123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- import fastai
- from fastai import *
- from fastai.core import *
- from fastai.vision.transform import get_transforms
- from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
- from .augs import noisify
- def get_colorize_data(
- sz: int,
- bs: int,
- crappy_path: Path,
- good_path: Path,
- random_seed: int = None,
- keep_pct: float = 1.0,
- num_workers: int = 8,
- stats: tuple = imagenet_stats,
- xtra_tfms=[],
- ) -> ImageDataBunch:
-
- src = (
- ImageImageList.from_folder(crappy_path, convert_mode='RGB')
- .use_partial_data(sample_pct=keep_pct, seed=random_seed)
- .split_by_rand_pct(0.1, seed=random_seed)
- )
- data = (
- src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
- .transform(
- get_transforms(
- max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
- ),
- size=sz,
- tfm_y=True,
- )
- .databunch(bs=bs, num_workers=num_workers, no_check=True)
- .normalize(stats, do_y=True)
- )
- data.c = 3
- return data
- def get_dummy_databunch() -> ImageDataBunch:
- path = Path('./dummy/')
- return get_colorize_data(
- sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
- )
|