dataset.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from fastai.dataset import FilesDataset, ModelData, ImageData, Transform, RandomFlip, RandomZoom, TfmType
  2. from fastai.dataset import get_cv_idxs, split_by_idx, tfms_from_stats, inception_stats, open_image
  3. from fastai.core import *
  4. class MatchedFilesDataset(FilesDataset):
  5. def __init__(self, fnames:np.array, y:np.array, transforms:[Transform], path:Path, x_tfms:[Transform]=[]):
  6. self.y=y
  7. self.x_tfms=x_tfms
  8. assert(len(fnames)==len(y))
  9. super().__init__(fnames, transforms, path)
  10. def get_x(self, i):
  11. x = super().get_x(i)
  12. for tfm in self.x_tfms:
  13. x,_ = tfm(x, False)
  14. return x
  15. def get_y(self, i):
  16. return open_image(os.path.join(self.path, self.y[i]))
  17. def get_c(self):
  18. return 0
  19. class ImageGenDataLoader():
  20. def __init__(self, sz:int, bs:int, path:Path, random_seed:int=None, keep_pct:float=1.0, x_tfms:[Transform]=[],
  21. file_exts=('jpg','jpeg','png'), extra_aug_tfms:[Transform]=[], reduce_x_scale:int=1):
  22. self.md = None
  23. self.sz = sz
  24. self.bs = bs
  25. self.path = path
  26. self.x_tfms = x_tfms
  27. self.random_seed = random_seed
  28. self.keep_pct = keep_pct
  29. self.file_exts = file_exts
  30. self.extra_aug_tfms=extra_aug_tfms
  31. self.reduce_x_scale=reduce_x_scale
  32. def get_model_data(self):
  33. if self.md is not None:
  34. return self.md
  35. resize_amt = self._get_resize_amount()
  36. resize_folder = 'tmp'
  37. ((val_x,trn_x),(val_y,trn_y)) = self._get_filename_sets(resize_folder)
  38. if len(trn_x) == 0:
  39. raise ValueError('No image files were found in specified image directory. Path provided was: ' + str(self.path))
  40. aug_tfms = [RandomFlip(tfm_y=TfmType.PIXEL), RandomZoom(zoom_max=0.18, tfm_y=TfmType.PIXEL)]
  41. aug_tfms.extend(self.extra_aug_tfms)
  42. sz_x = self.sz//self.reduce_x_scale
  43. sz_y = self.sz
  44. tfms = (tfms_from_stats(inception_stats, sz=sz_x, sz_y=sz_y, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms))
  45. dstype = MatchedFilesDataset
  46. datasets = ImageData.get_ds(dstype, (trn_x,trn_y), (val_x,val_y), tfms, path=self.path, x_tfms=self.x_tfms)
  47. resize_path = os.path.join(self.path,resize_folder,str(resize_amt))
  48. self.md = self._load_model_data(resize_folder, resize_path, resize_amt, datasets, trn_x)
  49. return self.md
  50. def _load_model_data(self, resize_folder:str, resize_path:str, resize_amt:int, datasets, trn_x):
  51. #optimization
  52. if os.path.exists(os.path.join(resize_path,trn_x[0])):
  53. return ImageData(Path(resize_path), datasets, self.bs, num_workers=16, classes=None)
  54. md = ImageData(self.path.parent, datasets, self.bs, num_workers=16, classes=None)
  55. if resize_amt != self.sz:
  56. md = md.resize(resize_amt, new_path=str(resize_folder))
  57. return md
  58. def _get_filename_sets(self, resize_folder:str):
  59. exclude_str = '/' + resize_folder + '/'
  60. paths = self._find_files_recursively(self.path,self.file_exts)
  61. paths = filter(lambda path: not re.search(exclude_str, str(path)), paths)
  62. fnames_full = [Path(str(fname).replace(str(self.path) + '/','')) for fname in paths]
  63. self._update_np_random_seed()
  64. keeps = np.random.rand(len(fnames_full)) < self.keep_pct
  65. fnames = np.array(fnames_full, copy=False)[keeps]
  66. val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01/self.keep_pct, 0.1))
  67. return split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
  68. def _update_np_random_seed(self):
  69. if self.random_seed is None:
  70. np.random.seed()
  71. else:
  72. np.random.seed(self.random_seed)
  73. def _get_resize_amount(self):
  74. if self.sz<96:
  75. return 128
  76. if self.sz <192:
  77. return 256
  78. return self.sz
  79. def _find_files_recursively(self, root_path:Path, extensions:(str)):
  80. matches = []
  81. for root, dirnames, filenames in os.walk(str(root_path)):
  82. for filename in filenames:
  83. if filename.lower().endswith(extensions):
  84. matches.append(os.path.join(root, filename))
  85. return matches