Jelajahi Sumber

Merge pull request #161 from alexandrevicenzi/refac-2

Refactor and cleanup
Jason Antic 5 tahun lalu
induk
melakukan
6c92ebe1ae
8 mengubah file dengan 735 tambahan dan 314 penghapusan
  1. 24 9
      deoldify/critics.py
  2. 31 12
      deoldify/dataset.py
  3. 47 38
      deoldify/filters.py
  4. 134 44
      deoldify/generators.py
  5. 38 15
      deoldify/layers.py
  6. 64 37
      deoldify/loss.py
  7. 169 58
      deoldify/unet.py
  8. 228 101
      deoldify/visualize.py

+ 24 - 9
deoldify/critics.py

@@ -5,25 +5,40 @@ from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
 
 _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
 
-def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
+
+def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
     return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
 
-def custom_gan_critic(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
+
+def custom_gan_critic(
+    n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
+):
     "Critic to train a `GAN`."
-    layers = [
-        _conv(n_channels, nf, ks=4, stride=2),
-        nn.Dropout2d(p/2)]
+    layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
     for i in range(n_blocks):
         layers += [
             _conv(nf, nf, ks=3, stride=1),
             nn.Dropout2d(p),
-            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
+            _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
+        ]
         nf *= 2
     layers += [
         _conv(nf, nf, ks=3, stride=1),
         _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
-        Flatten()]
+        Flatten(),
+    ]
     return nn.Sequential(*layers)
 
-def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
-    return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
+
+def colorize_crit_learner(
+    data: ImageDataBunch,
+    loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
+    nf: int = 256,
+) -> Learner:
+    return Learner(
+        data,
+        custom_gan_critic(nf=nf),
+        metrics=accuracy_thresh_expand,
+        loss_func=loss_critic,
+        wd=1e-3,
+    )

+ 31 - 12
deoldify/dataset.py

@@ -6,23 +6,42 @@ from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
 from .augs import noisify
 
 
-def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None, 
-        keep_pct:float=1.0, num_workers:int=8, xtra_tfms=[])->ImageDataBunch:
-
-    src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB')
+def get_colorize_data(
+    sz: int,
+    bs: int,
+    crappy_path: Path,
+    good_path: Path,
+    random_seed: int = None,
+    keep_pct: float = 1.0,
+    num_workers: int = 8,
+    xtra_tfms=[],
+) -> ImageDataBunch:
+
+    src = (
+        ImageImageList.from_folder(crappy_path, convert_mode='RGB')
         .use_partial_data(sample_pct=keep_pct, seed=random_seed)
-        .split_by_rand_pct(0.1, seed=random_seed))
-
-    data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
-        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)
+        .split_by_rand_pct(0.1, seed=random_seed)
+    )
+
+    data = (
+        src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
+        .transform(
+            get_transforms(
+                max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
+            ),
+            size=sz,
+            tfm_y=True,
+        )
         .databunch(bs=bs, num_workers=num_workers, no_check=True)
-        .normalize(imagenet_stats, do_y=True))
+        .normalize(imagenet_stats, do_y=True)
+    )
 
     data.c = 3
     return data
 
 
-
-def get_dummy_databunch()->ImageDataBunch:
+def get_dummy_databunch() -> ImageDataBunch:
     path = Path('./dummy/')
-    return get_colorize_data(sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001)
+    return get_colorize_data(
+        sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
+    )

+ 47 - 38
deoldify/filters.py

@@ -14,54 +14,60 @@ from PIL import Image as PilImage
 
 class IFilter(ABC):
     @abstractmethod
-    def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:
-        pass   
-  
+    def filter(
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
+    ) -> PilImage:
+        pass
+
+
 class BaseFilter(IFilter):
-    def __init__(self, learn:Learner):
+    def __init__(self, learn: Learner):
         super().__init__()
-        self.learn=learn
+        self.learn = learn
         self.norm, self.denorm = normalize_funcs(*imagenet_stats)
 
-    def _transform(self, image:PilImage)->PilImage:
+    def _transform(self, image: PilImage) -> PilImage:
         return image
 
-    def _scale_to_square(self, orig:PilImage, targ:int)->PilImage:
-        #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
-        #I've tried padding to the square as well (reflect, symetric, constant, etc).  Not as good!
+    def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
+        # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
+        # I've tried padding to the square as well (reflect, symetric, constant, etc).  Not as good!
         targ_sz = (targ, targ)
         return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
 
-    def _get_model_ready_image(self, orig:PilImage, sz:int)->PilImage:
+    def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
         result = self._scale_to_square(orig, sz)
         result = self._transform(result)
         return result
 
-    def _model_process(self, orig:PilImage, sz:int)->PilImage:
+    def _model_process(self, orig: PilImage, sz: int) -> PilImage:
         model_image = self._get_model_ready_image(orig, sz)
-        x =  pil2tensor(model_image,np.float32)
+        x = pil2tensor(model_image, np.float32)
         x.div_(255)
-        x,y = self.norm((x,x), do_x=True)
-        result = self.learn.pred_batch(ds_type=DatasetType.Valid, 
-            batch=(x[None].cuda(),y[None]), reconstruct=True)
+        x, y = self.norm((x, x), do_x=True)
+        result = self.learn.pred_batch(
+            ds_type=DatasetType.Valid, batch=(x[None].cuda(), y[None]), reconstruct=True
+        )
         out = result[0]
         out = self.denorm(out.px, do_x=False)
-        out = image2np(out*255).astype(np.uint8)
+        out = image2np(out * 255).astype(np.uint8)
         return PilImage.fromarray(out)
 
-    def _unsquare(self, image:PilImage, orig:PilImage)->PilImage:
+    def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
         targ_sz = orig.size
         image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
         return image
 
 
 class ColorizerFilter(BaseFilter):
-    def __init__(self, learn:Learner, map_to_orig:bool=True):
+    def __init__(self, learn: Learner, map_to_orig: bool = True):
         super().__init__(learn=learn)
-        self.render_base=16
-        self.map_to_orig=map_to_orig
+        self.render_base = 16
+        self.map_to_orig = map_to_orig
 
