Browse Source

Making generator DataParallel friendly (yet...DataParallel doesn't train correctly yet...wat wat wat?)

Jason Antic 6 năm trước cách đây
mục cha
commit
dce09b6417
1 tập tin đã thay đổi với 34 bổ sung12 xóa
  1. 34 12
      fasterai/generators.py

+ 34 - 12
fasterai/generators.py

@@ -41,7 +41,6 @@ class Unet34(GeneratorModule):
         sn=True
         self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
         self.relu = nn.ReLU()
-        self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
         self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
@@ -65,12 +64,24 @@ class Unet34(GeneratorModule):
         return x
            
     def forward(self, x_in:torch.Tensor):
-        x = self.rn(x_in)
+        x = self.rn[0](x_in)
+        x = self.rn[1](x)
+        x = self.rn[2](x)
+        enc0 = x
+        x = self.rn[3](x)
+        x = self.rn[4](x)
+        enc1 = x
+        x = self.rn[5](x)
+        enc2 = x
+        x = self.rn[6](x)
+        enc3 = x
+        x = self.rn[7](x)
+
         x = self.relu(x)
-        x = self.up1(x, self._pad(self.sfs[3].features, x))
-        x = self.up2(x, self._pad(self.sfs[2].features, x))
-        x = self.up3(x, self._pad(self.sfs[1].features, x))
-        x = self.up4(x, self._pad(self.sfs[0].features, x))
+        x = self.up1(x, self._pad(enc3, x))
+        x = self.up2(x, self._pad(enc2, x))
+        x = self.up3(x, self._pad(enc1, x))
+        x = self.up4(x, self._pad(enc0, x))
         x = self.up5(x)
         x = self.out(x)
         return x
@@ -102,7 +113,6 @@ class Unet34_V2(GeneratorModule):
         sn=True
         self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
         self.relu = nn.ReLU()
-        self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
         self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
@@ -125,12 +135,24 @@ class Unet34_V2(GeneratorModule):
         return x
            
     def forward(self, x_in:torch.Tensor):
-        x = self.rn(x_in)
+        x = self.rn[0](x_in)
+        x = self.rn[1](x)
+        x = self.rn[2](x)
+        enc0 = x
+        x = self.rn[3](x)
+        x = self.rn[4](x)
+        enc1 = x
+        x = self.rn[5](x)
+        enc2 = x
+        x = self.rn[6](x)
+        enc3 = x
+        x = self.rn[7](x)
+
         x = self.relu(x)
-        x = self.up1(x, self._pad(self.sfs[3].features, x))
-        x = self.up2(x, self._pad(self.sfs[2].features, x))
-        x = self.up3(x, self._pad(self.sfs[1].features, x))
-        x = self.up4(x, self._pad(self.sfs[0].features, x))
+        x = self.up1(x, self._pad(enc3, x))
+        x = self.up2(x, self._pad(enc2, x))
+        x = self.up3(x, self._pad(enc1, x))
+        x = self.up4(x, self._pad(enc0, x))
         x = self.up5(x)
         x = self.out(x)
         return x