浏览代码

Adding gaussian noise augmentation

Jason Antic 6 年之前
父节点
当前提交
cc1d53c2d6
共有 2 个文件被更改,包括 26 次插入2 次删除
  1. 23 0
      fasterai/augs.py
  2. 3 2
      fasterai/dataset.py

+ 23 - 0
fasterai/augs.py

@@ -0,0 +1,23 @@
+
+from fastai.vision.image import TfmPixel
+import random
+
+#Contributed by Rani Horev. Thank you!
+def _noisify(x, pct_pixels_min:float=0.003, pct_pixels_max:float=0.0031, noise_range:int=30):
+    if noise_range > 255 or noise_range < 0:
+        raise('noise_range must be between 0 and 255, inclusively.')
+    h,w = x.shape[1:]
+    img_size = h * w
+    mult = 10000.0
+    pct_pixels = random.randrange(int(pct_pixels_min*mult), int(pct_pixels_max*mult))/mult
+    noise_count = int(img_size * pct_pixels)
+
+    for ii in range(noise_count):
+        yy = random.randrange(h)
+        xx = random.randrange(w)
+        noise = random.randrange(-noise_range, noise_range)/255.0
+        x[:,yy,xx].add_(noise)
+    return x
+
+
+noisify = TfmPixel(_noisify)

+ 3 - 2
fasterai/dataset.py

@@ -3,17 +3,18 @@ from fastai import *
 from fastai.core import *
 from fastai.vision.transform import get_transforms
 from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
+from .augs import noisify
 
 
 def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None, 
-        keep_pct:float=1.0, num_workers:int=8)->ImageDataBunch:
+        keep_pct:float=1.0, num_workers:int=8, xtra_tfms=[])->ImageDataBunch:
 
     src = (ImageImageList.from_folder(crappy_path)
         .use_partial_data(sample_pct=keep_pct, seed=random_seed)
         .random_split_by_pct(0.1, seed=random_seed))
 
     data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
-        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25), size=sz, tfm_y=True)
+        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)
         .databunch(bs=bs, num_workers=num_workers, no_check=True)
         .normalize(imagenet_stats, do_y=True))