Просмотр исходного кода

code cleanup and refactor

Using Black and PyLint to improve
code quality.
Alexandre Vicenzi 6 лет назад
Родитель
Сommit
b80aaafe09
2 измененных файлов с 35 добавлено и 21 удалено
  1. 15 9
      deoldify/augs.py
  2. 20 12
      deoldify/save.py

+ 15 - 9
deoldify/augs.py

@@ -1,23 +1,29 @@
+import random
 
 from fastai.vision.image import TfmPixel
-import random
 
-#Contributed by Rani Horev. Thank you!
-def _noisify(x, pct_pixels_min:float=0.001, pct_pixels_max:float=0.4, noise_range:int=30):
+# Contributed by Rani Horev. Thank you!
+def _noisify(
+    x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
+):
     if noise_range > 255 or noise_range < 0:
-        raise Exception('noise_range must be between 0 and 255, inclusively.')
-    h,w = x.shape[1:]
+        raise Exception("noise_range must be between 0 and 255, inclusively.")
+
+    h, w = x.shape[1:]
     img_size = h * w
     mult = 10000.0
-    pct_pixels = random.randrange(int(pct_pixels_min*mult), int(pct_pixels_max*mult))/mult
+    pct_pixels = (
+        random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
+    )
     noise_count = int(img_size * pct_pixels)
 
     for ii in range(noise_count):
         yy = random.randrange(h)
         xx = random.randrange(w)
-        noise = random.randrange(-noise_range, noise_range)/255.0
-        x[:,yy,xx].add_(noise)
+        noise = random.randrange(-noise_range, noise_range) / 255.0
+        x[:, yy, xx].add_(noise)
+
     return x
 
 
-noisify = TfmPixel(_noisify)
+noisify = TfmPixel(_noisify)

+ 20 - 12
deoldify/save.py

@@ -1,21 +1,29 @@
-from fastai.torch_core import *
-from fastai.basic_data import DataBunch
-from fastai.callback import *
 from fastai.basic_train import Learner, LearnerCallback
 from fastai.vision.gan import GANLearner
 
+
 class GANSaveCallback(LearnerCallback):
-    "A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."
-    def __init__(self, learn:GANLearner, learn_gen:Learner, filename:str, save_iters:int=1000): 
+    """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
+
+    def __init__(
+        self,
+        learn: GANLearner,
+        learn_gen: Learner,
+        filename: str,
+        save_iters: int = 1000,
+    ):
         super().__init__(learn)
-        self.learn_gen, self.filename, self.save_iters = learn_gen, filename, save_iters
+        self.learn_gen = learn_gen
+        self.filename = filename
+        self.save_iters = save_iters
 
+    def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
+        if iteration == 0:
+            return
 
-    def on_batch_end(self, iteration:int, epoch:int, **kwargs)->None:
-        if iteration == 0: return
-        if iteration % self.save_iters == 0: 
+        if iteration % self.save_iters == 0:
             self._save_gen_learner(iteration=iteration, epoch=epoch)
 
-    def _save_gen_learner(self, iteration:int, epoch:int):
-        fn = self.filename + '_' + str(epoch) + '_' + str(iteration)
-        self.learn_gen.save(fn)
+    def _save_gen_learner(self, iteration: int, epoch: int):
+        filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
+        self.learn_gen.save(filename)