-    def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:
+    def filter(
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
+    ) -> PilImage:
         render_sz = render_factor * self.render_base
         model_image = self._model_process(orig=filtered_image, sz=render_sz)
 
@@ -70,36 +76,39 @@ class ColorizerFilter(BaseFilter):
         else:
             return self._post_process(model_image, filtered_image)
 
-    def  _transform(self, image:PilImage)->PilImage:
+    def _transform(self, image: PilImage) -> PilImage:
         return image.convert('LA').convert('RGB')
 
-    #This takes advantage of the fact that human eyes are much less sensitive to 
-    #imperfections in chrominance compared to luminance.  This means we can
-    #save a lot on memory and processing in the model, yet get a great high
-    #resolution result at the end.  This is primarily intended just for 
-    #inference
-    def _post_process(self, raw_color:PilImage, orig:PilImage)->PilImage:
+    # This takes advantage of the fact that human eyes are much less sensitive to
+    # imperfections in chrominance compared to luminance.  This means we can
+    # save a lot on memory and processing in the model, yet get a great high
+    # resolution result at the end.  This is primarily intended just for
+    # inference
+    def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
         raw_color = self._unsquare(raw_color, orig)
         color_np = np.asarray(raw_color)
         orig_np = np.asarray(orig)
         color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
-        #do a black and white transform first to get better luminance values
+        # do a black and white transform first to get better luminance values
         orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
         hires = np.copy(orig_yuv)
-        hires[:,:,1:3] = color_yuv[:,:,1:3]
-        final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)  
-        final = PilImage.fromarray(final) 
+        hires[:, :, 1:3] = color_yuv[:, :, 1:3]
+        final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
+        final = PilImage.fromarray(final)
         return final
 
+
 class MasterFilter(BaseFilter):
-    def __init__(self, filters:[IFilter], render_factor:int):
-        self.filters=filters
-        self.render_factor=render_factor
+    def __init__(self, filters: [IFilter], render_factor: int):
+        self.filters = filters
+        self.render_factor = render_factor
 
-    def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int=None)->PilImage:
+    def filter(
+        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None
+    ) -> PilImage:
         render_factor = self.render_factor if render_factor is None else render_factor
 
         for filter in self.filters:
-            filtered_image=filter.filter(orig_image, filtered_image, render_factor)
-        
+            filtered_image = filter.filter(orig_image, filtered_image, render_factor)
+
         return filtered_image

+ 134 - 44
deoldify/generators.py

@@ -4,66 +4,156 @@ from .unet import DynamicUnetWide, DynamicUnetDeep
 from .loss import FeatureLoss
 from .dataset import *
 
