Browse Source

Adding new palette watermark

Jason Antic 5 years ago
parent
commit
2326e4090b
2 changed files with 71 additions and 10 deletions
  1. 68 10
      deoldify/visualize.py
  2. 3 0
      resource_images/watermark.png

+ 68 - 10
deoldify/visualize.py

@@ -17,6 +17,27 @@ import base64
 from IPython import display as ipythondisplay
 from IPython import display as ipythondisplay
 from IPython.display import HTML
 from IPython.display import HTML
 from IPython.display import Image as ipythonimage
 from IPython.display import Image as ipythonimage
+import cv2
+
+# adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/
+def get_watermarked(pil_image: Image) -> Image:
+    image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
+    (h, w) = image.shape[:2]
+    image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255])
+    full_watermark = cv2.imread('./resource_images/watermark.png', cv2.IMREAD_UNCHANGED)
+    pct = 0.05
+    (fwH, fwW) = full_watermark.shape[:2]
+    wH = int(pct * h)
+    wW = int((pct * h / fwH) * fwW)
+    watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)
+    overlay = np.zeros((h, w, 4), dtype="uint8")
+    overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark
+    # blend the two images together using transparent overlays
+    output = image.copy()
+    cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)
+    rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
+    final_image = Image.fromarray(rgb_image)
+    return final_image
 
 
 
 
 class ModelImageVisualizer:
 class ModelImageVisualizer:
@@ -45,6 +66,7 @@ class ModelImageVisualizer:
         render_factor: int = None,
         render_factor: int = None,
         display_render_factor: bool = False,
         display_render_factor: bool = False,
         compare: bool = False,
         compare: bool = False,
+        watermarked: bool = True,
     ) -> Path:
     ) -> Path:
         img = self._get_image_from_url(url)
         img = self._get_image_from_url(url)
         img.save(path)
         img.save(path)
@@ -54,6 +76,7 @@ class ModelImageVisualizer:
             render_factor=render_factor,
             render_factor=render_factor,
             display_render_factor=display_render_factor,
             display_render_factor=display_render_factor,
             compare=compare,
             compare=compare,
+            watermarked=watermarked,
         )
         )
 
 
     def plot_transformed_image(
     def plot_transformed_image(
@@ -63,9 +86,12 @@ class ModelImageVisualizer:
         render_factor: int = None,
         render_factor: int = None,
         display_render_factor: bool = False,
         display_render_factor: bool = False,
         compare: bool = False,
         compare: bool = False,
+        watermarked: bool = True,
     ) -> Path:
     ) -> Path:
         path = Path(path)
         path = Path(path)
