dataset.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. import csv
  2. from .imports import *
  3. from .torch_imports import *
  4. from .core import *
  5. from .transforms import *
  6. from .layer_optimizer import *
  7. from .dataloader import DataLoader
  8. def get_cv_idxs(n, cv_idx=0, val_pct=0.2, seed=42):
  9. """ Get a list of index values for Validation set from a dataset
  10. Arguments:
  11. n : int, Total number of elements in the data set.
  12. cv_idx : int, starting index [idx_start = cv_idx*int(val_pct*n)]
  13. val_pct : (int, float), validation set percentage
  14. seed : seed value for RandomState
  15. Returns:
  16. list of indexes
  17. """
  18. np.random.seed(seed)
  19. n_val = int(val_pct*n)
  20. idx_start = cv_idx*n_val
  21. idxs = np.random.permutation(n)
  22. return idxs[idx_start:idx_start+n_val]
  23. def resize_img(fname, targ, path, new_path):
  24. """
  25. Enlarge or shrink a single image to scale, such that the smaller of the height or width dimension is equal to targ.
  26. """
  27. dest = os.path.join(path,new_path,str(targ),fname)
  28. if os.path.exists(dest): return
  29. im = Image.open(os.path.join(path, fname)).convert('RGB')
  30. r,c = im.size
  31. ratio = targ/min(r,c)
  32. sz = (scale_to(r, ratio, targ), scale_to(c, ratio, targ))
  33. os.makedirs(os.path.split(dest)[0], exist_ok=True)
  34. im.resize(sz, Image.LINEAR).save(dest)
  35. def resize_imgs(fnames, targ, path, new_path):
  36. """
  37. Enlarge or shrink a set of images in the same directory to scale, such that the smaller of the height or width dimension is equal to targ.
  38. Note:
  39. -- This function is multithreaded for efficiency.
  40. -- When destination file or folder already exist, function exists without raising an error.
  41. """
  42. if not os.path.exists(os.path.join(path,new_path,str(targ),fnames[0])):
  43. with ThreadPoolExecutor(8) as e:
  44. ims = e.map(lambda x: resize_img(x, targ, path, new_path), fnames)
  45. for x in tqdm(ims, total=len(fnames), leave=False): pass
  46. return os.path.join(path,new_path,str(targ))
  47. def read_dir(path, folder):
  48. """ Returns a list of relative file paths to `path` for all files within `folder` """
  49. full_path = os.path.join(path, folder)
  50. fnames = glob(f"{full_path}/*.*")
  51. directories = glob(f"{full_path}/*/")
  52. if any(fnames):
  53. return [os.path.relpath(f,path) for f in fnames]
  54. elif any(directories):
  55. raise FileNotFoundError("{} has subdirectories but contains no files. Is your directory structure is correct?".format(full_path))
  56. else:
  57. raise FileNotFoundError("{} folder doesn't exist or is empty".format(full_path))
  58. def read_dirs(path, folder):
  59. '''
  60. Fetches name of all files in path in long form, and labels associated by extrapolation of directory names.
  61. '''
  62. lbls, fnames, all_lbls = [], [], []
  63. full_path = os.path.join(path, folder)
  64. for lbl in sorted(os.listdir(full_path)):
  65. if lbl not in ('.ipynb_checkpoints','.DS_Store'):
  66. all_lbls.append(lbl)
  67. for fname in os.listdir(os.path.join(full_path, lbl)):
  68. if fname not in ('.DS_Store'):
  69. fnames.append(os.path.join(folder, lbl, fname))
  70. lbls.append(lbl)
  71. return fnames, lbls, all_lbls
  72. def n_hot(ids, c):
  73. '''
  74. one hot encoding by index. Returns array of length c, where all entries are 0, except for the indecies in ids
  75. '''
  76. res = np.zeros((c,), dtype=np.float32)
  77. res[ids] = 1
  78. return res
  79. def folder_source(path, folder):
  80. """
  81. Returns the filenames and labels for a folder within a path
  82. Returns:
  83. -------
  84. fnames: a list of the filenames within `folder`
  85. all_lbls: a list of all of the labels in `folder`, where the # of labels is determined by the # of directories within `folder`
  86. lbl_arr: a numpy array of the label indices in `all_lbls`
  87. """
  88. fnames, lbls, all_lbls = read_dirs(path, folder)
  89. lbl2idx = {lbl:idx for idx,lbl in enumerate(all_lbls)}
  90. idxs = [lbl2idx[lbl] for lbl in lbls]
  91. lbl_arr = np.array(idxs, dtype=int)
  92. return fnames, lbl_arr, all_lbls
  93. def parse_csv_labels(fn, skip_header=True, cat_separator = ' '):
  94. """Parse filenames and label sets from a CSV file.
  95. This method expects that the csv file at path :fn: has two columns. If it
  96. has a header, :skip_header: should be set to True. The labels in the
  97. label set are expected to be space separated.
  98. Arguments:
  99. fn: Path to a CSV file.
  100. skip_header: A boolean flag indicating whether to skip the header.
  101. Returns:
  102. a two-tuple of (
  103. image filenames,
  104. a dictionary of filenames and corresponding labels
  105. )
  106. .
  107. :param cat_separator: the separator for the categories column
  108. """
  109. df = pd.read_csv(fn, index_col=0, header=0 if skip_header else None, dtype=str)
  110. fnames = df.index.values
  111. df.iloc[:,0] = df.iloc[:,0].str.split(cat_separator)
  112. return fnames, list(df.to_dict().values())[0]
  113. def nhot_labels(label2idx, csv_labels, fnames, c):
  114. all_idx = {k: n_hot([label2idx[o] for o in ([] if type(v) == float else v)], c)
  115. for k,v in csv_labels.items()}
  116. return np.stack([all_idx[o] for o in fnames])
  117. def csv_source(folder, csv_file, skip_header=True, suffix='', continuous=False, cat_separator=' '):
  118. fnames,csv_labels = parse_csv_labels(csv_file, skip_header, cat_separator)
  119. return dict_source(folder, fnames, csv_labels, suffix, continuous)
  120. def dict_source(folder, fnames, csv_labels, suffix='', continuous=False):
  121. all_labels = sorted(list(set(p for o in csv_labels.values() for p in ([] if type(o) == float else o))))
  122. full_names = [os.path.join(folder,str(fn)+suffix) for fn in fnames]
  123. if continuous:
  124. label_arr = np.array([np.array(csv_labels[i]).astype(np.float32)
  125. for i in fnames])
  126. else:
  127. label2idx = {v:k for k,v in enumerate(all_labels)}
  128. label_arr = nhot_labels(label2idx, csv_labels, fnames, len(all_labels))
  129. is_single = np.all(label_arr.sum(axis=1)==1)
  130. if is_single: label_arr = np.argmax(label_arr, axis=1)
  131. return full_names, label_arr, all_labels
  132. class BaseDataset(Dataset):
  133. """An abstract class representing a fastai dataset. Extends torch.utils.data.Dataset."""
  134. def __init__(self, transform=None):
  135. self.transform = transform
  136. self.n = self.get_n()
  137. self.c = self.get_c()
  138. self.sz = self.get_sz()
  139. def get1item(self, idx):
  140. x,y = self.get_x(idx),self.get_y(idx)
  141. return self.get(self.transform, x, y)
  142. def __getitem__(self, idx):
  143. if isinstance(idx,slice):
  144. xs,ys = zip(*[self.get1item(i) for i in range(*idx.indices(self.n))])
  145. return np.stack(xs),ys
  146. return self.get1item(idx)
  147. def __len__(self): return self.n
  148. def get(self, tfm, x, y):
  149. return (x,y) if tfm is None else tfm(x,y)
  150. @abstractmethod
  151. def get_n(self):
  152. """Return number of elements in the dataset == len(self)."""
  153. raise NotImplementedError
  154. @abstractmethod
  155. def get_c(self):
  156. """Return number of classes in a dataset."""
  157. raise NotImplementedError
  158. @abstractmethod
  159. def get_sz(self):
  160. """Return maximum size of an image in a dataset."""
  161. raise NotImplementedError
  162. @abstractmethod
  163. def get_x(self, i):
  164. """Return i-th example (image, wav, etc)."""
  165. raise NotImplementedError
  166. @abstractmethod
  167. def get_y(self, i):
  168. """Return i-th label."""
  169. raise NotImplementedError
  170. @property
  171. def is_multi(self):
  172. """Returns true if this data set contains multiple labels per sample."""
  173. return False
  174. @property
  175. def is_reg(self):
  176. """True if the data set is used to train regression models."""
  177. return False
  178. def open_image(fn):
  179. """ Opens an image using OpenCV given the file path.
  180. Arguments:
  181. fn: the file path of the image
  182. Returns:
  183. The image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0
  184. """
  185. flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR
  186. if not os.path.exists(fn) and not str(fn).startswith("http"):
  187. raise OSError('No such file or directory: {}'.format(fn))
  188. elif os.path.isdir(fn) and not str(fn).startswith("http"):
  189. raise OSError('Is a directory: {}'.format(fn))
  190. else:
  191. #res = np.array(Image.open(fn), dtype=np.float32)/255
  192. #if len(res.shape)==2: res = np.repeat(res[...,None],3,2)
  193. #return res
  194. try:
  195. if str(fn).startswith("http"):
  196. req = urllib.urlopen(str(fn))
  197. image = np.asarray(bytearray(req.read()), dtype="uint8")
  198. im = cv2.imdecode(image, flags).astype(np.float32)/255
  199. else:
  200. im = cv2.imread(str(fn), flags).astype(np.float32)/255
  201. if im is None: raise OSError(f'File not recognized by opencv: {fn}')
  202. return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  203. except Exception as e:
  204. raise OSError('Error handling image at: {}'.format(fn)) from e
  205. class FilesDataset(BaseDataset):
  206. def __init__(self, fnames, transform, path):
  207. self.path,self.fnames = path,fnames
  208. super().__init__(transform)
  209. def get_sz(self): return self.transform.sz
  210. def get_x(self, i): return open_image(os.path.join(self.path, self.fnames[i]))
  211. def get_n(self): return len(self.fnames)
  212. def resize_imgs(self, targ, new_path):
  213. dest = resize_imgs(self.fnames, targ, self.path, new_path)
  214. return self.__class__(self.fnames, self.y, self.transform, dest)
  215. def denorm(self,arr):
  216. """Reverse the normalization done to a batch of images.
  217. Arguments:
  218. arr: of shape/size (N,3,sz,sz)
  219. """
  220. if type(arr) is not np.ndarray: arr = to_np(arr)
  221. if len(arr.shape)==3: arr = arr[None]
  222. return self.transform.denorm(np.rollaxis(arr,1,4))
  223. class FilesArrayDataset(FilesDataset):
  224. def __init__(self, fnames, y, transform, path):
  225. self.y=y
  226. assert(len(fnames)==len(y))
  227. super().__init__(fnames, transform, path)
  228. def get_y(self, i): return self.y[i]
  229. def get_c(self):
  230. return self.y.shape[1] if len(self.y.shape)>1 else 0
  231. class FilesIndexArrayDataset(FilesArrayDataset):
  232. def get_c(self): return int(self.y.max())+1
  233. class FilesNhotArrayDataset(FilesArrayDataset):
  234. @property
  235. def is_multi(self): return True
  236. class FilesIndexArrayRegressionDataset(FilesArrayDataset):
  237. def is_reg(self): return True
  238. class ArraysDataset(BaseDataset):
  239. def __init__(self, x, y, transform):
  240. self.x,self.y=x,y
  241. assert(len(x)==len(y))
  242. super().__init__(transform)
  243. def get_x(self, i): return self.x[i]
  244. def get_y(self, i): return self.y[i]
  245. def get_n(self): return len(self.y)
  246. def get_sz(self): return self.x.shape[1]
  247. class ArraysIndexDataset(ArraysDataset):
  248. def get_c(self): return int(self.y.max())+1
  249. def get_y(self, i): return self.y[i]
  250. class ArraysIndexRegressionDataset(ArraysIndexDataset):
  251. def is_reg(self): return True
  252. class ArraysNhotDataset(ArraysDataset):
  253. def get_c(self): return self.y.shape[1]
  254. @property
  255. def is_multi(self): return True
  256. class ModelData():
  257. """Encapsulates DataLoaders and Datasets for training, validation, test. Base class for fastai *Data classes."""
  258. def __init__(self, path, trn_dl, val_dl, test_dl=None):
  259. self.path,self.trn_dl,self.val_dl,self.test_dl = path,trn_dl,val_dl,test_dl
  260. @classmethod
  261. def from_dls(cls, path,trn_dl,val_dl,test_dl=None):
  262. #trn_dl,val_dl = DataLoader(trn_dl),DataLoader(val_dl)
  263. #if test_dl: test_dl = DataLoader(test_dl)
  264. return cls(path, trn_dl, val_dl, test_dl)
  265. @property
  266. def is_reg(self): return self.trn_ds.is_reg
  267. @property
  268. def is_multi(self): return self.trn_ds.is_multi
  269. @property
  270. def trn_ds(self): return self.trn_dl.dataset
  271. @property
  272. def val_ds(self): return self.val_dl.dataset
  273. @property
  274. def test_ds(self): return self.test_dl.dataset
  275. @property
  276. def trn_y(self): return self.trn_ds.y
  277. @property
  278. def val_y(self): return self.val_ds.y
  279. class ImageData(ModelData):
  280. def __init__(self, path, datasets, bs, num_workers, classes):
  281. trn_ds,val_ds,fix_ds,aug_ds,test_ds,test_aug_ds = datasets
  282. self.path,self.bs,self.num_workers,self.classes = path,bs,num_workers,classes
  283. self.trn_dl,self.val_dl,self.fix_dl,self.aug_dl,self.test_dl,self.test_aug_dl = [
  284. self.get_dl(ds,shuf) for ds,shuf in [
  285. (trn_ds,True),(val_ds,False),(fix_ds,False),(aug_ds,False),
  286. (test_ds,False),(test_aug_ds,False)
  287. ]
  288. ]
  289. def get_dl(self, ds, shuffle):
  290. if ds is None: return None
  291. return DataLoader(ds, batch_size=self.bs, shuffle=shuffle,
  292. num_workers=self.num_workers, pin_memory=False)
  293. @property
  294. def sz(self): return self.trn_ds.sz
  295. @property
  296. def c(self): return self.trn_ds.c
  297. def resized(self, dl, targ, new_path):
  298. return dl.dataset.resize_imgs(targ,new_path) if dl else None
  299. def resize(self, targ_sz, new_path='tmp'):
  300. new_ds = []
  301. dls = [self.trn_dl,self.val_dl,self.fix_dl,self.aug_dl]
  302. if self.test_dl: dls += [self.test_dl, self.test_aug_dl]
  303. else: dls += [None,None]
  304. t = tqdm_notebook(dls)
  305. for dl in t: new_ds.append(self.resized(dl, targ_sz, new_path))
  306. t.close()
  307. return self.__class__(new_ds[0].path, new_ds, self.bs, self.num_workers, self.classes)
  308. @staticmethod
  309. def get_ds(fn, trn, val, tfms, test=None, **kwargs):
  310. res = [
  311. fn(trn[0], trn[1], tfms[0], **kwargs), # train
  312. fn(val[0], val[1], tfms[1], **kwargs), # val
  313. fn(trn[0], trn[1], tfms[1], **kwargs), # fix
  314. fn(val[0], val[1], tfms[0], **kwargs) # aug
  315. ]
  316. if test is not None:
  317. if isinstance(test, tuple):
  318. test_lbls = test[1]
  319. test = test[0]
  320. else:
  321. if len(trn[1].shape) == 1:
  322. test_lbls = np.zeros((len(test),1))
  323. else:
  324. test_lbls = np.zeros((len(test),trn[1].shape[1]))
  325. res += [
  326. fn(test, test_lbls, tfms[1], **kwargs), # test
  327. fn(test, test_lbls, tfms[0], **kwargs) # test_aug
  328. ]
  329. else: res += [None,None]
  330. return res
  331. class ImageClassifierData(ImageData):
  332. @classmethod
  333. def from_arrays(cls, path, trn, val, bs=64, tfms=(None,None), classes=None, num_workers=4, test=None, continuous=False):
  334. """ Read in images and their labels given as numpy arrays
  335. Arguments:
  336. path: a root path of the data (used for storing trained models, precomputed values, etc)
  337. trn: a tuple of training data matrix and target label/classification array (e.g. `trn=(x,y)` where `x` has the
  338. shape of `(5000, 784)` and `y` has the shape of `(5000,)`)
  339. val: a tuple of validation data matrix and target label/classification array.
  340. bs: batch size
  341. tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
  342. classes: a list of all labels/classifications
  343. num_workers: a number of workers
  344. test: a matrix of test data (the shape should match `trn[0]`)
  345. Returns:
  346. ImageClassifierData
  347. """
  348. f = ArraysIndexRegressionDataset if continuous else ArraysIndexDataset
  349. datasets = cls.get_ds(f, trn, val, tfms, test=test)
  350. return cls(path, datasets, bs, num_workers, classes=classes)
  351. @classmethod
  352. def from_paths(cls, path, bs=64, tfms=(None,None), trn_name='train', val_name='valid', test_name=None, test_with_labels=False, num_workers=8):
  353. """ Read in images and their labels given as sub-folder names
  354. Arguments:
  355. path: a root path of the data (used for storing trained models, precomputed values, etc)
  356. bs: batch size
  357. tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
  358. trn_name: a name of the folder that contains training images.
  359. val_name: a name of the folder that contains validation images.
  360. test_name: a name of the folder that contains test images.
  361. num_workers: number of workers
  362. Returns:
  363. ImageClassifierData
  364. """
  365. assert not(tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
  366. trn,val = [folder_source(path, o) for o in (trn_name, val_name)]
  367. if test_name:
  368. test = folder_source(path, test_name) if test_with_labels else read_dir(path, test_name)
  369. else: test = None
  370. datasets = cls.get_ds(FilesIndexArrayDataset, trn, val, tfms, path=path, test=test)
  371. return cls(path, datasets, bs, num_workers, classes=trn[2])
  372. @classmethod
  373. def from_csv(cls, path, folder, csv_fname, bs=64, tfms=(None,None),
  374. val_idxs=None, suffix='', test_name=None, continuous=False, skip_header=True, num_workers=8, cat_separator=' '):
  375. """ Read in images and their labels given as a CSV file.
  376. This method should be used when training image labels are given in an CSV file as opposed to
  377. sub-directories with label names.
  378. Arguments:
  379. path: a root path of the data (used for storing trained models, precomputed values, etc)
  380. folder: a name of the folder in which training images are contained.
  381. csv_fname: a name of the CSV file which contains target labels.
  382. bs: batch size
  383. tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
  384. val_idxs: index of images to be used for validation. e.g. output of `get_cv_idxs`.
  385. If None, default arguments to get_cv_idxs are used.
  386. suffix: suffix to add to image names in CSV file (sometimes CSV only contains the file name without file
  387. extension e.g. '.jpg' - in which case, you can set suffix as '.jpg')
  388. test_name: a name of the folder which contains test images.
  389. continuous: TODO
  390. skip_header: skip the first row of the CSV file.
  391. num_workers: number of workers
  392. cat_separator: Labels category separator
  393. Returns:
  394. ImageClassifierData
  395. """
  396. assert not (tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
  397. assert not (os.path.isabs(folder)), "folder needs to be a relative path"
  398. fnames,y,classes = csv_source(folder, csv_fname, skip_header, suffix, continuous=continuous, cat_separator=cat_separator)
  399. return cls.from_names_and_array(path, fnames, y, classes, val_idxs, test_name,
  400. num_workers=num_workers, suffix=suffix, tfms=tfms, bs=bs, continuous=continuous)
  401. @classmethod
  402. def from_path_and_array(cls, path, folder, y, classes=None, val_idxs=None, test_name=None,
  403. num_workers=8, tfms=(None,None), bs=64):
  404. """ Read in images given a sub-folder and their labels given a numpy array
  405. Arguments:
  406. path: a root path of the data (used for storing trained models, precomputed values, etc)
  407. folder: a name of the folder in which training images are contained.
  408. y: numpy array which contains target labels ordered by filenames.
  409. bs: batch size
  410. tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
  411. val_idxs: index of images to be used for validation. e.g. output of `get_cv_idxs`.
  412. If None, default arguments to get_cv_idxs are used.
  413. test_name: a name of the folder which contains test images.
  414. num_workers: number of workers
  415. Returns:
  416. ImageClassifierData
  417. """
  418. assert not (tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
  419. assert not (os.path.isabs(folder)), "folder needs to be a relative path"
  420. fnames = np.core.defchararray.add(f'{folder}/', sorted(os.listdir(f'{path}{folder}')))
  421. return cls.from_names_and_array(path, fnames, y, classes, val_idxs, test_name,
  422. num_workers=num_workers, tfms=tfms, bs=bs)
  423. @classmethod
  424. def from_names_and_array(cls, path, fnames, y, classes, val_idxs=None, test_name=None,
  425. num_workers=8, suffix='', tfms=(None,None), bs=64, continuous=False):
  426. val_idxs = get_cv_idxs(len(fnames)) if val_idxs is None else val_idxs
  427. ((val_fnames,trn_fnames),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames), y)
  428. test_fnames = read_dir(path, test_name) if test_name else None
  429. if continuous: f = FilesIndexArrayRegressionDataset
  430. else:
  431. f = FilesIndexArrayDataset if len(trn_y.shape)==1 else FilesNhotArrayDataset
  432. datasets = cls.get_ds(f, (trn_fnames,trn_y), (val_fnames,val_y), tfms,
  433. path=path, test=test_fnames)
  434. return cls(path, datasets, bs, num_workers, classes=classes)
  435. def split_by_idx(idxs, *a):
  436. """
  437. Split each array passed as *a, to a pair of arrays like this (elements selected by idxs, the remaining elements)
  438. This can be used to split multiple arrays containing training data to validation and training set.
  439. :param idxs [int]: list of indexes selected
  440. :param a list: list of np.array, each array should have same amount of elements in the first dimension
  441. :return: list of tuples, each containing a split of corresponding array from *a.
  442. First element of each tuple is an array composed from elements selected by idxs,
  443. second element is an array of remaining elements.
  444. """
  445. mask = np.zeros(len(a[0]),dtype=bool)
  446. mask[np.array(idxs)] = True
  447. return [(o[mask],o[~mask]) for o in a]