dataset.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import fastai
  2. from fastai import *
  3. from fastai.core import *
  4. from fastai.vision.transform import get_transforms
  5. from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
  6. from .augs import noisify
  7. def get_colorize_data(
  8. sz: int,
  9. bs: int,
  10. crappy_path: Path,
  11. good_path: Path,
  12. random_seed: int = None,
  13. keep_pct: float = 1.0,
  14. num_workers: int = 8,
  15. xtra_tfms=[],
  16. ) -> ImageDataBunch:
  17. src = (
  18. ImageImageList.from_folder(crappy_path, convert_mode='RGB')
  19. .use_partial_data(sample_pct=keep_pct, seed=random_seed)
  20. .split_by_rand_pct(0.1, seed=random_seed)
  21. )
  22. data = (
  23. src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
  24. .transform(
  25. get_transforms(
  26. max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
  27. ),
  28. size=sz,
  29. tfm_y=True,
  30. )
  31. .databunch(bs=bs, num_workers=num_workers, no_check=True)
  32. .normalize(imagenet_stats, do_y=True)
  33. )
  34. data.c = 3
  35. return data
  36. def get_dummy_databunch() -> ImageDataBunch:
  37. path = Path('./dummy/')
  38. return get_colorize_data(
  39. sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
  40. )