-        result = self.get_transformed_image(path, render_factor)
+        result = self.get_transformed_image(
+            path, render_factor, watermarked=watermarked
+        )
         orig = self._open_pil_image(path)
         orig = self._open_pil_image(path)
         if compare:
         if compare:
             self._plot_comparison(
             self._plot_comparison(
@@ -121,12 +147,18 @@ class ModelImageVisualizer:
         image.save(result_path)
         image.save(result_path)
         return result_path
         return result_path
 
 
-    def get_transformed_image(self, path: Path, render_factor: int = None) -> Image:
+    def get_transformed_image(
+        self, path: Path, render_factor: int = None, watermarked: bool = True
+    ) -> Image:
         self._clean_mem()
         self._clean_mem()
         orig_image = self._open_pil_image(path)
         orig_image = self._open_pil_image(path)
         filtered_image = self.filter.filter(
         filtered_image = self.filter.filter(
             orig_image, orig_image, render_factor=render_factor
             orig_image, orig_image, render_factor=render_factor
         )
         )
+
+        if watermarked:
+            return get_watermarked(filtered_image)
+
         return filtered_image
         return filtered_image
 
 
     def _plot_image(
     def _plot_image(
@@ -200,7 +232,9 @@ class VideoColorizer:
             str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0
             str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0
         ).run(capture_stdout=True)
         ).run(capture_stdout=True)
 
 
-    def _colorize_raw_frames(self, source_path: Path, render_factor: int = None):
+    def _colorize_raw_frames(
+        self, source_path: Path, render_factor: int = None, watermarked: bool = True
+    ):
         colorframes_folder = self.colorframes_root / (source_path.stem)
         colorframes_folder = self.colorframes_root / (source_path.stem)
         colorframes_folder.mkdir(parents=True, exist_ok=True)
         colorframes_folder.mkdir(parents=True, exist_ok=True)
         self._purge_images(colorframes_folder)
         self._purge_images(colorframes_folder)
@@ -210,7 +244,7 @@ class VideoColorizer:
             img_path = bwframes_folder / img
             img_path = bwframes_folder / img
             if os.path.isfile(str(img_path)):
             if os.path.isfile(str(img_path)):
                 color_image = self.vis.get_transformed_image(
                 color_image = self.vis.get_transformed_image(
-                    str(img_path), render_factor=render_factor
+                    str(img_path), render_factor=render_factor, watermarked=watermarked
                 )
                 )
                 color_image.save(str(colorframes_folder / img))
                 color_image.save(str(colorframes_folder / img))
 
 
@@ -265,26 +299,38 @@ class VideoColorizer:
         return result_path
         return result_path
 
 
     def colorize_from_url(
     def colorize_from_url(
-        self, source_url, file_name: str, render_factor: int = None
+        self,
+        source_url,
+        file_name: str,
+        render_factor: int = None,
+        watermarked: bool = True,
     ) -> Path:
     ) -> Path:
         source_path = self.source_folder / file_name
         source_path = self.source_folder / file_name
         self._download_video_from_url(source_url, source_path)
         self._download_video_from_url(source_url, source_path)
-        return self._colorize_from_path(source_path, render_factor=render_factor)
+        return self._colorize_from_path(
+            source_path, render_factor=render_factor, watermarked=watermarked
+        )
 
 
     def colorize_from_file_name(
     def colorize_from_file_name(
-        self, file_name: str, render_factor: int = None
+        self, file_name: str, render_factor: int = None, watermarked: bool = True
     ) -> Path:
     ) -> Path:
         source_path = self.source_folder / file_name
         source_path = self.source_folder / file_name
-        return self._colorize_from_path(source_path, render_factor=render_factor)
+        return self._colorize_from_path(
+            source_path, render_factor=render_factor, watermarked=watermarked
+        )
 
 
-    def _colorize_from_path(self, source_path: Path, render_factor: int = None) -> Path:
+    def _colorize_from_path(
+        self, source_path: Path, render_factor: int = None, watermarked: bool = True
+    ) -> Path:
         if not source_path.exists():
         if not source_path.exists():
             raise Exception(
             raise Exception(
                 'Video at path specfied, ' + str(source_path) + ' could not be found.'
                 'Video at path specfied, ' + str(source_path) + ' could not be found.'
             )
             )
 
 
         self._extract_raw_frames(source_path)
         self._extract_raw_frames(source_path)
-        self._colorize_raw_frames(source_path, render_factor=render_factor)
+        self._colorize_raw_frames(
+            source_path, render_factor=render_factor, watermarked=watermarked
+        )
         return self._build_video(source_path)
         return self._build_video(source_path)
 
 
 
 
@@ -292,6 +338,18 @@ def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
     return get_stable_video_colorizer(render_factor=render_factor)
     return get_stable_video_colorizer(render_factor=render_factor)
 
 
 
 
+def get_artistic_video_colorizer(
+    root_folder: Path = Path('./'),
+    weights_name: str = 'ColorizeArtistic_gen',
+    results_dir='result_images',
+    render_factor: int = 35,
+) -> VideoColorizer:
+    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
+    vis = ModelImageVisualizer(filtr, results_dir=results_dir)
+    return VideoColorizer(vis)
+
+
 def get_stable_video_colorizer(
 def get_stable_video_colorizer(
     root_folder: Path = Path('./'),
     root_folder: Path = Path('./'),
     weights_name: str = 'ColorizeVideo_gen',
     weights_name: str = 'ColorizeVideo_gen',

+ 3 - 0
resource_images/watermark.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:568488613b9e2addbda770324d67feb956d6c8c56c29285717b15ecfd6e77f03
+size 9210