Quellcode durchsuchen

Fixing a few bugs and cleaning up new post_process/stats parameterizations

Jason Antic vor 5 Jahren
Ursprung
Commit
4969dfa11d
4 geänderte Dateien mit 33 neuen und 44 gelöschten Zeilen
  1. 2 2
      deoldify/dataset.py
  2. 9 14
      deoldify/filters.py
  3. 4 6
      deoldify/generators.py
  4. 18 22
      deoldify/visualize.py

+ 2 - 2
deoldify/dataset.py

@@ -41,8 +41,8 @@ def get_colorize_data(
     return data
 
 
-def get_dummy_databunch(stats=imagenet_stats) -> ImageDataBunch:
+def get_dummy_databunch() -> ImageDataBunch:
     path = Path('./dummy/')
     return get_colorize_data(
-        sz=1, bs=1, crappy_path=path, good_path=path, stats=stats,keep_pct=0.001
+        sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
     )

+ 9 - 14
deoldify/filters.py

@@ -21,7 +21,7 @@ class IFilter(ABC):
 
 
 class BaseFilter(IFilter):
-    def __init__(self, learn: Learner, stats:tuple = imagenet_stats):
+    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
         super().__init__()
         self.learn = learn
         self.norm, self.denorm = normalize_funcs(*stats)
@@ -60,21 +60,20 @@ class BaseFilter(IFilter):
 
 
 class ColorizerFilter(BaseFilter):
-    def __init__(self, learn: Learner, stats: tuple = imagenet_stats, map_to_orig: bool = True):
+    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
         super().__init__(learn=learn, stats=stats)
         self.render_base = 16
-        self.map_to_orig = map_to_orig
 
     def filter(
-        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int,post_process: bool = False
-    ) -> PilImage:
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
         render_sz = render_factor * self.render_base
         model_image = self._model_process(orig=filtered_image, sz=render_sz)
+        raw_color = self._unsquare(model_image, orig_image)
 
-        if self.map_to_orig:
-            return self._post_process(model_image, orig_image, post_process )
+        if post_process:
+            return self._post_process(raw_color, orig_image)
         else:
-            return self._post_process(model_image, filtered_image, post_process)
+            return raw_color
 
     def _transform(self, image: PilImage) -> PilImage:
         return image.convert('LA').convert('RGB')
@@ -84,10 +83,7 @@ class ColorizerFilter(BaseFilter):
     # save a lot on memory and processing in the model, yet get a great high
     # resolution result at the end.  This is primarily intended just for
     # inference
-    def _post_process(self, raw_color: PilImage, orig: PilImage, post_process: bool) -> PilImage:
-        raw_color = self._unsquare(raw_color, orig)
-        if not post_process:
-            return raw_color
+    def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
         color_np = np.asarray(raw_color)
         orig_np = np.asarray(orig)
         color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
@@ -106,8 +102,7 @@ class MasterFilter(BaseFilter):
         self.render_factor = render_factor
 
     def filter(
-        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None,post_process: bool = False
-    ) -> PilImage:
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
         render_factor = self.render_factor if render_factor is None else render_factor
         for filter in self.filters:
             filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)

+ 4 - 6
deoldify/generators.py

@@ -6,9 +6,8 @@ from .dataset import *
 
 # Weights are implicitly read from ./models/ folder
 def gen_inference_wide(
-    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101, stats: tuple = imagenet_stats
-) -> Learner:
-    data = get_dummy_databunch(stats)
+    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
+    data = get_dummy_databunch()
     learn = gen_learner_wide(
         data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
     )
@@ -80,9 +79,8 @@ def unet_learner_wide(
 
 # Weights are implicitly read from ./models/ folder
 def gen_inference_deep(
-    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5, stats: tuple = imagenet_stats
-) -> Learner:
-    data = get_dummy_databunch(stats=stats)
+    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
+    data = get_dummy_databunch()
     learn = gen_learner_deep(
         data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
     )

+ 18 - 22
deoldify/visualize.py

@@ -180,7 +180,7 @@ class ModelImageVisualizer:
         render_factor: int,
         axes: Axes = None,
         figsize=(20, 20),
-        display_render_factor=35,
+        display_render_factor = False,
     ):
         if axes is None:
             _, axes = plt.subplots(figsize=figsize)
@@ -246,7 +246,7 @@ class VideoColorizer:
         ).run(capture_stdout=True)
 
     def _colorize_raw_frames(
-        self, source_path: Path, render_factor: int = None, post_process: bool = False,
+        self, source_path: Path, render_factor: int = None, post_process: bool = True,
         watermarked: bool = True,
     ):
         colorframes_folder = self.colorframes_root / (source_path.stem)
@@ -318,7 +318,7 @@ class VideoColorizer:
         source_url,
         file_name: str,
         render_factor: int = None,
-        post_process: bool = False,
+        post_process: bool = True,
         watermarked: bool = True,
 
     ) -> Path:
@@ -350,19 +350,18 @@ class VideoColorizer:
         return self._build_video(source_path)
 
 
-def get_video_colorizer(render_factor: int = 21, stats:tuple = imagenet_stats) -> VideoColorizer:
-    return get_stable_video_colorizer(render_factor=render_factor, stats=stats)
+def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
+    return get_stable_video_colorizer(render_factor=render_factor)
 
 
 def get_artistic_video_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeArtistic_gen',
     results_dir='result_images',
-    render_factor: int = 35,
-    stats:tuple = imagenet_stats
+    render_factor: int = 35
 ) -> VideoColorizer:
     learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
-    filtr = MasterFilter([ColorizerFilter(learn=learn, stats=stats)], render_factor=render_factor)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return VideoColorizer(vis)
 
@@ -371,33 +370,31 @@ def get_stable_video_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeVideo_gen',
     results_dir='result_images',
-    render_factor: int = 21,
-    stats:tuple = imagenet_stats
+    render_factor: int = 21
 ) -> VideoColorizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
-    filtr = MasterFilter([ColorizerFilter(learn=learn,stats=stats)], render_factor=render_factor)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return VideoColorizer(vis)
 
 
 def get_image_colorizer(
-    render_factor: int = 35, artistic: bool = True, stats: tuple = imagenet_stats
+    render_factor: int = 35, artistic: bool = True
 ) -> ModelImageVisualizer:
     if artistic:
-        return get_artistic_image_colorizer(render_factor=render_factor, stats=stats)
+        return get_artistic_image_colorizer(render_factor=render_factor)
     else:
-        return get_stable_image_colorizer(render_factor=render_factor, stats=stats)
+        return get_stable_image_colorizer(render_factor=render_factor)
 
 
 def get_stable_image_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeStable_gen',
     results_dir='result_images',
-    render_factor: int = 35,
-    stats: tuple = imagenet_stats
+    render_factor: int = 35
 ) -> ModelImageVisualizer:
-    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name, stats=stats)
-    filtr = MasterFilter([ColorizerFilter(learn=learn, stats=stats)], render_factor=render_factor)
+    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
 
@@ -406,11 +403,10 @@ def get_artistic_image_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeArtistic_gen',
     results_dir='result_images',
-    render_factor: int = 35,
-    stats: tuple = imagenet_stats
+    render_factor: int = 35
 ) -> ModelImageVisualizer:
-    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name, stats=stats)
-    filtr = MasterFilter([ColorizerFilter(learn=learn, stats=stats)], render_factor=render_factor)
+    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis