|
@@ -41,7 +41,6 @@ class Unet34(GeneratorModule):
|
|
sn=True
|
|
sn=True
|
|
self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
|
|
self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
|
|
self.relu = nn.ReLU()
|
|
self.relu = nn.ReLU()
|
|
- self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
|
|
|
|
self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
|
|
self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
|
|
@@ -65,12 +64,24 @@ class Unet34(GeneratorModule):
|
|
return x
|
|
return x
|
|
|
|
|
|
def forward(self, x_in:torch.Tensor):
|
|
def forward(self, x_in:torch.Tensor):
|
|
- x = self.rn(x_in)
|
|
|
|
|
|
+ x = self.rn[0](x_in)
|
|
|
|
+ x = self.rn[1](x)
|
|
|
|
+ x = self.rn[2](x)
|
|
|
|
+ enc0 = x
|
|
|
|
+ x = self.rn[3](x)
|
|
|
|
+ x = self.rn[4](x)
|
|
|
|
+ enc1 = x
|
|
|
|
+ x = self.rn[5](x)
|
|
|
|
+ enc2 = x
|
|
|
|
+ x = self.rn[6](x)
|
|
|
|
+ enc3 = x
|
|
|
|
+ x = self.rn[7](x)
|
|
|
|
+
|
|
x = self.relu(x)
|
|
x = self.relu(x)
|
|
- x = self.up1(x, self._pad(self.sfs[3].features, x))
|
|
|
|
- x = self.up2(x, self._pad(self.sfs[2].features, x))
|
|
|
|
- x = self.up3(x, self._pad(self.sfs[1].features, x))
|
|
|
|
- x = self.up4(x, self._pad(self.sfs[0].features, x))
|
|
|
|
|
|
+ x = self.up1(x, self._pad(enc3, x))
|
|
|
|
+ x = self.up2(x, self._pad(enc2, x))
|
|
|
|
+ x = self.up3(x, self._pad(enc1, x))
|
|
|
|
+ x = self.up4(x, self._pad(enc0, x))
|
|
x = self.up5(x)
|
|
x = self.up5(x)
|
|
x = self.out(x)
|
|
x = self.out(x)
|
|
return x
|
|
return x
|
|
@@ -102,7 +113,6 @@ class Unet34_V2(GeneratorModule):
|
|
sn=True
|
|
sn=True
|
|
self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
|
|
self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
|
|
self.relu = nn.ReLU()
|
|
self.relu = nn.ReLU()
|
|
- self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
|
|
|
|
self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
|
|
self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
|
|
self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
|
|
@@ -125,12 +135,24 @@ class Unet34_V2(GeneratorModule):
|
|
return x
|
|
return x
|
|
|
|
|
|
def forward(self, x_in:torch.Tensor):
|
|
def forward(self, x_in:torch.Tensor):
|
|
- x = self.rn(x_in)
|
|
|
|
|
|
+ x = self.rn[0](x_in)
|
|
|
|
+ x = self.rn[1](x)
|
|
|
|
+ x = self.rn[2](x)
|
|
|
|
+ enc0 = x
|
|
|
|
+ x = self.rn[3](x)
|
|
|
|
+ x = self.rn[4](x)
|
|
|
|
+ enc1 = x
|
|
|
|
+ x = self.rn[5](x)
|
|
|
|
+ enc2 = x
|
|
|
|
+ x = self.rn[6](x)
|
|
|
|
+ enc3 = x
|
|
|
|
+ x = self.rn[7](x)
|
|
|
|
+
|
|
x = self.relu(x)
|
|
x = self.relu(x)
|
|
- x = self.up1(x, self._pad(self.sfs[3].features, x))
|
|
|
|
- x = self.up2(x, self._pad(self.sfs[2].features, x))
|
|
|
|
- x = self.up3(x, self._pad(self.sfs[1].features, x))
|
|
|
|
- x = self.up4(x, self._pad(self.sfs[0].features, x))
|
|
|
|
|
|
+ x = self.up1(x, self._pad(enc3, x))
|
|
|
|
+ x = self.up2(x, self._pad(enc2, x))
|
|
|
|
+ x = self.up3(x, self._pad(enc1, x))
|
|
|
|
+ x = self.up4(x, self._pad(enc0, x))
|
|
x = self.up5(x)
|
|
x = self.up5(x)
|
|
x = self.out(x)
|
|
x = self.out(x)
|
|
return x
|
|
return x
|