123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- from fastai.dataset import FilesDataset, ModelData, ImageData, Transform, RandomFlip, RandomZoom, TfmType
- from fastai.dataset import get_cv_idxs, split_by_idx, tfms_from_stats, inception_stats, open_image
- from fastai.core import *
- class MatchedFilesDataset(FilesDataset):
- def __init__(self, fnames:np.array, y:np.array, transforms:[Transform], path:Path, x_tfms:[Transform]=[]):
- self.y=y
- self.x_tfms=x_tfms
- assert(len(fnames)==len(y))
- super().__init__(fnames, transforms, path)
- def get_x(self, i):
- x = super().get_x(i)
- for tfm in self.x_tfms:
- x,_ = tfm(x, False)
- return x
- def get_y(self, i):
- return open_image(os.path.join(self.path, self.y[i]))
- def get_c(self):
- return 0
- class ImageGenDataLoader():
- def __init__(self, sz:int, bs:int, path:Path, random_seed:int=None, keep_pct:float=1.0, x_tfms:[Transform]=[],
- file_exts=('jpg','jpeg','png'), extra_aug_tfms:[Transform]=[], reduce_x_scale:int=1):
-
- self.md = None
- self.sz = sz
- self.bs = bs
- self.path = path
- self.x_tfms = x_tfms
- self.random_seed = random_seed
- self.keep_pct = keep_pct
- self.file_exts = file_exts
- self.extra_aug_tfms=extra_aug_tfms
- self.reduce_x_scale=reduce_x_scale
- def get_model_data(self):
- if self.md is not None:
- return self.md
- resize_amt = self._get_resize_amount()
- resize_folder = 'tmp'
- ((val_x,trn_x),(val_y,trn_y)) = self._get_filename_sets(resize_folder)
- aug_tfms = [RandomFlip(tfm_y=TfmType.PIXEL), RandomZoom(zoom_max=0.18, tfm_y=TfmType.PIXEL)]
- aug_tfms.extend(self.extra_aug_tfms)
- sz_x = self.sz//self.reduce_x_scale
- sz_y = self.sz
- tfms = (tfms_from_stats(inception_stats, sz=sz_x, sz_y=sz_y, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms))
- dstype = MatchedFilesDataset
- datasets = ImageData.get_ds(dstype, (trn_x,trn_y), (val_x,val_y), tfms, path=self.path, x_tfms=self.x_tfms)
- resize_path = os.path.join(self.path,resize_folder,str(resize_amt))
- self.md = self._load_model_data(resize_folder, resize_path, resize_amt, datasets, trn_x)
- return self.md
- def _load_model_data(self, resize_folder:str, resize_path:str, resize_amt:int, datasets, trn_x):
- #optimization
- if os.path.exists(os.path.join(resize_path,trn_x[0])):
- return ImageData(Path(resize_path), datasets, self.bs, num_workers=16, classes=None)
-
- md = ImageData(self.path.parent, datasets, self.bs, num_workers=16, classes=None)
- if resize_amt != self.sz:
- md = md.resize(resize_amt, new_path=str(resize_folder))
- return md
- def _get_filename_sets(self, resize_folder:str):
- exclude_str = '/' + resize_folder + '/'
- paths = self._find_files_recursively(self.path,self.file_exts)
- paths = filter(lambda path: not re.search(exclude_str, str(path)), paths)
- fnames_full = [Path(str(fname).replace(str(self.path) + '/','')) for fname in paths]
- self._update_np_random_seed()
- keeps = np.random.rand(len(fnames_full)) < self.keep_pct
- fnames = np.array(fnames_full, copy=False)[keeps]
- val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01/self.keep_pct, 0.1))
- return split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
- def _update_np_random_seed(self):
- if self.random_seed is None:
- np.random.seed()
- else:
- np.random.seed(self.random_seed)
- def _get_resize_amount(self):
- if self.sz<96:
- return 128
- if self.sz <192:
- return 256
- return self.sz
-
- def _find_files_recursively(self, root_path:Path, extensions:(str)):
- matches = []
- for root, dirnames, filenames in os.walk(str(root_path)):
- for filename in filenames:
- if filename.lower().endswith(extensions):
- matches.append(os.path.join(root, filename))
- return matches
-
|