|
@@ -20,7 +20,7 @@ class GeneratorModule(ABC, nn.Module):
|
|
|
@abstractmethod
|
|
|
def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
def freeze_to(self, n:int):
|
|
|
c=self.get_layer_groups()
|
|
|
for l in c: set_trainable(l, False)
|
|
@@ -176,3 +176,26 @@ class Unet101(AbstractUnet):
|
|
|
layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
|
|
|
return layers
|
|
|
|
|
|
+class Unet152(AbstractUnet):
|
|
|
+ def __init__(self, nf_factor:int=1, scale:int=1):
|
|
|
+ super().__init__(nf_factor=nf_factor, scale=scale)
|
|
|
+
|
|
|
+ def _get_pretrained_resnet_base(self, layers_cut:int=0):
|
|
|
+ f = resnet152
|
|
|
+ cut,lr_cut = model_meta[f]
|
|
|
+ cut-=layers_cut
|
|
|
+ layers = cut_model(f(True), cut)
|
|
|
+ return nn.Sequential(*layers), lr_cut
|
|
|
+
|
|
|
+ def _get_decoding_layers(self, nf_factor:int, scale:int):
|
|
|
+ self_attention=True
|
|
|
+ bn=True
|
|
|
+ sn=True
|
|
|
+ leakyReLu=False
|
|
|
+ layers = []
|
|
|
+ layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
|
|
|
+ layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
|
|
|
+ layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
|
|
|
+ layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
|
|
|
+ layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
|
|
|
+ return layers
|