dataset.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from fastai.dataset import FilesDataset, ModelData, ImageData, Transform, RandomFlip, RandomZoom, TfmType
  2. from fastai.dataset import get_cv_idxs, split_by_idx, tfms_from_stats, inception_stats, open_image
  3. from fastai.core import *
  4. class MatchedFilesDataset(FilesDataset):
  5. def __init__(self, fnames:np.array, y:np.array, transforms:[Transform], path:Path, x_tfms:[Transform]=[]):
  6. self.y=y
  7. self.x_tfms=x_tfms
  8. assert(len(fnames)==len(y))
  9. super().__init__(fnames, transforms, path)
  10. def get_x(self, i):
  11. x = super().get_x(i)
  12. for tfm in self.x_tfms:
  13. x,_ = tfm(x, False)
  14. return x
  15. def get_y(self, i):
  16. return open_image(os.path.join(self.path, self.y[i]))
  17. def get_c(self):
  18. return 0
  19. class ImageGenDataLoader():
  20. def __init__(self, sz:int, bs:int, path:Path, random_seed:int=None, keep_pct:float=1.0, x_tfms:[Transform]=[],
  21. file_exts=('jpg','jpeg','png'), extra_aug_tfms:[Transform]=[], reduce_x_scale:int=1):
  22. self.md = None
  23. self.sz = sz
  24. self.bs = bs
  25. self.path = path
  26. self.x_tfms = x_tfms
  27. self.random_seed = random_seed
  28. self.keep_pct = keep_pct
  29. self.file_exts = file_exts
  30. self.extra_aug_tfms=extra_aug_tfms
  31. self.reduce_x_scale=reduce_x_scale
  32. def get_model_data(self):
  33. if self.md is not None:
  34. return self.md
  35. resize_amt = self._get_resize_amount()
  36. resize_folder = 'tmp'
  37. ((val_x,trn_x),(val_y,trn_y)) = self._get_filename_sets(resize_folder)
  38. aug_tfms = [RandomFlip(tfm_y=TfmType.PIXEL), RandomZoom(zoom_max=0.18, tfm_y=TfmType.PIXEL)]
  39. aug_tfms.extend(self.extra_aug_tfms)
  40. sz_x = self.sz//self.reduce_x_scale
  41. sz_y = self.sz
  42. tfms = (tfms_from_stats(inception_stats, sz=sz_x, sz_y=sz_y, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms))
  43. dstype = MatchedFilesDataset
  44. datasets = ImageData.get_ds(dstype, (trn_x,trn_y), (val_x,val_y), tfms, path=self.path, x_tfms=self.x_tfms)
  45. resize_path = os.path.join(self.path,resize_folder,str(resize_amt))
  46. self.md = self._load_model_data(resize_folder, resize_path, resize_amt, datasets, trn_x)
  47. return self.md
  48. def _load_model_data(self, resize_folder:str, resize_path:str, resize_amt:int, datasets, trn_x):
  49. #optimization
  50. if os.path.exists(os.path.join(resize_path,trn_x[0])):
  51. return ImageData(Path(resize_path), datasets, self.bs, num_workers=16, classes=None)
  52. md = ImageData(self.path.parent, datasets, self.bs, num_workers=16, classes=None)
  53. if resize_amt != self.sz:
  54. md = md.resize(resize_amt, new_path=str(resize_folder))
  55. return md
  56. def _get_filename_sets(self, resize_folder:str):
  57. exclude_str = '/' + resize_folder + '/'
  58. paths = self._find_files_recursively(self.path,self.file_exts)
  59. paths = filter(lambda path: not re.search(exclude_str, str(path)), paths)
  60. fnames_full = [Path(str(fname).replace(str(self.path) + '/','')) for fname in paths]
  61. self._update_np_random_seed()
  62. keeps = np.random.rand(len(fnames_full)) < self.keep_pct
  63. fnames = np.array(fnames_full, copy=False)[keeps]
  64. val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01/self.keep_pct, 0.1))
  65. return split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
  66. def _update_np_random_seed(self):
  67. if self.random_seed is None:
  68. np.random.seed()
  69. else:
  70. np.random.seed(self.random_seed)
  71. def _get_resize_amount(self):
  72. if self.sz<96:
  73. return 128
  74. if self.sz <192:
  75. return 256
  76. return self.sz
  77. def _find_files_recursively(self, root_path:Path, extensions:(str)):
  78. matches = []
  79. for root, dirnames, filenames in os.walk(str(root_path)):
  80. for filename in filenames:
  81. if filename.lower().endswith(extensions):
  82. matches.append(os.path.join(root, filename))
  83. return matches