123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- from fastai.layers import *
- from .layers import *
- from fastai.torch_core import *
- from fastai.callbacks.hooks import *
- from fastai.vision import *
- # The code below is meant to be merged into fastaiv1 ideally
- __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
- def _get_sfs_idxs(sizes: Sizes) -> List[int]:
- "Get the indexes of the layers where the size of the activation changes."
- feature_szs = [size[-1] for size in sizes]
- sfs_idxs = list(
- np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
- )
- if feature_szs[0] != feature_szs[1]:
- sfs_idxs = [0] + sfs_idxs
- return sfs_idxs
- class CustomPixelShuffle_ICNR(nn.Module):
- "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
- def __init__(
- self,
- ni: int,
- nf: int = None,
- scale: int = 2,
- blur: bool = False,
- leaky: float = None,
- **kwargs
- ):
- super().__init__()
- nf = ifnone(nf, ni)
- self.conv = custom_conv_layer(
- ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
- )
- icnr(self.conv[0].weight)
- self.shuf = nn.PixelShuffle(scale)
- # Blurring over (h*w) kernel
- # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
- # - https://arxiv.org/abs/1806.02658
- self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
- self.blur = nn.AvgPool2d(2, stride=1)
- self.relu = relu(True, leaky=leaky)
- def forward(self, x):
- x = self.shuf(self.relu(self.conv(x)))
- return self.blur(self.pad(x)) if self.blur else x
- class UnetBlockDeep(nn.Module):
- "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
- def __init__(
- self,
- up_in_c: int,
- x_in_c: int,
- hook: Hook,
- final_div: bool = True,
- blur: bool = False,
- leaky: float = None,
- self_attention: bool = False,
- nf_factor: float = 1.0,
- **kwargs
- ):
- super().__init__()
- self.hook = hook
- self.shuf = CustomPixelShuffle_ICNR(
- up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
- )
- self.bn = batchnorm_2d(x_in_c)
- ni = up_in_c // 2 + x_in_c
- nf = int((ni if final_div else ni // 2) * nf_factor)
- self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
- self.conv2 = custom_conv_layer(
- nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
- )
- self.relu = relu(leaky=leaky)
- def forward(self, up_in: Tensor) -> Tensor:
- s = self.hook.stored
- up_out = self.shuf(up_in)
- ssh = s.shape[-2:]
- if ssh != up_out.shape[-2:]:
- up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
- cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
- return self.conv2(self.conv1(cat_x))
- class DynamicUnetDeep(SequentialEx):
- "Create a U-Net from a given architecture."
- def __init__(
- self,
- encoder: nn.Module,
- n_classes: int,
- blur: bool = False,
- blur_final=True,
- self_attention: bool = False,
- y_range: Optional[Tuple[float, float]] = None,
- last_cross: bool = True,
- bottle: bool = False,
- norm_type: Optional[NormType] = NormType.Batch,
- nf_factor: float = 1.0,
- **kwargs
- ):
- extra_bn = norm_type == NormType.Spectral
- imsize = (256, 256)
- sfs_szs = model_sizes(encoder, size=imsize)
- sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
- self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
- x = dummy_eval(encoder, imsize).detach()
- ni = sfs_szs[-1][1]
- middle_conv = nn.Sequential(
- custom_conv_layer(
- ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
- ),
- custom_conv_layer(
- ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
- ),
- ).eval()
- x = middle_conv(x)
- layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
- for i, idx in enumerate(sfs_idxs):
- not_final = i != len(sfs_idxs) - 1
- up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
- do_blur = blur and (not_final or blur_final)
- sa = self_attention and (i == len(sfs_idxs) - 3)
- unet_block = UnetBlockDeep(
- up_in_c,
- x_in_c,
- self.sfs[i],
- final_div=not_final,
- blur=blur,
- self_attention=sa,
- norm_type=norm_type,
- extra_bn=extra_bn,
- nf_factor=nf_factor,
- **kwargs
- ).eval()
- layers.append(unet_block)
- x = unet_block(x)
- ni = x.shape[1]
- if imsize != sfs_szs[0][-2:]:
- layers.append(PixelShuffle_ICNR(ni, **kwargs))
- if last_cross:
- layers.append(MergeLayer(dense=True))
- ni += in_channels(encoder)
- layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
- layers += [
- custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
- ]
- if y_range is not None:
- layers.append(SigmoidRange(*y_range))
- super().__init__(*layers)
- def __del__(self):
- if hasattr(self, "sfs"):
- self.sfs.remove()
- # ------------------------------------------------------
- class UnetBlockWide(nn.Module):
- "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
- def __init__(
- self,
- up_in_c: int,
- x_in_c: int,
- n_out: int,
- hook: Hook,
- final_div: bool = True,
- blur: bool = False,
- leaky: float = None,
- self_attention: bool = False,
- **kwargs
- ):
- super().__init__()
- self.hook = hook
- up_out = x_out = n_out // 2
- self.shuf = CustomPixelShuffle_ICNR(
- up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
- )
- self.bn = batchnorm_2d(x_in_c)
- ni = up_out + x_in_c
- self.conv = custom_conv_layer(
- ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
- )
- self.relu = relu(leaky=leaky)
- def forward(self, up_in: Tensor) -> Tensor:
- s = self.hook.stored
- up_out = self.shuf(up_in)
- ssh = s.shape[-2:]
- if ssh != up_out.shape[-2:]:
- up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
- cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
- return self.conv(cat_x)
- class DynamicUnetWide(SequentialEx):
- "Create a U-Net from a given architecture."
- def __init__(
- self,
- encoder: nn.Module,
- n_classes: int,
- blur: bool = False,
- blur_final=True,
- self_attention: bool = False,
- y_range: Optional[Tuple[float, float]] = None,
- last_cross: bool = True,
- bottle: bool = False,
- norm_type: Optional[NormType] = NormType.Batch,
- nf_factor: int = 1,
- **kwargs
- ):
- nf = 512 * nf_factor
- extra_bn = norm_type == NormType.Spectral
- imsize = (256, 256)
- sfs_szs = model_sizes(encoder, size=imsize)
- sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
- self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
- x = dummy_eval(encoder, imsize).detach()
- ni = sfs_szs[-1][1]
- middle_conv = nn.Sequential(
- custom_conv_layer(
- ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
- ),
- custom_conv_layer(
- ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
- ),
- ).eval()
- x = middle_conv(x)
- layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
- for i, idx in enumerate(sfs_idxs):
- not_final = i != len(sfs_idxs) - 1
- up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
- do_blur = blur and (not_final or blur_final)
- sa = self_attention and (i == len(sfs_idxs) - 3)
- n_out = nf if not_final else nf // 2
- unet_block = UnetBlockWide(
- up_in_c,
- x_in_c,
- n_out,
- self.sfs[i],
- final_div=not_final,
- blur=blur,
- self_attention=sa,
- norm_type=norm_type,
- extra_bn=extra_bn,
- **kwargs
- ).eval()
- layers.append(unet_block)
- x = unet_block(x)
- ni = x.shape[1]
- if imsize != sfs_szs[0][-2:]:
- layers.append(PixelShuffle_ICNR(ni, **kwargs))
- if last_cross:
- layers.append(MergeLayer(dense=True))
- ni += in_channels(encoder)
- layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
- layers += [
- custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
- ]
- if y_range is not None:
- layers.append(SigmoidRange(*y_range))
- super().__init__(*layers)
- def __del__(self):
- if hasattr(self, "sfs"):
- self.sfs.remove()
|