Selaa lähdekoodia

Better implementation of generator padding; Adding convenient device id access to generator/critic

Jason Antic 6 vuotta sitten
vanhempi
sitoutus
210c33dd05
3 muutettua tiedostoa jossa 9 lisäystä ja 5 poistoa
  1. 1 0
      .gitignore
  2. 5 5
      fasterai/generators.py
  3. 3 0
      fasterai/training.py

+ 1 - 0
.gitignore

@@ -27,3 +27,4 @@ fasterai/SymbolicLinks.sh
 SymbolicLinks.sh
 .ipynb_checkpoints/README-checkpoint.md
 .ipynb_checkpoints/ComboVisualization-checkpoint.ipynb
+.ipynb_checkpoints/ColorizeTraining2-checkpoint.ipynb

+ 5 - 5
fasterai/generators.py

@@ -20,7 +20,7 @@ class GeneratorModule(ABC, nn.Module):
         for l in c[n:]: set_trainable(l, True)
 
     def get_device(self):
-        next(self.parameters()).device
+        return next(self.parameters()).device
 
  
 class Unet34(GeneratorModule): 
@@ -51,7 +51,7 @@ class Unet34(GeneratorModule):
         self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
 
     #Gets around irritating inconsistent halving coming from resnet
-    def _pad(self, x, target):
+    def _pad(self, x: torch.Tensor, target: torch.Tensor)-> torch.Tensor:
         h = x.shape[2] 
         w = x.shape[3]
 
@@ -59,9 +59,9 @@ class Unet34(GeneratorModule):
         target_w = target.shape[3]*2
 
         if h<target_h or w<target_w:
-            target = Variable(torch.zeros(x.shape[0], x.shape[1], target_h, target_w))
-            target[:,:,:h,:w]=x
-            return to_gpu(target)
+            padh = target_h-h if target_h > h else 0
+            padw = target_w-w if target_w > w else 0
+            return F.pad(x, (0,padw,0,padh), "constant",0)
 
         return x
            

+ 3 - 0
fasterai/training.py

@@ -27,6 +27,9 @@ class CriticModule(ABC, nn.Module):
     def get_layer_groups(self)->[]:
         pass
 
+    def get_device(self):
+        return next(self.parameters()).device
+
 class DCCritic(CriticModule):
 
     def _generate_reduce_layers(self, nf:int):