-#Weights are implicitly read from ./models/ folder 
-def gen_inference_wide(root_folder:Path, weights_name:str, nf_factor:int=2, arch=models.resnet101)->Learner:
-      data = get_dummy_databunch()
-      learn = gen_learner_wide(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
-      learn.path = root_folder
-      learn.load(weights_name)
-      learn.model.eval()
-      return learn
+# Weights are implicitly read from ./models/ folder
+def gen_inference_wide(
+    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101
+) -> Learner:
+    data = get_dummy_databunch()
+    learn = gen_learner_wide(
+        data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
+    )
+    learn.path = root_folder
+    learn.load(weights_name)
+    learn.model.eval()
+    return learn
+
 
-def gen_learner_wide(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet101, nf_factor:int=2)->Learner:
-    return unet_learner_wide(data, arch=arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
-                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
+def gen_learner_wide(
+    data: ImageDataBunch,
+    gen_loss=FeatureLoss(),
+    arch=models.resnet101,
+    nf_factor: int = 2,
+) -> Learner:
+    return unet_learner_wide(
+        data,
+        arch=arch,
+        wd=1e-3,
+        blur=True,
+        norm_type=NormType.Spectral,
+        self_attention=True,
+        y_range=(-3.0, 3.0),
+        loss_func=gen_loss,
+        nf_factor=nf_factor,
+    )
 
-#The code below is meant to be merged into fastaiv1 ideally
-def unet_learner_wide(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
-                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
-                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=False, nf_factor:int=1, **kwargs:Any)->Learner:
+
+# The code below is meant to be merged into fastaiv1 ideally
+def unet_learner_wide(
+    data: DataBunch,
+    arch: Callable,
+    pretrained: bool = True,
+    blur_final: bool = True,
+    norm_type: Optional[NormType] = NormType,
+    split_on: Optional[SplitFuncOrIdxList] = None,
+    blur: bool = False,
+    self_attention: bool = False,
+    y_range: Optional[Tuple[float, float]] = None,
+    last_cross: bool = True,
+    bottle: bool = False,
+    nf_factor: int = 1,
+    **kwargs: Any
+) -> Learner:
     "Build Unet learner from `data` and `arch`."
     meta = cnn_config(arch)
     body = create_body(arch, pretrained)
-    model = to_device(DynamicUnetWide(body, n_classes=data.c, blur=blur, blur_final=blur_final,
-          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
-          bottle=bottle, nf_factor=nf_factor), data.device)
+    model = to_device(
+        DynamicUnetWide(
+            body,
+            n_classes=data.c,
+            blur=blur,
+            blur_final=blur_final,
+            self_attention=self_attention,
+            y_range=y_range,
+            norm_type=norm_type,
+            last_cross=last_cross,
+            bottle=bottle,
+            nf_factor=nf_factor,
+        ),
+        data.device,
+    )
     learn = Learner(data, model, **kwargs)
-    learn.split(ifnone(split_on,meta['split']))
-    if pretrained: learn.freeze()
+    learn.split(ifnone(split_on, meta['split']))
+    if pretrained:
+        learn.freeze()
     apply_init(model[2], nn.init.kaiming_normal_)
     return learn
 
-#----------------------------------------------------------------------
 
-#Weights are implicitly read from ./models/ folder 
-def gen_inference_deep(root_folder:Path, weights_name:str, arch=models.resnet34, nf_factor:float=1.5)->Learner:
-      data = get_dummy_databunch()
-      learn = gen_learner_deep(data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor)
-      learn.path = root_folder
-      learn.load(weights_name)
-      learn.model.eval()
-      return learn
+# ----------------------------------------------------------------------
+
+# Weights are implicitly read from ./models/ folder
+def gen_inference_deep(
+    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5
+) -> Learner:
+    data = get_dummy_databunch()
+    learn = gen_learner_deep(
+        data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
+    )
+    learn.path = root_folder
+    learn.load(weights_name)
+    learn.model.eval()
+    return learn
+
+
+def gen_learner_deep(
+    data: ImageDataBunch,
+    gen_loss=FeatureLoss(),
+    arch=models.resnet34,
+    nf_factor: float = 1.5,
+) -> Learner:
+    return unet_learner_deep(
+        data,
+        arch,
+        wd=1e-3,
+        blur=True,
+        norm_type=NormType.Spectral,
+        self_attention=True,
+        y_range=(-3.0, 3.0),
+        loss_func=gen_loss,
+        nf_factor=nf_factor,
+    )
 
-def gen_learner_deep(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:float=1.5)->Learner:
-    return unet_learner_deep(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
-                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
 
-#The code below is meant to be merged into fastaiv1 ideally
-def unet_learner_deep(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
-                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
-                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=False, nf_factor:float=1.5, **kwargs:Any)->Learner:
+# The code below is meant to be merged into fastaiv1 ideally
+def unet_learner_deep(
+    data: DataBunch,
+    arch: Callable,
+    pretrained: bool = True,
+    blur_final: bool = True,
+    norm_type: Optional[NormType] = NormType,
+    split_on: Optional[SplitFuncOrIdxList] = None,
+    blur: bool = False,
+    self_attention: bool = False,
+    y_range: Optional[Tuple[float, float]] = None,
+    last_cross: bool = True,
+    bottle: bool = False,
+    nf_factor: float = 1.5,
+    **kwargs: Any
+) -> Learner:
     "Build Unet learner from `data` and `arch`."
     meta = cnn_config(arch)
     body = create_body(arch, pretrained)
-    model = to_device(DynamicUnetDeep(body, n_classes=data.c, blur=blur, blur_final=blur_final,
-          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
-          bottle=bottle, nf_factor=nf_factor), data.device)
+    model = to_device(
+        DynamicUnetDeep(
+            body,
+            n_classes=data.c,
+            blur=blur,
+            blur_final=blur_final,
+            self_attention=self_attention,
+            y_range=y_range,
+            norm_type=norm_type,
+            last_cross=last_cross,
+            bottle=bottle,
+            nf_factor=nf_factor,
+        ),
+        data.device,
+    )
     learn = Learner(data, model, **kwargs)
-    learn.split(ifnone(split_on,meta['split']))
-    if pretrained: learn.freeze()
+    learn.split(ifnone(split_on, meta['split']))
+    if pretrained:
+        learn.freeze()
     apply_init(model[2], nn.init.kaiming_normal_)
     return learn
 
-#-----------------------------
+
+# -----------------------------

+ 38 - 15
deoldify/layers.py

@@ -4,22 +4,45 @@ from torch.nn.parameter import Parameter
 from torch.autograd import Variable
 
 
-#The code below is meant to be merged into fastaiv1 ideally
+# The code below is meant to be merged into fastaiv1 ideally
 
-def custom_conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
-               norm_type:Optional[NormType]=NormType.Batch,  use_activ:bool=True, leaky:float=None,
-               transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False,
-               extra_bn:bool=False):
+
+def custom_conv_layer(
+    ni: int,
+    nf: int,
+    ks: int = 3,
+    stride: int = 1,
+    padding: int = None,
+    bias: bool = None,
+    is_1d: bool = False,
+    norm_type: Optional[NormType] = NormType.Batch,
+    use_activ: bool = True,
+    leaky: float = None,
+    transpose: bool = False,
+    init: Callable = nn.init.kaiming_normal_,
+    self_attention: bool = False,
+    extra_bn: bool = False,
+):
     "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
-    if padding is None: padding = (ks-1)//2 if not transpose else 0
-    bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn==True
-    if bias is None: bias = not bn
+    if padding is None:
+        padding = (ks - 1) // 2 if not transpose else 0
+    bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
+    if bias is None:
+        bias = not bn
     conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
-    conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
-    if   norm_type==NormType.Weight:   conv = weight_norm(conv)
-    elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
+    conv = init_default(
+        conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
+        init,
+    )
+    if norm_type == NormType.Weight:
+        conv = weight_norm(conv)
+    elif norm_type == NormType.Spectral:
+        conv = spectral_norm(conv)
     layers = [conv]
-    if use_activ: layers.append(relu(True, leaky=leaky))
-    if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
-    if self_attention: layers.append(SelfAttention(nf))
-    return nn.Sequential(*layers)
+    if use_activ:
+        layers.append(relu(True, leaky=leaky))
+    if bn:
+        layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
+    if self_attention:
+        layers.append(SelfAttention(nf))
+    return nn.Sequential(*layers)

+ 64 - 37
deoldify/loss.py

@@ -1,22 +1,26 @@
 from fastai import *
 from fastai.core import *
 from fastai.torch_core import *
-from fastai.callbacks  import hook_outputs
+from fastai.callbacks import hook_outputs
 import torchvision.models as models
 
 
 class FeatureLoss(nn.Module):
-    def __init__(self, layer_wgts=[20,70,10]):
+    def __init__(self, layer_wgts=[20, 70, 10]):
         super().__init__()
 
         self.m_feat = models.vgg16_bn(True).features.cuda().eval()
         requires_grad(self.m_feat, False)
-        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        blocks = [
+            i - 1
+            for i, o in enumerate(children(self.m_feat))
+            if isinstance(o, nn.MaxPool2d)
+        ]
         layer_ids = blocks[2:5]
         self.loss_features = [self.m_feat[i] for i in layer_ids]
         self.hooks = hook_outputs(self.loss_features, detach=False)
         self.wgts = layer_wgts
-        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] 
+        self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
         self.base_loss = F.l1_loss
 
     def _make_features(self, x, clone=False):
@@ -26,29 +30,40 @@ class FeatureLoss(nn.Module):
     def forward(self, input, target):
         out_feat = self._make_features(target, clone=True)
         in_feat = self._make_features(input)
-        self.feat_losses = [self.base_loss(input,target)]
-        self.feat_losses += [self.base_loss(f_in, f_out)*w
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-        
+        self.feat_losses = [self.base_loss(input, target)]
+        self.feat_losses += [
+            self.base_loss(f_in, f_out) * w
+            for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
+        ]
+
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
-    
-    def __del__(self): self.hooks.remove()
+
+    def __del__(self):
+        self.hooks.remove()
 
 
-#Includes wasserstein loss
+# Includes wasserstein loss
 class WassFeatureLoss(nn.Module):
-    def __init__(self, layer_wgts=[5,15,2], wass_wgts=[3.0,0.7,0.01]):
+    def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
         super().__init__()
         self.m_feat = models.vgg16_bn(True).features.cuda().eval()
         requires_grad(self.m_feat, False)
-        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        blocks = [
+            i - 1
+            for i, o in enumerate(children(self.m_feat))
+            if isinstance(o, nn.MaxPool2d)
+        ]
         layer_ids = blocks[2:5]
         self.loss_features = [self.m_feat[i] for i in layer_ids]
         self.hooks = hook_outputs(self.loss_features, detach=False)
         self.wgts = layer_wgts
         self.wass_wgts = wass_wgts
-        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'wass_{i}' for i in range(len(layer_ids))]
+        self.metric_names = (
+            ['pixel']
+            + [f'feat_{i}' for i in range(len(layer_ids))]
+            + [f'wass_{i}' for i in range(len(layer_ids))]
+        )
         self.base_loss = F.l1_loss
 
     def _make_features(self, x, clone=False):
@@ -58,52 +73,64 @@ class WassFeatureLoss(nn.Module):
     def _calc_2_moments(self, tensor):
         chans = tensor.shape[1]
         tensor = tensor.view(1, chans, -1)
-        n = tensor.shape[2] 
+        n = tensor.shape[2]
         mu = tensor.mean(2)
-        tensor = (tensor - mu[:,:,None]).squeeze(0)
-        #Prevents nasty bug that happens very occassionally- divide by zero.  Why such things happen?
-        if n == 0: return None, None
-        cov = torch.mm(tensor, tensor.t()) / float(n) 
+        tensor = (tensor - mu[:, :, None]).squeeze(0)
+        # Prevents nasty bug that happens very occassionally- divide by zero.  Why such things happen?
+        if n == 0:
+            return None, None
+        cov = torch.mm(tensor, tensor.t()) / float(n)
         return mu, cov
 
     def _get_style_vals(self, tensor):
-        mean, cov = self._calc_2_moments(tensor) 
+        mean, cov = self._calc_2_moments(tensor)
         if mean is None:
             return None, None, None
         eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
-        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))     
-        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())  
-        tr_cov = eigvals.clamp(min=0).sum() 
+        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
+        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
+        tr_cov = eigvals.clamp(min=0).sum()
         return mean, tr_cov, root_cov
 
