123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from fastai.layers import *
- from .layers import *
- from fastai.torch_core import *
- from fastai.callbacks.hooks import *
- from fastai.vision import *
- #The code below is meant to be merged into fastaiv1 ideally
- __all__ = ['CustomDynamicUnet', 'CustomUnetBlock', 'CustomPixelShuffle_ICNR']
- def _get_sfs_idxs(sizes:Sizes) -> List[int]:
- "Get the indexes of the layers where the size of the activation changes."
- feature_szs = [size[-1] for size in sizes]
- sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
- if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
- return sfs_idxs
- class CustomPixelShuffle_ICNR(nn.Module):
- "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
- def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
- super().__init__()
- nf = ifnone(nf, ni)
- self.conv = custom_conv_layer(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
- icnr(self.conv[0].weight)
- self.shuf = nn.PixelShuffle(scale)
- # Blurring over (h*w) kernel
- # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
- # - https://arxiv.org/abs/1806.02658
- self.pad = nn.ReplicationPad2d((1,0,1,0))
- self.blur = nn.AvgPool2d(2, stride=1)
- self.relu = relu(True, leaky=leaky)
- def forward(self,x):
- x = self.shuf(self.relu(self.conv(x)))
- return self.blur(self.pad(x)) if self.blur else x
- class CustomUnetBlock(nn.Module):
- "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
- def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
- self_attention:bool=False, nf_factor:float=1.0, **kwargs):
- super().__init__()
- self.hook = hook
- self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
- self.bn = batchnorm_2d(x_in_c)
- ni = up_in_c//2 + x_in_c
- nf = int((ni if final_div else ni//2)*nf_factor)
- self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
- self.conv2 = custom_conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
- self.relu = relu(leaky=leaky)
- def forward(self, up_in:Tensor) -> Tensor:
- s = self.hook.stored
- up_out = self.shuf(up_in)
- ssh = s.shape[-2:]
- if ssh != up_out.shape[-2:]:
- up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
- cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
- return self.conv2(self.conv1(cat_x))
- class CustomDynamicUnet(SequentialEx):
- "Create a U-Net from a given architecture."
- def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
- y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
- norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
- extra_bn = norm_type == NormType.Spectral
- imsize = (256,256)
- sfs_szs = model_sizes(encoder, size=imsize)
- sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
- self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
- x = dummy_eval(encoder, imsize).detach()
- ni = sfs_szs[-1][1]
- middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
- custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
- x = middle_conv(x)
- layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
- for i,idx in enumerate(sfs_idxs):
- not_final = i!=len(sfs_idxs)-1
- up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
- do_blur = blur and (not_final or blur_final)
- sa = self_attention and (i==len(sfs_idxs)-3)
- unet_block = CustomUnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
- norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
- layers.append(unet_block)
- x = unet_block(x)
- ni = x.shape[1]
- if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
- if last_cross:
- layers.append(MergeLayer(dense=True))
- ni += in_channels(encoder)
- layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
- layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
- if y_range is not None: layers.append(SigmoidRange(*y_range))
- super().__init__(*layers)
- def __del__(self):
- if hasattr(self, "sfs"): self.sfs.remove()
- #------------------------------------------------------
- class CustomUnetBlock2(nn.Module):
- "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
- def __init__(self, up_in_c:int, x_in_c:int, n_out:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
- self_attention:bool=False, **kwargs):
- super().__init__()
- self.hook = hook
- up_out = x_out = n_out//2
- self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, leaky=leaky, **kwargs)
- self.bn = batchnorm_2d(x_in_c)
- ni = up_out + x_in_c
- self.conv = custom_conv_layer(ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs)
- self.relu = relu(leaky=leaky)
- def forward(self, up_in:Tensor) -> Tensor:
- s = self.hook.stored
- up_out = self.shuf(up_in)
- ssh = s.shape[-2:]
- if ssh != up_out.shape[-2:]:
- up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
- cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
- return self.conv(cat_x)
- class CustomDynamicUnet2(SequentialEx):
- "Create a U-Net from a given architecture."
- def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
- y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
- norm_type:Optional[NormType]=NormType.Batch, nf_factor:int=1, **kwargs):
-
- nf = 512 * nf_factor
- extra_bn = norm_type == NormType.Spectral
- imsize = (256,256)
- sfs_szs = model_sizes(encoder, size=imsize)
- sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
- self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
- x = dummy_eval(encoder, imsize).detach()
- ni = sfs_szs[-1][1]
- middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
- custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
- x = middle_conv(x)
- layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
- for i,idx in enumerate(sfs_idxs):
- not_final = i!=len(sfs_idxs)-1
- up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
- do_blur = blur and (not_final or blur_final)
- sa = self_attention and (i==len(sfs_idxs)-3)
- n_out = nf if not_final else nf//2
- unet_block = CustomUnetBlock2(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
- norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
- layers.append(unet_block)
- x = unet_block(x)
- ni = x.shape[1]
- if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
- if last_cross:
- layers.append(MergeLayer(dense=True))
- ni += in_channels(encoder)
- layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
- layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
- if y_range is not None: layers.append(SigmoidRange(*y_range))
- super().__init__(*layers)
- def __del__(self):
- if hasattr(self, "sfs"): self.sfs.remove()
|