|
@@ -184,4 +184,69 @@ class DynamicUnet4(SequentialEx):
|
|
super().__init__(*layers)
|
|
super().__init__(*layers)
|
|
|
|
|
|
def __del__(self):
|
|
def __del__(self):
|
|
- if hasattr(self, "sfs"): self.sfs.remove()
|
|
|
|
|
|
+ if hasattr(self, "sfs"): self.sfs.remove()
|
|
|
|
+
|
|
|
|
+class UnetBlock5(nn.Module):
|
|
|
|
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
|
|
|
+ def __init__(self, up_in_c:int, x_in_c:int, out_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
|
|
|
|
+ self_attention:bool=False, **kwargs):
|
|
|
|
+ super().__init__()
|
|
|
|
+ self.hook = hook
|
|
|
|
+ self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
|
|
|
|
+ self.bn = batchnorm_2d(x_in_c)
|
|
|
|
+ ni = up_in_c//2 + x_in_c
|
|
|
|
+ nf = out_c
|
|
|
|
+ self.conv = conv_layer2(ni, nf, leaky=leaky, self_attention=self_attention, **kwargs)
|
|
|
|
+ self.relu = relu(leaky=leaky)
|
|
|
|
+
|
|
|
|
+ def forward(self, up_in:Tensor) -> Tensor:
|
|
|
|
+ s = self.hook.stored
|
|
|
|
+ up_out = self.shuf(up_in)
|
|
|
|
+ ssh = s.shape[-2:]
|
|
|
|
+ if ssh != up_out.shape[-2:]:
|
|
|
|
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
|
|
|
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
|
|
|
+ return self.conv(cat_x)
|
|
|
|
+
|
|
|
|
+#custom filter widths
|
|
|
|
+class DynamicUnet5(SequentialEx):
|
|
|
|
+ "Create a U-Net from a given architecture."
|
|
|
|
+ def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
|
|
|
|
+ y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=True,
|
|
|
|
+ norm_type:Optional[NormType]=NormType.Batch, nf:int=256, **kwargs):
|
|
|
|
+ extra_bn = norm_type == NormType.Spectral
|
|
|
|
+ imsize = (256,256)
|
|
|
|
+ sfs_szs = model_sizes(encoder, size=imsize)
|
|
|
|
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
|
|
|
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
|
|
|
|
+ x = dummy_eval(encoder, imsize).detach()
|
|
|
|
+
|
|
|
|
+ ni = sfs_szs[-1][1]
|
|
|
|
+ middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
|
|
|
|
+ conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
|
|
|
|
+ x = middle_conv(x)
|
|
|
|
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
|
|
|
+
|
|
|
|
+ for i,idx in enumerate(sfs_idxs):
|
|
|
|
+ not_final = i!=len(sfs_idxs)-1
|
|
|
|
+ up_in_c = int(x.shape[1]) if i == 0 else nf
|
|
|
|
+ x_in_c = int(sfs_szs[idx][1])
|
|
|
|
+ do_blur = blur and (not_final or blur_final)
|
|
|
|
+ sa = self_attention and (i==len(sfs_idxs)-3)
|
|
|
|
+ unet_block = UnetBlock5(up_in_c, x_in_c, nf, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
|
|
|
|
+ norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
|
|
|
|
+ layers.append(unet_block)
|
|
|
|
+ x = unet_block(x)
|
|
|
|
+
|
|
|
|
+ ni = x.shape[1]
|
|
|
|
+ if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
|
|
|
+ if last_cross:
|
|
|
|
+ layers.append(MergeLayer(dense=True))
|
|
|
|
+ ni += in_channels(encoder)
|
|
|
|
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
|
|
|
+ layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
|
|
|
|
+ if y_range is not None: layers.append(SigmoidRange(*y_range))
|
|
|
|
+ super().__init__(*layers)
|
|
|
|
+
|
|
|
|
+ def __del__(self):
|
|
|
|
+ if hasattr(self, "sfs"): self.sfs.remove()
|