-    def _calc_l2wass_dist(self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth):
+    def _calc_l2wass_dist(
+        self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
+    ):
         tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
         mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
         cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
-        var_overlap = torch.sqrt(torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0)+1e-8).sum()
-        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2*var_overlap
+        var_overlap = torch.sqrt(
+            torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
+        ).sum()
+        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
         return dist
 
     def _single_wass_loss(self, pred, targ):
         mean_test, tr_cov_test, root_cov_test = targ
         mean_synth, cov_synth = self._calc_2_moments(pred)
-        loss = self._calc_l2wass_dist(mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth)
+        loss = self._calc_l2wass_dist(
+            mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
+        )
         return loss
-    
+
     def forward(self, input, target):
         out_feat = self._make_features(target, clone=True)
         in_feat = self._make_features(input)
-        self.feat_losses = [self.base_loss(input,target)]
-        self.feat_losses += [self.base_loss(f_in, f_out)*w
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-        
+        self.feat_losses = [self.base_loss(input, target)]
+        self.feat_losses += [
+            self.base_loss(f_in, f_out) * w
+            for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
+        ]
+
         styles = [self._get_style_vals(i) for i in out_feat]
 
         if styles[0][0] is not None:
-            self.feat_losses += [self._single_wass_loss(f_pred, f_targ)*w
-                                for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)]
-        
+            self.feat_losses += [
+                self._single_wass_loss(f_pred, f_targ) * w
+                for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
+            ]
+
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
-    
-    def __del__(self): self.hooks.remove()
+
+    def __del__(self):
+        self.hooks.remove()

+ 169 - 58
deoldify/unet.py

@@ -5,51 +5,83 @@ from fastai.callbacks.hooks import *
 from fastai.vision import *
 
 
-#The code below is meant to be merged into fastaiv1 ideally
+# The code below is meant to be merged into fastaiv1 ideally
 
 __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
 
-def _get_sfs_idxs(sizes:Sizes) -> List[int]:
+
+def _get_sfs_idxs(sizes: Sizes) -> List[int]:
     "Get the indexes of the layers where the size of the activation changes."
     feature_szs = [size[-1] for size in sizes]
-    sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
-    if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
+    sfs_idxs = list(
+        np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
+    )
+    if feature_szs[0] != feature_szs[1]:
+        sfs_idxs = [0] + sfs_idxs
     return sfs_idxs
 
+
 class CustomPixelShuffle_ICNR(nn.Module):
     "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
-    def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
+
+    def __init__(
+        self,
+        ni: int,
+        nf: int = None,
+        scale: int = 2,
+        blur: bool = False,
+        leaky: float = None,
+        **kwargs
+    ):
         super().__init__()
         nf = ifnone(nf, ni)
-        self.conv = custom_conv_layer(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
+        self.conv = custom_conv_layer(
+            ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
+        )
         icnr(self.conv[0].weight)
         self.shuf = nn.PixelShuffle(scale)
         # Blurring over (h*w) kernel
         # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
         # - https://arxiv.org/abs/1806.02658
-        self.pad = nn.ReplicationPad2d((1,0,1,0))
+        self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
         self.blur = nn.AvgPool2d(2, stride=1)
         self.relu = relu(True, leaky=leaky)
 
-    def forward(self,x):
+    def forward(self, x):
         x = self.shuf(self.relu(self.conv(x)))
         return self.blur(self.pad(x)) if self.blur else x
 
+
 class UnetBlockDeep(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
-    def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
-                 self_attention:bool=False, nf_factor:float=1.0,  **kwargs):
+
+    def __init__(
+        self,
+        up_in_c: int,
+        x_in_c: int,
+        hook: Hook,
+        final_div: bool = True,
+        blur: bool = False,
+        leaky: float = None,
+        self_attention: bool = False,
+        nf_factor: float = 1.0,
+        **kwargs
+    ):
         super().__init__()
         self.hook = hook
-        self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
+        self.shuf = CustomPixelShuffle_ICNR(
+            up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
+        )
         self.bn = batchnorm_2d(x_in_c)
-        ni = up_in_c//2 + x_in_c
-        nf = int((ni if final_div else ni//2)*nf_factor)
+        ni = up_in_c // 2 + x_in_c
+        nf = int((ni if final_div else ni // 2) * nf_factor)
         self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
-        self.conv2 = custom_conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.conv2 = custom_conv_layer(
+            nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
+        )
         self.relu = relu(leaky=leaky)
 
-    def forward(self, up_in:Tensor) -> Tensor:
+    def forward(self, up_in: Tensor) -> Tensor:
         s = self.hook.stored
         up_out = self.shuf(up_in)
         ssh = s.shape[-2:]
@@ -61,63 +93,109 @@ class UnetBlockDeep(nn.Module):
 
 class DynamicUnetDeep(SequentialEx):
     "Create a U-Net from a given architecture."
-    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
-                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
-                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
-        extra_bn =  norm_type == NormType.Spectral
-        imsize = (256,256)
+
+    def __init__(
+        self,
+        encoder: nn.Module,
+        n_classes: int,
+        blur: bool = False,
+        blur_final=True,
+        self_attention: bool = False,
+        y_range: Optional[Tuple[float, float]] = None,
+        last_cross: bool = True,
+        bottle: bool = False,
+        norm_type: Optional[NormType] = NormType.Batch,
+        nf_factor: float = 1.0,
+        **kwargs
+    ):
+        extra_bn = norm_type == NormType.Spectral
+        imsize = (256, 256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
         self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        middle_conv = nn.Sequential(
+            custom_conv_layer(
+                ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
+            ),
+            custom_conv_layer(
+                ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
+            ),
+        ).eval()
         x = middle_conv(x)
         layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
 
-        for i,idx in enumerate(sfs_idxs):
-            not_final = i!=len(sfs_idxs)-1
+        for i, idx in enumerate(sfs_idxs):
+            not_final = i != len(sfs_idxs) - 1
             up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
             do_blur = blur and (not_final or blur_final)
-            sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = UnetBlockDeep(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
+            sa = self_attention and (i == len(sfs_idxs) - 3)
+            unet_block = UnetBlockDeep(
+                up_in_c,
+                x_in_c,
+                self.sfs[i],
+                final_div=not_final,
+                blur=blur,
+                self_attention=sa,
+                norm_type=norm_type,
+                extra_bn=extra_bn,
+                nf_factor=nf_factor,
+                **kwargs
+            ).eval()
             layers.append(unet_block)
             x = unet_block(x)
 
         ni = x.shape[1]
-        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if imsize != sfs_szs[0][-2:]:
+            layers.append(PixelShuffle_ICNR(ni, **kwargs))
         if last_cross:
             layers.append(MergeLayer(dense=True))
             ni += in_channels(encoder)
             layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
-        layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
-        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        layers += [
+            custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
+        ]
+        if y_range is not None:
+            layers.append(SigmoidRange(*y_range))
         super().__init__(*layers)
 
     def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()
+        if hasattr(self, "sfs"):
+            self.sfs.remove()
 
 
-
-
-#------------------------------------------------------
+# ------------------------------------------------------
 class UnetBlockWide(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
-    def __init__(self, up_in_c:int, x_in_c:int, n_out:int,  hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
-                 self_attention:bool=False,  **kwargs):
+
+    def __init__(
+        self,
+        up_in_c: int,
+        x_in_c: int,
+        n_out: int,
+        hook: Hook,
+        final_div: bool = True,
+        blur: bool = False,
+        leaky: float = None,
+        self_attention: bool = False,
+        **kwargs
+    ):
         super().__init__()
         self.hook = hook
-        up_out = x_out = n_out//2
-        self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, leaky=leaky, **kwargs)
+        up_out = x_out = n_out // 2
+        self.shuf = CustomPixelShuffle_ICNR(
+            up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
+        )
         self.bn = batchnorm_2d(x_in_c)
         ni = up_out + x_in_c
-        self.conv = custom_conv_layer(ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.conv = custom_conv_layer(
+            ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
+        )
         self.relu = relu(leaky=leaky)
 
-    def forward(self, up_in:Tensor) -> Tensor:
+    def forward(self, up_in: Tensor) -> Tensor:
         s = self.hook.stored
         up_out = self.shuf(up_in)
         ssh = s.shape[-2:]
@@ -129,46 +207,79 @@ class UnetBlockWide(nn.Module):
 
 class DynamicUnetWide(SequentialEx):
     "Create a U-Net from a given architecture."
-    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
-                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
-                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:int=1, **kwargs):
-        
+
+    def __init__(
+        self,
+        encoder: nn.Module,
+        n_classes: int,
+        blur: bool = False,
+        blur_final=True,
+        self_attention: bool = False,
+        y_range: Optional[Tuple[float, float]] = None,
+        last_cross: bool = True,
+        bottle: bool = False,
+        norm_type: Optional[NormType] = NormType.Batch,
+        nf_factor: int = 1,
+        **kwargs
+    ):
+
         nf = 512 * nf_factor
-        extra_bn =  norm_type == NormType.Spectral
-        imsize = (256,256)
+        extra_bn = norm_type == NormType.Spectral
+        imsize = (256, 256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
         self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        middle_conv = nn.Sequential(
+            custom_conv_layer(
+                ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
+            ),
+            custom_conv_layer(
+                ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
+            ),
+        ).eval()
         x = middle_conv(x)
         layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
 
-        for i,idx in enumerate(sfs_idxs):
-            not_final = i!=len(sfs_idxs)-1
+        for i, idx in enumerate(sfs_idxs):
+            not_final = i != len(sfs_idxs) - 1
             up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
             do_blur = blur and (not_final or blur_final)
-            sa = self_attention and (i==len(sfs_idxs)-3)
+            sa = self_attention and (i == len(sfs_idxs) - 3)
 
-            n_out = nf if not_final else nf//2
+            n_out = nf if not_final else nf // 2
 
-            unet_block = UnetBlockWide(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
+            unet_block = UnetBlockWide(
+                up_in_c,
+                x_in_c,
+                n_out,
+                self.sfs[i],
+                final_div=not_final,
+                blur=blur,
+                self_attention=sa,
+                norm_type=norm_type,
+                extra_bn=extra_bn,
+                **kwargs
+            ).eval()
             layers.append(unet_block)
             x = unet_block(x)
 
         ni = x.shape[1]
-        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if imsize != sfs_szs[0][-2:]:
+            layers.append(PixelShuffle_ICNR(ni, **kwargs))
         if last_cross:
             layers.append(MergeLayer(dense=True))
             ni += in_channels(encoder)
             layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
-        layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
-        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        layers += [
+            custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
+        ]
+        if y_range is not None:
+            layers.append(SigmoidRange(*y_range))
         super().__init__(*layers)
 
     def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()
+        if hasattr(self, "sfs"):
+            self.sfs.remove()

+ 228 - 101
deoldify/visualize.py

@@ -7,7 +7,7 @@ from .filters import IFilter, MasterFilter, ColorizerFilter
 from .generators import gen_inference_deep, gen_inference_wide
 from tensorboardX import SummaryWriter
 from scipy import misc
-from PIL import Image 
+from PIL import Image
 import ffmpeg
 import youtube_dl
 import gc
@@ -18,212 +18,339 @@ from IPython import display as ipythondisplay
 from IPython.display import HTML
 from IPython.display import Image as ipythonimage
 
-class ModelImageVisualizer():
-    def __init__(self, filter:IFilter, results_dir:str=None):
+
+class ModelImageVisualizer:
+    def __init__(self, filter: IFilter, results_dir: str = None):
         self.filter = filter
-        self.results_dir=None if results_dir is None else Path(results_dir)
+        self.results_dir = None if results_dir is None else Path(results_dir)
         self.results_dir.mkdir(parents=True, exist_ok=True)
-    
+
     def _clean_mem(self):
         torch.cuda.empty_cache()
-        #gc.collect()
+        # gc.collect()
 
-    def _open_pil_image(self, path:Path)->Image:
+    def _open_pil_image(self, path: Path) -> Image:
         return PIL.Image.open(path).convert('RGB')
 
-    def _get_image_from_url(self, url:str)->Image:
+    def _get_image_from_url(self, url: str) -> Image:
         response = requests.get(url, timeout=30)
         img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
         return img
 
-    def plot_transformed_image_from_url(self, url:str, path:str='test_images/image.png', figsize:(int,int)=(20,20), 
-            render_factor:int=None, display_render_factor:bool=False, compare:bool=False)->Path:
+    def plot_transformed_image_from_url(
+        self,
+        url: str,
+        path: str = 'test_images/image.png',
+        figsize: (int, int) = (20, 20),
+        render_factor: int = None,
+        display_render_factor: bool = False,
+        compare: bool = False,
+    ) -> Path:
         img = self._get_image_from_url(url)
         img.save(path)
-        return self.plot_transformed_image(path=path, figsize=figsize, render_factor=render_factor, 
-                                            display_render_factor=display_render_factor, compare=compare)
-
-    def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None, 
-                            display_render_factor:bool=False, compare:bool=False)->Path:
+        return self.plot_transformed_image(
+            path=path,
+            figsize=figsize,
+            render_factor=render_factor,
+            display_render_factor=display_render_factor,
+            compare=compare,
+        )
+
+    def plot_transformed_image(
+        self,
+        path: str,
+        figsize: (int, int) = (20, 20),
+        render_factor: int = None,
+        display_render_factor: bool = False,
+        compare: bool = False,
+    ) -> Path:
         path = Path(path)
         result = self.get_transformed_image(path, render_factor)
         orig = self._open_pil_image(path)
-        if compare: 
-            self._plot_comparison(figsize, render_factor, display_render_factor, orig, result)
+        if compare:
+            self._plot_comparison(
+                figsize, render_factor, display_render_factor, orig, result
+            )
         else:
             self._plot_solo(figsize, render_factor, display_render_factor, result)
 
         return self._save_result_image(path, result)
 
-    def _plot_comparison(self, figsize:(int,int), render_factor:int, display_render_factor:bool, orig:Image, result:Image):
-        fig,axes = plt.subplots(1, 2, figsize=figsize)
-        self._plot_image(orig, axes=axes[0], figsize=figsize, render_factor=render_factor, display_render_factor=False)
-        self._plot_image(result, axes=axes[1], figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
- 
-    def _plot_solo(self, figsize:(int,int), render_factor:int, display_render_factor:bool, result:Image):
-        fig,axes = plt.subplots(1, 1, figsize=figsize)
-        self._plot_image(result, axes=axes, figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
-
-    def _save_result_image(self, source_path:Path, image:Image)->Path:
-        result_path = self.results_dir/source_path.name
+    def _plot_comparison(
+        self,
+        figsize: (int, int),
+        render_factor: int,
+        display_render_factor: bool,
+        orig: Image,
+        result: Image,
+    ):
+        fig, axes = plt.subplots(1, 2, figsize=figsize)
+        self._plot_image(
+            orig,
+            axes=axes[0],
+            figsize=figsize,
+            render_factor=render_factor,
+            display_render_factor=False,
+        )
+        self._plot_image(
+            result,
+            axes=axes[1],
+            figsize=figsize,
+            render_factor=render_factor,
+            display_render_factor=display_render_factor,
+        )
+
+    def _plot_solo(
+        self,
+        figsize: (int, int),
+        render_factor: int,
+        display_render_factor: bool,
+        result: Image,
+    ):
+        fig, axes = plt.subplots(1, 1, figsize=figsize)
+        self._plot_image(
+            result,
+            axes=axes,
+            figsize=figsize,
+            render_factor=render_factor,
+            display_render_factor=display_render_factor,
+        )
+
+    def _save_result_image(self, source_path: Path, image: Image) -> Path:
+        result_path = self.results_dir / source_path.name
         image.save(result_path)
         return result_path
 
-    def get_transformed_image(self, path:Path, render_factor:int=None)->Image:
+    def get_transformed_image(self, path: Path, render_factor: int = None) -> Image:
         self._clean_mem()
         orig_image = self._open_pil_image(path)
-        filtered_image = self.filter.filter(orig_image, orig_image, render_factor=render_factor)
+        filtered_image = self.filter.filter(
+            orig_image, orig_image, render_factor=render_factor
+        )
         return filtered_image
 
-    def _plot_image(self, image:Image, render_factor:int, axes:Axes=None, figsize=(20,20), display_render_factor:bool=False):
-        if axes is None: 
-            _,axes = plt.subplots(figsize=figsize)
-        axes.imshow(np.asarray(image)/255)
+    def _plot_image(
+        self,
+        image: Image,
+        render_factor: int,
+        axes: Axes = None,
+        figsize=(20, 20),
+        display_render_factor: bool = False,
+    ):
+        if axes is None:
+            _, axes = plt.subplots(figsize=figsize)
+        axes.imshow(np.asarray(image) / 255)
         axes.axis('off')
         if render_factor is not None and display_render_factor:
-            plt.text(10,10,'render_factor: ' + str(render_factor), color='white', backgroundcolor='black')
-
-    def _get_num_rows_columns(self, num_images:int, max_columns:int)->(int,int):
+            plt.text(
+                10,
+                10,
+                'render_factor: ' + str(render_factor),
+                color='white',
+                backgroundcolor='black',
+            )
+
+    def _get_num_rows_columns(self, num_images: int, max_columns: int) -> (int, int):
         columns = min(num_images, max_columns)
-        rows = num_images//columns
+        rows = num_images // columns
         rows = rows if rows * columns == num_images else rows + 1
         return rows, columns
 
-class VideoColorizer():
-    def __init__(self, vis:ModelImageVisualizer):
-        self.vis=vis
+
+class VideoColorizer:
+    def __init__(self, vis: ModelImageVisualizer):
+        self.vis = vis
         workfolder = Path('./video')
-        self.source_folder = workfolder/"source"
-        self.bwframes_root = workfolder/"bwframes"
-        self.audio_root = workfolder/"audio"
-        self.colorframes_root = workfolder/"colorframes"
-        self.result_folder = workfolder/"result"
+        self.source_folder = workfolder / "source"
+        self.bwframes_root = workfolder / "bwframes"
+        self.audio_root = workfolder / "audio"
+        self.colorframes_root = workfolder / "colorframes"
+        self.result_folder = workfolder / "result"
 
     def _purge_images(self, dir):
         for f in os.listdir(dir):
             if re.search('.*?\.jpg', f):
                 os.remove(os.path.join(dir, f))
 
-    def _get_fps(self, source_path: Path)->str:
+    def _get_fps(self, source_path: Path) -> str:
         probe = ffmpeg.probe(str(source_path))
-        stream_data = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
+        stream_data = next(
+            (stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
+            None,
+        )
         return stream_data['avg_frame_rate']
 
-    def _download_video_from_url(self, source_url, source_path:Path):
-        if source_path.exists(): source_path.unlink()
+    def _download_video_from_url(self, source_url, source_path: Path):
+        if source_path.exists():
+            source_path.unlink()
 
-        ydl_opts = {    
-            'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',     
-            'outtmpl': str(source_path)   
-            }
+        ydl_opts = {
+            'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
+            'outtmpl': str(source_path),
+        }
         with youtube_dl.YoutubeDL(ydl_opts) as ydl:
             ydl.download([source_url])
 
-    def _extract_raw_frames(self, source_path:Path):
-        bwframes_folder = self.bwframes_root/(source_path.stem)
-        bwframe_path_template = str(bwframes_folder/'%5d.jpg')
+    def _extract_raw_frames(self, source_path: Path):
+        bwframes_folder = self.bwframes_root / (source_path.stem)
+        bwframe_path_template = str(bwframes_folder / '%5d.jpg')
         bwframes_folder.mkdir(parents=True, exist_ok=True)
         self._purge_images(bwframes_folder)
-        ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)
-
+        ffmpeg.input(str(source_path)).output(
+            str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0
+        ).run(capture_stdout=True)
 
-    def _colorize_raw_frames(self, source_path:Path, render_factor:int=None):
-        colorframes_folder = self.colorframes_root/(source_path.stem)
+    def _colorize_raw_frames(self, source_path: Path, render_factor: int = None):
+        colorframes_folder = self.colorframes_root / (source_path.stem)
         colorframes_folder.mkdir(parents=True, exist_ok=True)
         self._purge_images(colorframes_folder)
-        bwframes_folder = self.bwframes_root/(source_path.stem)
+        bwframes_folder = self.bwframes_root / (source_path.stem)
 
         for img in progress_bar(os.listdir(str(bwframes_folder))):
-            img_path = bwframes_folder/img
+            img_path = bwframes_folder / img
             if os.path.isfile(str(img_path)):
-                color_image = self.vis.get_transformed_image(str(img_path), render_factor=render_factor)
-                color_image.save(str(colorframes_folder/img))
-    
-    def _build_video(self, source_path:Path)->Path:
-        colorized_path = self.result_folder/(source_path.name.replace('.mp4', '_no_audio.mp4'))
-        colorframes_folder = self.colorframes_root/(source_path.stem)
-        colorframes_path_template = str(colorframes_folder/'%5d.jpg')
+                color_image = self.vis.get_transformed_image(
+                    str(img_path), render_factor=render_factor
+                )
+                color_image.save(str(colorframes_folder / img))
+
+    def _build_video(self, source_path: Path) -> Path:
+        colorized_path = self.result_folder / (
+            source_path.name.replace('.mp4', '_no_audio.mp4')
+        )
+        colorframes_folder = self.colorframes_root / (source_path.stem)
+        colorframes_path_template = str(colorframes_folder / '%5d.jpg')
         colorized_path.parent.mkdir(parents=True, exist_ok=True)
-        if colorized_path.exists(): colorized_path.unlink()
+        if colorized_path.exists():
+            colorized_path.unlink()
         fps = self._get_fps(source_path)
 
-        ffmpeg.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=fps) \
-            .output(str(colorized_path), crf=17, vcodec='libx264') \
-            .run(capture_stdout=True)
-
-        result_path = self.result_folder/source_path.name
-        if result_path.exists(): result_path.unlink()
-        #making copy of non-audio version in case adding back audio doesn't apply or fails.
+        ffmpeg.input(
+            str(colorframes_path_template),
+            format='image2',
+            vcodec='mjpeg',
+            framerate=fps,
+        ).output(str(colorized_path), crf=17, vcodec='libx264').run(capture_stdout=True)
+
+        result_path = self.result_folder / source_path.name
+        if result_path.exists():
+            result_path.unlink()
+        # making copy of non-audio version in case adding back audio doesn't apply or fails.
         shutil.copyfile(str(colorized_path), str(result_path))
 
         # adding back sound here
         audio_file = Path(str(source_path).replace('.mp4', '.aac'))
-        if audio_file.exists(): audio_file.unlink()
+        if audio_file.exists():
+            audio_file.unlink()
 
-        os.system('ffmpeg -y -i "' + str(source_path) + '" -vn -acodec copy "' + str(audio_file) + '"')
+        os.system(
+            'ffmpeg -y -i "'
+            + str(source_path)
+            + '" -vn -acodec copy "'
+            + str(audio_file)
+            + '"'
+        )
 
         if audio_file.exists:
-            os.system('ffmpeg -y -i "' + str(colorized_path) + '" -i "' + str(audio_file) 
-                + '" -shortest -c:v copy -c:a aac -b:a 256k "' + str(result_path) + '"')
+            os.system(
+                'ffmpeg -y -i "'
+                + str(colorized_path)
+                + '" -i "'
+                + str(audio_file)
+                + '" -shortest -c:v copy -c:a aac -b:a 256k "'
+                + str(result_path)
+                + '"'
+            )
         print('Video created here: ' + str(result_path))
         return result_path
 
-    def colorize_from_url(self, source_url, file_name:str, render_factor:int=None)->Path: 
-        source_path =  self.source_folder/file_name
+    def colorize_from_url(
+        self, source_url, file_name: str, render_factor: int = None
+    ) -> Path:
+        source_path = self.source_folder / file_name
         self._download_video_from_url(source_url, source_path)
         return self._colorize_from_path(source_path, render_factor=render_factor)
 
-    def colorize_from_file_name(self, file_name:str, render_factor:int=None)->Path:
-        source_path =  self.source_folder/file_name
+    def colorize_from_file_name(
+        self, file_name: str, render_factor: int = None
+    ) -> Path:
+        source_path = self.source_folder / file_name
         return self._colorize_from_path(source_path, render_factor=render_factor)
 
-    def _colorize_from_path(self, source_path:Path, render_factor:int=None)->Path:
+    def _colorize_from_path(self, source_path: Path, render_factor: int = None) -> Path:
         if not source_path.exists():
-            raise Exception('Video at path specfied, ' + str(source_path) + ' could not be found.')
+            raise Exception(
+                'Video at path specfied, ' + str(source_path) + ' could not be found.'
+            )
 
         self._extract_raw_frames(source_path)
         self._colorize_raw_frames(source_path, render_factor=render_factor)
         return self._build_video(source_path)
 
 
-def get_video_colorizer(render_factor:int=21)->VideoColorizer:
+def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
     return get_stable_video_colorizer(render_factor=render_factor)
 
-def get_stable_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeVideo_gen', 
-        results_dir='result_images', render_factor:int=21)->VideoColorizer:
+
+def get_stable_video_colorizer(
+    root_folder: Path = Path('./'),
+    weights_name: str = 'ColorizeVideo_gen',
+    results_dir='result_images',
+    render_factor: int = 21,
+) -> VideoColorizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return VideoColorizer(vis)
 
-def get_image_colorizer(render_factor:int=35, artistic:bool=True)->ModelImageVisualizer:
+
+def get_image_colorizer(
+    render_factor: int = 35, artistic: bool = True
+) -> ModelImageVisualizer:
     if artistic:
         return get_artistic_image_colorizer(render_factor=render_factor)
     else:
         return get_stable_image_colorizer(render_factor=render_factor)
 
-def get_stable_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeStable_gen', 
-        results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
+
+def get_stable_image_colorizer(
+    root_folder: Path = Path('./'),
+    weights_name: str = 'ColorizeStable_gen',
+    results_dir='result_images',
+    render_factor: int = 35,
+) -> ModelImageVisualizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
 
-def get_artistic_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeArtistic_gen', 
-        results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
+
+def get_artistic_image_colorizer(
+    root_folder: Path = Path('./'),
+    weights_name: str = 'ColorizeArtistic_gen',
+    results_dir='result_images',
+    render_factor: int = 35,
+) -> ModelImageVisualizer:
     learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
 
-def show_image_in_notebook(image_path:Path):
+
+def show_image_in_notebook(image_path: Path):
     ipythondisplay.display(ipythonimage(str(image_path)))
 
-def show_video_in_notebook(video_path:Path):
+
+def show_video_in_notebook(video_path: Path):
     video = io.open(video_path, 'r+b').read()
     encoded = base64.b64encode(video)
-    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
+    ipythondisplay.display(
+        HTML(
+            data='''<video alt="test" autoplay 
                 loop controls style="height: 400px;">
                 <source src="data:video/mp4;base64,{0}" type="video/mp4" />
-             </video>'''.format(encoded.decode('ascii'))))
-
+             </video>'''.format(
+                encoded.decode('ascii')
+            )
+        )
+    )