Explorar o código

Getting rid of dumb looking padding in output images!

Jason Antic %!s(int64=6) %!d(string=hai) anos
pai
achega
8340aa0e86
Modificáronse 1 ficheiros con 28 adicións e 78 borrados
  1. 28 78
      fasterai/generators.py

+ 28 - 78
fasterai/generators.py

@@ -49,7 +49,7 @@ class Unet34(GeneratorModule):
         self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
 
     #Gets around irritating inconsistent halving coming from resnet
-    def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
+    def _pad(self, x:torch.Tensor, target:torch.Tensor, total_padh:int, total_padw:int)-> torch.Tensor:
         h = x.shape[2] 
         w = x.shape[3]
 
@@ -58,10 +58,20 @@ class Unet34(GeneratorModule):
 
         if h<target_h or w<target_w:
             padh = target_h-h if target_h > h else 0
+            total_padh = total_padh + padh
             padw = target_w-w if target_w > w else 0
-            return F.pad(x, (0,padw,0,padh), "constant",0)
-
-        return x
+            total_padw = total_padw + padw
+            return (F.pad(x, (0,padw,0,padh), "reflect",0), total_padh, total_padw)
+
+        return (x, total_padh, total_padw)
+
+    def _remove_padding(self, x:torch.Tensor, padh:int, padw:int)->torch.Tensor:
+        if padw == 0 and padh == 0:
+            return x 
+        
+        target_h = x.shape[2]-padh
+        target_w = x.shape[3]-padw
+        return x[:,:,:target_h, :target_w]
            
     def forward(self, x_in:torch.Tensor):
         x = self.rn[0](x_in)
@@ -77,82 +87,21 @@ class Unet34(GeneratorModule):
         enc3 = x
         x = self.rn[7](x)
 
-        x = self.relu(x)
-        x = self.up1(x, self._pad(enc3, x))
-        x = self.up2(x, self._pad(enc2, x))
-        x = self.up3(x, self._pad(enc1, x))
-        x = self.up4(x, self._pad(enc0, x))
-        x = self.up5(x)
-        x = self.out(x)
-        return x
-    
-    def get_layer_groups(self, precompute:bool=False)->[]:
-        lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
-        return lgs + [children(self)[1:]]
-    
-    def close(self):
-        for sf in self.sfs: 
-            sf.remove()
-
-
-class Unet34_V2(GeneratorModule): 
-    @staticmethod
-    def get_pretrained_resnet_base(layers_cut:int=0):
-        f = resnet34
-        cut,lr_cut = model_meta[f]
-        cut-=layers_cut
-        layers = cut_model(f(True), cut)
-        return nn.Sequential(*layers), lr_cut
-
-    def __init__(self, nf_factor:int=1, scale:int=1):
-        super().__init__()
-        assert (math.log(scale,2)).is_integer()
-        leakyReLu=False
-        self_attention=True
-        bn=True
-        sn=True
-        self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
-        self.relu = nn.ReLU()
-        self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
-        self.up4 = UnetBlock(256*nf_factor,64,128*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up5 = UpSampleBlock(128*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn) 
-        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
-
-    #Gets around irritating inconsistent halving coming from resnet
-    def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
-        h = x.shape[2] 
-        w = x.shape[3]
-        target_h = target.shape[2]*2
-        target_w = target.shape[3]*2
+        padw = 0
+        padh = 0
 
-        if h<target_h or w<target_w:
-            padh = target_h-h if target_h > h else 0
-            padw = target_w-w if target_w > w else 0
-            return F.pad(x, (0,padw,0,padh), "constant",0)
+        x = self.relu(x)
+        penc3, padh, padw = self._pad(enc3, x, padh, padw)
+        x = self.up1(x, penc3)
+        penc2, padh, padw  = self._pad(enc2, x, padh, padw)
+        x = self.up2(x, penc2)
+        penc1, padh, padw  = self._pad(enc1, x, padh, padw)
+        x = self.up3(x, penc1)
+        penc0, padh, padw  = self._pad(enc0, x, padh, padw)
+        x = self.up4(x, penc0)
 
-        return x
-           
-    def forward(self, x_in:torch.Tensor):
-        x = self.rn[0](x_in)
-        x = self.rn[1](x)
-        x = self.rn[2](x)
-        enc0 = x
-        x = self.rn[3](x)
-        x = self.rn[4](x)
-        enc1 = x
-        x = self.rn[5](x)
-        enc2 = x
-        x = self.rn[6](x)
-        enc3 = x
-        x = self.rn[7](x)
+        x = self._remove_padding(x, padh, padw)
 
-        x = self.relu(x)
-        x = self.up1(x, self._pad(enc3, x))
-        x = self.up2(x, self._pad(enc2, x))
-        x = self.up3(x, self._pad(enc1, x))
-        x = self.up4(x, self._pad(enc0, x))
         x = self.up5(x)
         x = self.out(x)
         return x
@@ -163,4 +112,5 @@ class Unet34_V2(GeneratorModule):
     
     def close(self):
         for sf in self.sfs: 
-            sf.remove()
+            sf.remove()
+