Browse Source

Adding Unet152 and Colorizer152

Jason Antic 6 years ago
parent
commit
34b7eae5d7
3 changed files with 34 additions and 2 deletions
  1. 1 0
      .gitignore
  2. 9 1
      fasterai/filters.py
  3. 24 1
      fasterai/generators.py

+ 1 - 0
.gitignore

@@ -413,3 +413,4 @@ result_images/Sami1880s.jpg
 result_images/Scotland1919.jpg
 result_images/SenecaNative1908.jpg
 result_images/TitanicGym.jpg
+fasterai/fastai

+ 9 - 1
fasterai/filters.py

@@ -1,6 +1,6 @@
 from numpy import ndarray
 from abc import ABC, abstractmethod
-from .generators import Unet34, Unet101, GeneratorModule
+from .generators import Unet34, Unet101, Unet152, GeneratorModule
 from .transforms import BlackAndWhiteTransform
 from fastai.torch_imports import *
 from fastai.core import *
@@ -132,6 +132,14 @@ class Colorizer101(AbstractColorizer):
         return Unet101(nf_factor=nf_factor).cuda(gpu)
 
 
+class Colorizer152(AbstractColorizer):
+    def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
+        super().__init__(gpu=gpu, weights_path=weights_path, nf_factor=nf_factor, map_to_orig=map_to_orig)
+
+    def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
+        return Unet152(nf_factor=nf_factor).cuda(gpu)
+
+
 #TODO:  May not want to do square rendering here like in colorization- it definitely loses 
 #fidelity visibly (but not too terribly).  Will revisit.
 class DeFader(Filter): 

+ 24 - 1
fasterai/generators.py

@@ -20,7 +20,7 @@ class GeneratorModule(ABC, nn.Module):
     @abstractmethod
     def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
         pass
-
+        
     def freeze_to(self, n:int):
         c=self.get_layer_groups()
         for l in c:     set_trainable(l, False)
@@ -176,3 +176,26 @@ class Unet101(AbstractUnet):
         layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
         return layers 
 
+class Unet152(AbstractUnet): 
+    def __init__(self, nf_factor:int=1, scale:int=1):
+        super().__init__(nf_factor=nf_factor, scale=scale)
+
+    def _get_pretrained_resnet_base(self, layers_cut:int=0):
+        f = resnet152
+        cut,lr_cut = model_meta[f]
+        cut-=layers_cut
+        layers = cut_model(f(True), cut)
+        return nn.Sequential(*layers), lr_cut
+
+    def _get_decoding_layers(self, nf_factor:int, scale:int):
+        self_attention=True
+        bn=True
+        sn=True
+        leakyReLu=False
+        layers = []
+        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        return layers