Переглянути джерело

Merge pull request #205 from jantic/pr/204

Pr/204 with modifications
Jason Antic 5 роки тому
батько
коміт
6766bf4a1e
4 змінених файлів з 40 додано та 37 видалено
  1. 3 2
      deoldify/dataset.py
  2. 11 15
      deoldify/filters.py
  3. 2 4
      deoldify/generators.py
  4. 24 16
      deoldify/visualize.py

+ 3 - 2
deoldify/dataset.py

@@ -14,9 +14,10 @@ def get_colorize_data(
     random_seed: int = None,
     keep_pct: float = 1.0,
     num_workers: int = 8,
+    stats: tuple = imagenet_stats,
     xtra_tfms=[],
 ) -> ImageDataBunch:
-
+    
     src = (
         ImageImageList.from_folder(crappy_path, convert_mode='RGB')
         .use_partial_data(sample_pct=keep_pct, seed=random_seed)
@@ -33,7 +34,7 @@ def get_colorize_data(
             tfm_y=True,
         )
         .databunch(bs=bs, num_workers=num_workers, no_check=True)
-        .normalize(imagenet_stats, do_y=True)
+        .normalize(stats, do_y=True)
     )
 
     data.c = 3

+ 11 - 15
deoldify/filters.py

@@ -21,10 +21,10 @@ class IFilter(ABC):
 
 
 class BaseFilter(IFilter):
-    def __init__(self, learn: Learner):
+    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
         super().__init__()
         self.learn = learn
-        self.norm, self.denorm = normalize_funcs(*imagenet_stats)
+        self.norm, self.denorm = normalize_funcs(*stats)
 
     def _transform(self, image: PilImage) -> PilImage:
         return image
@@ -60,21 +60,20 @@ class BaseFilter(IFilter):
 
 
 class ColorizerFilter(BaseFilter):
-    def __init__(self, learn: Learner, map_to_orig: bool = True):
-        super().__init__(learn=learn)
+    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
+        super().__init__(learn=learn, stats=stats)
         self.render_base = 16
-        self.map_to_orig = map_to_orig
 
     def filter(
-        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
-    ) -> PilImage:
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
         render_sz = render_factor * self.render_base
         model_image = self._model_process(orig=filtered_image, sz=render_sz)
+        raw_color = self._unsquare(model_image, orig_image)
 
-        if self.map_to_orig:
-            return self._post_process(model_image, orig_image)
+        if post_process:
+            return self._post_process(raw_color, orig_image)
         else:
-            return self._post_process(model_image, filtered_image)
+            return raw_color
 
     def _transform(self, image: PilImage) -> PilImage:
         return image.convert('LA').convert('RGB')
@@ -85,7 +84,6 @@ class ColorizerFilter(BaseFilter):
     # resolution result at the end.  This is primarily intended just for
     # inference
     def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
-        raw_color = self._unsquare(raw_color, orig)
         color_np = np.asarray(raw_color)
         orig_np = np.asarray(orig)
         color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
@@ -104,11 +102,9 @@ class MasterFilter(BaseFilter):
         self.render_factor = render_factor
 
     def filter(
-        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None
-    ) -> PilImage:
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
         render_factor = self.render_factor if render_factor is None else render_factor
-
         for filter in self.filters:
-            filtered_image = filter.filter(orig_image, filtered_image, render_factor)
+            filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
 
         return filtered_image

+ 2 - 4
deoldify/generators.py

@@ -6,8 +6,7 @@ from .dataset import *
 
 # Weights are implicitly read from ./models/ folder
 def gen_inference_wide(
-    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101
-) -> Learner:
+    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
     data = get_dummy_databunch()
     learn = gen_learner_wide(
         data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
@@ -80,8 +79,7 @@ def unet_learner_wide(
 
 # Weights are implicitly read from ./models/ folder
 def gen_inference_deep(
-    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5
-) -> Learner:
+    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
     data = get_dummy_databunch()
     learn = gen_learner_deep(
         data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor

+ 24 - 16
deoldify/visualize.py

@@ -72,8 +72,10 @@ class ModelImageVisualizer:
         path: str = 'test_images/image.png',
         figsize: (int, int) = (20, 20),
         render_factor: int = None,
+        
         display_render_factor: bool = False,
         compare: bool = False,
+        post_process: bool = True,
         watermarked: bool = True,
     ) -> Path:
         img = self._get_image_from_url(url)
@@ -84,6 +86,7 @@ class ModelImageVisualizer:
             render_factor=render_factor,
             display_render_factor=display_render_factor,
             compare=compare,
+            post_process = post_process,
             watermarked=watermarked,
         )
 
@@ -94,11 +97,12 @@ class ModelImageVisualizer:
         render_factor: int = None,
         display_render_factor: bool = False,
         compare: bool = False,
+        post_process: bool = True,
         watermarked: bool = True,
     ) -> Path:
         path = Path(path)
         result = self.get_transformed_image(
-            path, render_factor, watermarked=watermarked
+            path, render_factor, post_process=post_process,watermarked=watermarked
         )
         orig = self._open_pil_image(path)
         if compare:
@@ -156,12 +160,13 @@ class ModelImageVisualizer:
         return result_path
 
     def get_transformed_image(
-        self, path: Path, render_factor: int = None, watermarked: bool = True
+        self, path: Path, render_factor: int = None, post_process: bool = True,
+        watermarked: bool = True,
     ) -> Image:
         self._clean_mem()
         orig_image = self._open_pil_image(path)
         filtered_image = self.filter.filter(
-            orig_image, orig_image, render_factor=render_factor
+            orig_image, orig_image, render_factor=render_factor,post_process=post_process
         )
 
         if watermarked:
@@ -175,7 +180,7 @@ class ModelImageVisualizer:
         render_factor: int,
         axes: Axes = None,
         figsize=(20, 20),
-        display_render_factor: bool = False,
+        display_render_factor = False,
     ):
         if axes is None:
             _, axes = plt.subplots(figsize=figsize)
@@ -241,7 +246,8 @@ class VideoColorizer:
         ).run(capture_stdout=True)
 
     def _colorize_raw_frames(
-        self, source_path: Path, render_factor: int = None, watermarked: bool = True
+        self, source_path: Path, render_factor: int = None, post_process: bool = True,
+        watermarked: bool = True,
     ):
         colorframes_folder = self.colorframes_root / (source_path.stem)
         colorframes_folder.mkdir(parents=True, exist_ok=True)
@@ -250,9 +256,10 @@ class VideoColorizer:
 
         for img in progress_bar(os.listdir(str(bwframes_folder))):
             img_path = bwframes_folder / img
+
             if os.path.isfile(str(img_path)):
                 color_image = self.vis.get_transformed_image(
-                    str(img_path), render_factor=render_factor, watermarked=watermarked
+                    str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked
                 )
                 color_image.save(str(colorframes_folder / img))
 
@@ -311,33 +318,34 @@ class VideoColorizer:
         source_url,
         file_name: str,
         render_factor: int = None,
+        post_process: bool = True,
         watermarked: bool = True,
+
     ) -> Path:
         source_path = self.source_folder / file_name
         self._download_video_from_url(source_url, source_path)
         return self._colorize_from_path(
-            source_path, render_factor=render_factor, watermarked=watermarked
+            source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
         )
 
     def colorize_from_file_name(
-        self, file_name: str, render_factor: int = None, watermarked: bool = True
+        self, file_name: str, render_factor: int = None,  watermarked: bool = True, post_process: bool = True,
     ) -> Path:
         source_path = self.source_folder / file_name
         return self._colorize_from_path(
-            source_path, render_factor=render_factor, watermarked=watermarked
+            source_path, render_factor=render_factor,  post_process=post_process,watermarked=watermarked
         )
 
     def _colorize_from_path(
-        self, source_path: Path, render_factor: int = None, watermarked: bool = True
+        self, source_path: Path, render_factor: int = None,  watermarked: bool = True, post_process: bool = True
     ) -> Path:
         if not source_path.exists():
             raise Exception(
                 'Video at path specfied, ' + str(source_path) + ' could not be found.'
             )
-
         self._extract_raw_frames(source_path)
         self._colorize_raw_frames(
-            source_path, render_factor=render_factor, watermarked=watermarked
+            source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked
         )
         return self._build_video(source_path)
 
@@ -350,7 +358,7 @@ def get_artistic_video_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeArtistic_gen',
     results_dir='result_images',
-    render_factor: int = 35,
+    render_factor: int = 35
 ) -> VideoColorizer:
     learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
@@ -362,7 +370,7 @@ def get_stable_video_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeVideo_gen',
     results_dir='result_images',
-    render_factor: int = 21,
+    render_factor: int = 21
 ) -> VideoColorizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
@@ -383,7 +391,7 @@ def get_stable_image_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeStable_gen',
     results_dir='result_images',
-    render_factor: int = 35,
+    render_factor: int = 35
 ) -> ModelImageVisualizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
@@ -395,7 +403,7 @@ def get_artistic_image_colorizer(
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeArtistic_gen',
     results_dir='result_images',
-    render_factor: int = 35,
+    render_factor: int = 35
 ) -> ModelImageVisualizer:
     learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)