dataset.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import fastai
  2. from fastai import *
  3. from fastai.core import *
  4. from fastai.vision.transform import get_transforms
  5. from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
  6. from .augs import noisify
  7. def get_colorize_data(
  8. sz: int,
  9. bs: int,
  10. crappy_path: Path,
  11. good_path: Path,
  12. random_seed: int = None,
  13. keep_pct: float = 1.0,
  14. num_workers: int = 8,
  15. stats: tuple = imagenet_stats,
  16. xtra_tfms=[],
  17. ) -> ImageDataBunch:
  18. src = (
  19. ImageImageList.from_folder(crappy_path, convert_mode='RGB')
  20. .use_partial_data(sample_pct=keep_pct, seed=random_seed)
  21. .split_by_rand_pct(0.1, seed=random_seed)
  22. )
  23. data = (
  24. src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
  25. .transform(
  26. get_transforms(
  27. max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
  28. ),
  29. size=sz,
  30. tfm_y=True,
  31. )
  32. .databunch(bs=bs, num_workers=num_workers, no_check=True)
  33. .normalize(stats, do_y=True)
  34. )
  35. data.c = 3
  36. return data
  37. def get_dummy_databunch() -> ImageDataBunch:
  38. path = Path('./dummy/')
  39. return get_colorize_data(
  40. sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
  41. )