transforms.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. from .imports import *
  2. from .layer_optimizer import *
  3. from enum import IntEnum
  4. def scale_min(im, targ, interpolation=cv2.INTER_AREA):
  5. """ Scale the image so that the smallest axis is of size targ.
  6. Arguments:
  7. im (array): image
  8. targ (int): target size
  9. """
  10. r,c,*_ = im.shape
  11. ratio = targ/min(r,c)
  12. sz = (scale_to(c, ratio, targ), scale_to(r, ratio, targ))
  13. return cv2.resize(im, sz, interpolation=interpolation)
  14. def zoom_cv(x,z):
  15. """ Zoom the center of image x by a factor of z+1 while retaining the original image size and proportion. """
  16. if z==0: return x
  17. r,c,*_ = x.shape
  18. M = cv2.getRotationMatrix2D((c/2,r/2),0,z+1.)
  19. return cv2.warpAffine(x,M,(c,r))
  20. def stretch_cv(x,sr,sc,interpolation=cv2.INTER_AREA):
  21. """ Stretches image x horizontally by sr+1, and vertically by sc+1 while retaining the original image size and proportion. """
  22. if sr==0 and sc==0: return x
  23. r,c,*_ = x.shape
  24. x = cv2.resize(x, None, fx=sr+1, fy=sc+1, interpolation=interpolation)
  25. nr,nc,*_ = x.shape
  26. cr = (nr-r)//2; cc = (nc-c)//2
  27. return x[cr:r+cr, cc:c+cc]
  28. def dihedral(x, dih):
  29. """ Perform any of 8 permutations of 90-degrees rotations or flips for image x. """
  30. x = np.rot90(x, dih%4)
  31. return x if dih<4 else np.fliplr(x)
  32. def lighting(im, b, c):
  33. """ Adjust image balance and contrast """
  34. if b==0 and c==1: return im
  35. mu = np.average(im)
  36. return np.clip((im-mu)*c+mu+b,0.,1.).astype(np.float32)
  37. def rotate_cv(im, deg, mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_AREA):
  38. """ Rotate an image by deg degrees
  39. Arguments:
  40. deg (float): degree to rotate.
  41. """
  42. r,c,*_ = im.shape
  43. M = cv2.getRotationMatrix2D((c//2,r//2),deg,1)
  44. return cv2.warpAffine(im,M,(c,r), borderMode=mode, flags=cv2.WARP_FILL_OUTLIERS+interpolation)
  45. def no_crop(im, min_sz=None, interpolation=cv2.INTER_AREA):
  46. """ Return a squared resized image """
  47. r,c,*_ = im.shape
  48. if min_sz is None: min_sz = min(r,c)
  49. return cv2.resize(im, (min_sz, min_sz), interpolation=interpolation)
  50. def center_crop(im, min_sz=None):
  51. """ Return a center crop of an image """
  52. r,c,*_ = im.shape
  53. if min_sz is None: min_sz = min(r,c)
  54. start_r = math.ceil((r-min_sz)/2)
  55. start_c = math.ceil((c-min_sz)/2)
  56. return crop(im, start_r, start_c, min_sz)
  57. def googlenet_resize(im, targ, min_area_frac, min_aspect_ratio, max_aspect_ratio, flip_hw_p, interpolation=cv2.INTER_AREA):
  58. """ Randomly crop an image with an aspect ratio and returns a squared resized image of size targ
  59. References:
  60. 1. https://arxiv.org/pdf/1409.4842.pdf
  61. 2. https://arxiv.org/pdf/1802.07888.pdf
  62. """
  63. h,w,*_ = im.shape
  64. area = h*w
  65. for _ in range(10):
  66. targetArea = random.uniform(min_area_frac, 1.0) * area
  67. aspectR = random.uniform(min_aspect_ratio, max_aspect_ratio)
  68. ww = int(np.sqrt(targetArea * aspectR) + 0.5)
  69. hh = int(np.sqrt(targetArea / aspectR) + 0.5)
  70. if flip_hw_p:
  71. ww, hh = hh, ww
  72. if hh <= h and ww <= w:
  73. x1 = 0 if w == ww else random.randint(0, w - ww)
  74. y1 = 0 if h == hh else random.randint(0, h - hh)
  75. out = im[y1:y1 + hh, x1:x1 + ww]
  76. out = cv2.resize(out, (targ, targ), interpolation=interpolation)
  77. return out
  78. out = scale_min(im, targ, interpolation=interpolation)
  79. out = center_crop(out)
  80. return out
  81. def cutout(im, n_holes, length):
  82. """ Cut out n_holes number of square holes of size length in image at random locations. Holes may overlap. """
  83. r,c,*_ = im.shape
  84. mask = np.ones((r, c), np.int32)
  85. for n in range(n_holes):
  86. y = np.random.randint(length / 2, r - length / 2)
  87. x = np.random.randint(length / 2, c - length / 2)
  88. y1 = int(np.clip(y - length / 2, 0, r))
  89. y2 = int(np.clip(y + length / 2, 0, r))
  90. x1 = int(np.clip(x - length / 2, 0, c))
  91. x2 = int(np.clip(x + length / 2, 0, c))
  92. mask[y1: y2, x1: x2] = 0.
  93. mask = mask[:,:,None]
  94. im = im * mask
  95. return im
  96. def scale_to(x, ratio, targ):
  97. '''Calculate dimension of an image during scaling with aspect ratio'''
  98. return max(math.floor(x*ratio), targ)
  99. def crop(im, r, c, sz):
  100. '''
  101. crop image into a square of size sz,
  102. '''
  103. return im[r:r+sz, c:c+sz]
  104. def det_dihedral(dih): return lambda x: dihedral(x, dih)
  105. def det_stretch(sr, sc): return lambda x: stretch_cv(x, sr, sc)
  106. def det_lighting(b, c): return lambda x: lighting(x, b, c)
  107. def det_rotate(deg): return lambda x: rotate_cv(x, deg)
  108. def det_zoom(zoom): return lambda x: zoom_cv(x, zoom)
  109. def rand0(s): return random.random()*(s*2)-s
  110. class TfmType(IntEnum):
  111. """ Type of transformation.
  112. Parameters
  113. IntEnum: predefined types of transformations
  114. NO: the default, y does not get transformed when x is transformed.
  115. PIXEL: x and y are images and should be transformed in the same way.
  116. Example: image segmentation.
  117. COORD: y are coordinates (i.e bounding boxes)
  118. CLASS: y are class labels (same behaviour as PIXEL, except no normalization)
  119. """
  120. NO = 1
  121. PIXEL = 2
  122. COORD = 3
  123. CLASS = 4
  124. class Denormalize():
  125. """ De-normalizes an image, returning it to original format.
  126. """
  127. def __init__(self, m, s):
  128. self.m=np.array(m, dtype=np.float32)
  129. self.s=np.array(s, dtype=np.float32)
  130. def __call__(self, x): return x*self.s+self.m
  131. class Normalize():
  132. """ Normalizes an image to zero mean and unit standard deviation, given the mean m and std s of the original image """
  133. def __init__(self, m, s, tfm_y=TfmType.NO):
  134. self.m=np.array(m, dtype=np.float32)
  135. self.s=np.array(s, dtype=np.float32)
  136. self.tfm_y=tfm_y
  137. def __call__(self, x, y=None):
  138. x = (x-self.m)/self.s
  139. if self.tfm_y==TfmType.PIXEL and y is not None: y = (y-self.m)/self.s
  140. return x,y
  141. class ChannelOrder():
  142. '''
  143. changes image array shape from (h, w, 3) to (3, h, w).
  144. tfm_y decides the transformation done to the y element.
  145. '''
  146. def __init__(self, tfm_y=TfmType.NO): self.tfm_y=tfm_y
  147. def __call__(self, x, y):
  148. x = np.rollaxis(x, 2)
  149. #if isinstance(y,np.ndarray) and (len(y.shape)==3):
  150. if self.tfm_y==TfmType.PIXEL: y = np.rollaxis(y, 2)
  151. elif self.tfm_y==TfmType.CLASS: y = y[...,0]
  152. return x,y
  153. def to_bb(YY, y="deprecated"):
  154. """Convert mask YY to a bounding box, assumes 0 as background nonzero object"""
  155. cols,rows = np.nonzero(YY)
  156. if len(cols)==0: return np.zeros(4, dtype=np.float32)
  157. top_row = np.min(rows)
  158. left_col = np.min(cols)
  159. bottom_row = np.max(rows)
  160. right_col = np.max(cols)
  161. return np.array([left_col, top_row, right_col, bottom_row], dtype=np.float32)
  162. def coords2px(y, x):
  163. """ Transforming coordinates to pixels.
  164. Arguments:
  165. y : np array
  166. vector in which (y[0], y[1]) and (y[2], y[3]) are the
  167. the corners of a bounding box.
  168. x : image
  169. an image
  170. Returns:
  171. Y : image
  172. of shape x.shape
  173. """
  174. rows = np.rint([y[0], y[0], y[2], y[2]]).astype(int)
  175. cols = np.rint([y[1], y[3], y[1], y[3]]).astype(int)
  176. r,c,*_ = x.shape
  177. Y = np.zeros((r, c))
  178. Y[rows, cols] = 1
  179. return Y
  180. class Transform():
  181. """ A class that represents a transform.
  182. All other transforms should subclass it. All subclasses should override
  183. do_transform.
  184. Arguments
  185. ---------
  186. tfm_y : TfmType
  187. type of transform
  188. """
  189. def __init__(self, tfm_y=TfmType.NO):
  190. self.tfm_y=tfm_y
  191. self.store = threading.local()
  192. def set_state(self): pass
  193. def __call__(self, x, y):
  194. self.set_state()
  195. x,y = ((self.transform(x),y) if self.tfm_y==TfmType.NO
  196. else self.transform(x,y) if self.tfm_y in (TfmType.PIXEL, TfmType.CLASS)
  197. else self.transform_coord(x,y))
  198. return x, y
  199. def transform_coord(self, x, y): return self.transform(x),y
  200. def transform(self, x, y=None):
  201. x = self.do_transform(x,False)
  202. return (x, self.do_transform(y,True)) if y is not None else x
  203. @abstractmethod
  204. def do_transform(self, x, is_y): raise NotImplementedError
  205. class CoordTransform(Transform):
  206. """ A coordinate transform. """
  207. @staticmethod
  208. def make_square(y, x):
  209. r,c,*_ = x.shape
  210. y1 = np.zeros((r, c))
  211. y = y.astype(np.int)
  212. y1[y[0]:y[2], y[1]:y[3]] = 1.
  213. return y1
  214. def map_y(self, y0, x):
  215. y = CoordTransform.make_square(y0, x)
  216. y_tr = self.do_transform(y, True)
  217. return to_bb(y_tr)
  218. def transform_coord(self, x, ys):
  219. yp = partition(ys, 4)
  220. y2 = [self.map_y(y,x) for y in yp]
  221. x = self.do_transform(x, False)
  222. return x, np.concatenate(y2)
  223. class AddPadding(CoordTransform):
  224. """ A class that represents adding paddings to an image.
  225. The default padding is border_reflect
  226. Arguments
  227. ---------
  228. pad : int
  229. size of padding on top, bottom, left and right
  230. mode:
  231. type of cv2 padding modes. (e.g., constant, reflect, wrap, replicate. etc. )
  232. """
  233. def __init__(self, pad, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO):
  234. super().__init__(tfm_y)
  235. self.pad,self.mode = pad,mode
  236. def do_transform(self, im, is_y):
  237. return cv2.copyMakeBorder(im, self.pad, self.pad, self.pad, self.pad, self.mode)
  238. class CenterCrop(CoordTransform):
  239. """ A class that represents a Center Crop.
  240. This transforms (optionally) transforms x,y at with the same parameters.
  241. Arguments
  242. ---------
  243. sz: int
  244. size of the crop.
  245. tfm_y : TfmType
  246. type of y transformation.
  247. """
  248. def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None):
  249. super().__init__(tfm_y)
  250. self.min_sz,self.sz_y = sz,sz_y
  251. def do_transform(self, x, is_y):
  252. return center_crop(x, self.sz_y if is_y else self.min_sz)
  253. class RandomCrop(CoordTransform):
  254. """ A class that represents a Random Crop transformation.
  255. This transforms (optionally) transforms x,y at with the same parameters.
  256. Arguments
  257. ---------
  258. targ: int
  259. target size of the crop.
  260. tfm_y: TfmType
  261. type of y transformation.
  262. """
  263. def __init__(self, targ_sz, tfm_y=TfmType.NO, sz_y=None):
  264. super().__init__(tfm_y)
  265. self.targ_sz,self.sz_y = targ_sz,sz_y
  266. def set_state(self):
  267. self.store.rand_r = random.uniform(0, 1)
  268. self.store.rand_c = random.uniform(0, 1)
  269. def do_transform(self, x, is_y):
  270. r,c,*_ = x.shape
  271. sz = self.sz_y if is_y else self.targ_sz
  272. start_r = np.floor(self.store.rand_r*(r-sz)).astype(int)
  273. start_c = np.floor(self.store.rand_c*(c-sz)).astype(int)
  274. return crop(x, start_r, start_c, sz)
  275. class NoCrop(CoordTransform):
  276. """ A transformation that resize to a square image without cropping.
  277. This transforms (optionally) resizes x,y at with the same parameters.
  278. Arguments:
  279. targ: int
  280. target size of the crop.
  281. tfm_y (TfmType): type of y transformation.
  282. """
  283. def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None):
  284. super().__init__(tfm_y)
  285. self.sz,self.sz_y = sz,sz_y
  286. def do_transform(self, x, is_y):
  287. if is_y: return no_crop(x, self.sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST)
  288. else : return no_crop(x, self.sz, cv2.INTER_AREA )
  289. class Scale(CoordTransform):
  290. """ A transformation that scales the min size to sz.
  291. Arguments:
  292. sz: int
  293. target size to scale minimum size.
  294. tfm_y: TfmType
  295. type of y transformation.
  296. """
  297. def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None):
  298. super().__init__(tfm_y)
  299. self.sz,self.sz_y = sz,sz_y
  300. def do_transform(self, x, is_y):
  301. if is_y: return scale_min(x, self.sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST)
  302. else : return scale_min(x, self.sz, cv2.INTER_AREA )
  303. class RandomScale(CoordTransform):
  304. """ Scales an image so that the min size is a random number between [sz, sz*max_zoom]
  305. This transforms (optionally) scales x,y at with the same parameters.
  306. Arguments:
  307. sz: int
  308. target size
  309. max_zoom: float
  310. float >= 1.0
  311. p : float
  312. a probability for doing the random sizing
  313. tfm_y: TfmType
  314. type of y transform
  315. """
  316. def __init__(self, sz, max_zoom, p=0.75, tfm_y=TfmType.NO, sz_y=None):
  317. super().__init__(tfm_y)
  318. self.sz,self.max_zoom,self.p,self.sz_y = sz,max_zoom,p,sz_y
  319. def set_state(self):
  320. min_z = 1.
  321. max_z = self.max_zoom
  322. if isinstance(self.max_zoom, collections.Iterable):
  323. min_z, max_z = self.max_zoom
  324. self.store.mult = random.uniform(min_z, max_z) if random.random()<self.p else 1
  325. self.store.new_sz = int(self.store.mult*self.sz)
  326. if self.sz_y is not None: self.store.new_sz_y = int(self.store.mult*self.sz_y)
  327. def do_transform(self, x, is_y):
  328. if is_y: return scale_min(x, self.store.new_sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST)
  329. else : return scale_min(x, self.store.new_sz, cv2.INTER_AREA )
  330. class RandomRotate(CoordTransform):
  331. """ Rotates images and (optionally) target y.
  332. Rotating coordinates is treated differently for x and y on this
  333. transform.
  334. Arguments:
  335. deg (float): degree to rotate.
  336. p (float): probability of rotation
  337. mode: type of border
  338. tfm_y (TfmType): type of y transform
  339. """
  340. def __init__(self, deg, p=0.75, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO):
  341. super().__init__(tfm_y)
  342. self.deg,self.p = deg,p
  343. if tfm_y == TfmType.COORD or tfm_y == TfmType.CLASS:
  344. self.modes = (mode,cv2.BORDER_CONSTANT)
  345. else:
  346. self.modes = (mode,mode)
  347. def set_state(self):
  348. self.store.rdeg = rand0(self.deg)
  349. self.store.rp = random.random()<self.p
  350. def do_transform(self, x, is_y):
  351. if self.store.rp: x = rotate_cv(x, self.store.rdeg,
  352. mode= self.modes[1] if is_y else self.modes[0],
  353. interpolation=cv2.INTER_NEAREST if is_y else cv2.INTER_AREA)
  354. return x
  355. class RandomDihedral(CoordTransform):
  356. """
  357. Rotates images by random multiples of 90 degrees and/or reflection.
  358. Please reference D8(dihedral group of order eight), the group of all symmetries of the square.
  359. """
  360. def set_state(self):
  361. self.store.rot_times = random.randint(0,3)
  362. self.store.do_flip = random.random()<0.5
  363. def do_transform(self, x, is_y):
  364. x = np.rot90(x, self.store.rot_times)
  365. return np.fliplr(x).copy() if self.store.do_flip else x
  366. class RandomFlip(CoordTransform):
  367. def __init__(self, tfm_y=TfmType.NO, p=0.5):
  368. super().__init__(tfm_y=tfm_y)
  369. self.p=p
  370. def set_state(self): self.store.do_flip = random.random()<self.p
  371. def do_transform(self, x, is_y): return np.fliplr(x).copy() if self.store.do_flip else x
  372. class RandomLighting(Transform):
  373. def __init__(self, b, c, tfm_y=TfmType.NO):
  374. super().__init__(tfm_y)
  375. self.b,self.c = b,c
  376. def set_state(self):
  377. self.store.b_rand = rand0(self.b)
  378. self.store.c_rand = rand0(self.c)
  379. def do_transform(self, x, is_y):
  380. if is_y and self.tfm_y != TfmType.PIXEL: return x
  381. b = self.store.b_rand
  382. c = self.store.c_rand
  383. c = -1/(c-1) if c<0 else c+1
  384. x = lighting(x, b, c)
  385. return x
  386. class RandomRotateZoom(CoordTransform):
  387. """
  388. Selects between a rotate, zoom, stretch, or no transform.
  389. Arguments:
  390. deg - maximum degrees of rotation.
  391. zoom - maximum fraction of zoom.
  392. stretch - maximum fraction of stretch.
  393. ps - probabilities for each transform. List of length 4. The order for these probabilities is as listed respectively (4th probability is 'no transform'.
  394. """
  395. def __init__(self, deg, zoom, stretch, ps=None, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO):
  396. super().__init__(tfm_y)
  397. if ps is None: ps = [0.25,0.25,0.25,0.25]
  398. assert len(ps) == 4, 'does not have 4 probabilities for p, it has %d' % len(ps)
  399. self.transforms = RandomRotate(deg, p=1, mode=mode, tfm_y=tfm_y), RandomZoom(zoom, tfm_y=tfm_y), RandomStretch(stretch,tfm_y=tfm_y)
  400. self.pass_t = PassThru()
  401. self.cum_ps = np.cumsum(ps)
  402. assert self.cum_ps[3]==1, 'probabilites do not sum to 1; they sum to %d' % self.cum_ps[3]
  403. def set_state(self):
  404. self.store.trans = self.pass_t
  405. self.store.choice = self.cum_ps[3]*random.random()
  406. for i in range(len(self.transforms)):
  407. if self.store.choice < self.cum_ps[i]:
  408. self.store.trans = self.transforms[i]
  409. break
  410. self.store.trans.set_state()
  411. def do_transform(self, x, is_y): return self.store.trans.do_transform(x, is_y)
  412. class RandomZoom(CoordTransform):
  413. def __init__(self, zoom_max, zoom_min=0, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO):
  414. super().__init__(tfm_y)
  415. self.zoom_max, self.zoom_min = zoom_max, zoom_min
  416. def set_state(self):
  417. self.store.zoom = self.zoom_min+(self.zoom_max-self.zoom_min)*random.random()
  418. def do_transform(self, x, is_y):
  419. return zoom_cv(x, self.store.zoom)
  420. class RandomStretch(CoordTransform):
  421. def __init__(self, max_stretch, tfm_y=TfmType.NO):
  422. super().__init__(tfm_y)
  423. self.max_stretch = max_stretch
  424. def set_state(self):
  425. self.store.stretch = self.max_stretch*random.random()
  426. self.store.stretch_dir = random.randint(0,1)
  427. def do_transform(self, x, is_y):
  428. if self.store.stretch_dir==0: x = stretch_cv(x, self.store.stretch, 0)
  429. else: x = stretch_cv(x, 0, self.store.stretch)
  430. return x
  431. class PassThru(CoordTransform):
  432. def do_transform(self, x, is_y):
  433. return x
  434. class RandomBlur(Transform):
  435. """
  436. Adds a gaussian blur to the image at chance.
  437. Multiple blur strengths can be configured, one of them is used by random chance.
  438. """
  439. def __init__(self, blur_strengths=5, probability=0.5, tfm_y=TfmType.NO):
  440. # Blur strength must be an odd number, because it is used as a kernel size.
  441. super().__init__(tfm_y)
  442. self.blur_strengths = (np.array(blur_strengths, ndmin=1) * 2) - 1
  443. if np.any(self.blur_strengths < 0):
  444. raise ValueError("all blur_strengths must be > 0")
  445. self.probability = probability
  446. self.apply_transform = False
  447. def set_state(self):
  448. self.store.apply_transform = random.random() < self.probability
  449. kernel_size = np.random.choice(self.blur_strengths)
  450. self.store.kernel = (kernel_size, kernel_size)
  451. def do_transform(self, x, is_y):
  452. return cv2.GaussianBlur(src=x, ksize=self.store.kernel, sigmaX=0) if self.apply_transform else x
  453. class Cutout(Transform):
  454. """ Randomly masks squares of size length on the image.
  455. https://arxiv.org/pdf/1708.04552.pdf
  456. Arguments:
  457. n_holes: number of squares
  458. length: size of the square
  459. p: probability to apply cutout
  460. tfm_y: type of y transform
  461. """
  462. def __init__(self, n_holes, length, p=0.5, tfm_y=TfmType.NO):
  463. super().__init__(tfm_y)
  464. self.n_holes, self.length, self.p = n_holes, length, p
  465. def set_state(self):
  466. self.apply_transform = random.random() < self.p
  467. def do_transform(self, img, is_y):
  468. return cutout(img, self.n_holes, self.length) if self.apply_transform else img
  469. class GoogleNetResize(CoordTransform):
  470. """ Randomly crops an image with an aspect ratio and returns a squared resized image of size targ
  471. Arguments:
  472. targ_sz: int
  473. target size
  474. min_area_frac: float < 1.0
  475. minimum area of the original image for cropping
  476. min_aspect_ratio : float
  477. minimum aspect ratio
  478. max_aspect_ratio : float
  479. maximum aspect ratio
  480. flip_hw_p : float
  481. probability for flipping magnitudes of height and width
  482. tfm_y: TfmType
  483. type of y transform
  484. """
  485. def __init__(self, targ_sz,
  486. min_area_frac=0.08, min_aspect_ratio=0.75, max_aspect_ratio=1.333, flip_hw_p=0.5,
  487. tfm_y=TfmType.NO, sz_y=None):
  488. super().__init__(tfm_y)
  489. self.targ_sz, self.tfm_y, self.sz_y = targ_sz, tfm_y, sz_y
  490. self.min_area_frac, self.min_aspect_ratio, self.max_aspect_ratio, self.flip_hw_p = min_area_frac, min_aspect_ratio, max_aspect_ratio, flip_hw_p
  491. def set_state(self):
  492. # if self.random_state: random.seed(self.random_state)
  493. self.store.fp = random.random()<self.flip_hw_p
  494. def do_transform(self, x, is_y):
  495. sz = self.sz_y if is_y else self.targ_sz
  496. if is_y:
  497. interpolation = cv2.INTER_NEAREST if self.tfm_y in (TfmType.COORD, TfmType.CLASS) else cv2.INTER_AREA
  498. else:
  499. interpolation = cv2.INTER_AREA
  500. return googlenet_resize(x, sz, self.min_area_frac, self.min_aspect_ratio, self.max_aspect_ratio, self.store.fp, interpolation=interpolation)
  501. def compose(im, y, fns):
  502. """ Apply a collection of transformation functions :fns: to images """
  503. for fn in fns:
  504. #pdb.set_trace()
  505. im, y =fn(im, y)
  506. return im if y is None else (im, y)
  507. class CropType(IntEnum):
  508. """ Type of image cropping. """
  509. RANDOM = 1
  510. CENTER = 2
  511. NO = 3
  512. GOOGLENET = 4
  513. crop_fn_lu = {CropType.RANDOM: RandomCrop, CropType.CENTER: CenterCrop, CropType.NO: NoCrop, CropType.GOOGLENET: GoogleNetResize}
  514. class Transforms():
  515. def __init__(self, sz, tfms, normalizer, denorm, crop_type=CropType.CENTER,
  516. tfm_y=TfmType.NO, sz_y=None):
  517. if sz_y is None: sz_y = sz
  518. self.sz,self.denorm,self.norm,self.sz_y = sz,denorm,normalizer,sz_y
  519. crop_tfm = crop_fn_lu[crop_type](sz, tfm_y, sz_y)
  520. self.tfms = tfms
  521. self.tfms.append(crop_tfm)
  522. if normalizer is not None: self.tfms.append(normalizer)
  523. self.tfms.append(ChannelOrder(tfm_y))
  524. def __call__(self, im, y=None): return compose(im, y, self.tfms)
  525. def __repr__(self): return str(self.tfms)
  526. def image_gen(normalizer, denorm, sz, tfms=None, max_zoom=None, pad=0, crop_type=None,
  527. tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, scale=None):
  528. """
  529. Generate a standard set of transformations
  530. Arguments
  531. ---------
  532. normalizer :
  533. image normalizing function
  534. denorm :
  535. image denormalizing function
  536. sz :
  537. size, sz_y = sz if not specified.
  538. tfms :
  539. iterable collection of transformation functions
  540. max_zoom : float,
  541. maximum zoom
  542. pad : int,
  543. padding on top, left, right and bottom
  544. crop_type :
  545. crop type
  546. tfm_y :
  547. y axis specific transformations
  548. sz_y :
  549. y size, height
  550. pad_mode :
  551. cv2 padding style: repeat, reflect, etc.
  552. Returns
  553. -------
  554. type : ``Transforms``
  555. transformer for specified image operations.
  556. See Also
  557. --------
  558. Transforms: the transformer object returned by this function
  559. """
  560. if tfm_y is None: tfm_y=TfmType.NO
  561. if tfms is None: tfms=[]
  562. elif not isinstance(tfms, collections.Iterable): tfms=[tfms]
  563. if sz_y is None: sz_y = sz
  564. if scale is None:
  565. scale = [RandomScale(sz, max_zoom, tfm_y=tfm_y, sz_y=sz_y) if max_zoom is not None
  566. else Scale(sz, tfm_y, sz_y=sz_y)]
  567. elif not is_listy(scale): scale = [scale]
  568. if pad: scale.append(AddPadding(pad, mode=pad_mode))
  569. if crop_type!=CropType.GOOGLENET: tfms=scale+tfms
  570. return Transforms(sz, tfms, normalizer, denorm, crop_type,
  571. tfm_y=tfm_y, sz_y=sz_y)
  572. def noop(x):
  573. """dummy function for do-nothing.
  574. equivalent to: lambda x: x"""
  575. return x
  576. transforms_basic = [RandomRotate(10), RandomLighting(0.05, 0.05)]
  577. transforms_side_on = transforms_basic + [RandomFlip()]
  578. transforms_top_down = transforms_basic + [RandomDihedral()]
  579. imagenet_stats = A([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  580. """Statistics pertaining to image data from image net. mean and std of the images of each color channel"""
  581. inception_stats = A([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  582. inception_models = (inception_4, inceptionresnet_2)
  583. def tfms_from_stats(stats, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,
  584. tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):
  585. """ Given the statistics of the training image sets, returns separate training and validation transform functions
  586. """
  587. if aug_tfms is None: aug_tfms=[]
  588. tfm_norm = Normalize(*stats, tfm_y=tfm_y if norm_y else TfmType.NO) if stats is not None else None
  589. tfm_denorm = Denormalize(*stats) if stats is not None else None
  590. val_crop = CropType.CENTER if crop_type in (CropType.RANDOM,CropType.GOOGLENET) else crop_type
  591. val_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=val_crop,
  592. tfm_y=tfm_y, sz_y=sz_y, scale=scale)
  593. trn_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=crop_type,
  594. tfm_y=tfm_y, sz_y=sz_y, tfms=aug_tfms, max_zoom=max_zoom, pad_mode=pad_mode, scale=scale)
  595. return trn_tfm, val_tfm
  596. def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,
  597. tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):
  598. """ Returns separate transformers of images for training and validation.
  599. Transformers are constructed according to the image statistics given b y the model. (See tfms_from_stats)
  600. Arguments:
  601. f_model: model, pretrained or not pretrained
  602. """
  603. stats = inception_stats if f_model in inception_models else imagenet_stats
  604. return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type,
  605. tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale)