dataset.py 783 B

123456789101112131415161718192021
  1. import fastai
  2. from fastai import *
  3. from fastai.core import *
  4. from fastai.vision import *
  5. def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None,
  6. keep_pct:float=1.0, num_workers:int=8)->ImageDataBunch:
  7. src = (ImageImageList.from_folder(crappy_path)
  8. .use_partial_data(sample_pct=keep_pct, seed=random_seed)
  9. .random_split_by_pct(0.1, seed=random_seed))
  10. data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
  11. #TODO: Revisit transforms used here....
  12. .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25), size=sz, tfm_y=True)
  13. .databunch(bs=bs, num_workers=num_workers)
  14. .normalize(imagenet_stats, do_y=True))
  15. data.c = 3
  16. return data