unet.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from fastai.layers import *
  2. from .layers import *
  3. from fastai.torch_core import *
  4. from fastai.callbacks.hooks import *
  5. from fastai.vision import *
  6. #The code below is meant to be merged into fastaiv1 ideally
  7. __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
  8. def _get_sfs_idxs(sizes:Sizes) -> List[int]:
  9. "Get the indexes of the layers where the size of the activation changes."
  10. feature_szs = [size[-1] for size in sizes]
  11. sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
  12. if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
  13. return sfs_idxs
  14. class CustomPixelShuffle_ICNR(nn.Module):
  15. "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
  16. def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
  17. super().__init__()
  18. nf = ifnone(nf, ni)
  19. self.conv = custom_conv_layer(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
  20. icnr(self.conv[0].weight)
  21. self.shuf = nn.PixelShuffle(scale)
  22. # Blurring over (h*w) kernel
  23. # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
  24. # - https://arxiv.org/abs/1806.02658
  25. self.pad = nn.ReplicationPad2d((1,0,1,0))
  26. self.blur = nn.AvgPool2d(2, stride=1)
  27. self.relu = relu(True, leaky=leaky)
  28. def forward(self,x):
  29. x = self.shuf(self.relu(self.conv(x)))
  30. return self.blur(self.pad(x)) if self.blur else x
  31. class UnetBlockDeep(nn.Module):
  32. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  33. def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
  34. self_attention:bool=False, nf_factor:float=1.0, **kwargs):
  35. super().__init__()
  36. self.hook = hook
  37. self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
  38. self.bn = batchnorm_2d(x_in_c)
  39. ni = up_in_c//2 + x_in_c
  40. nf = int((ni if final_div else ni//2)*nf_factor)
  41. self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
  42. self.conv2 = custom_conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
  43. self.relu = relu(leaky=leaky)
  44. def forward(self, up_in:Tensor) -> Tensor:
  45. s = self.hook.stored
  46. up_out = self.shuf(up_in)
  47. ssh = s.shape[-2:]
  48. if ssh != up_out.shape[-2:]:
  49. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  50. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  51. return self.conv2(self.conv1(cat_x))
  52. class DynamicUnetDeep(SequentialEx):
  53. "Create a U-Net from a given architecture."
  54. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  55. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
  56. norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
  57. extra_bn = norm_type == NormType.Spectral
  58. imsize = (256,256)
  59. sfs_szs = model_sizes(encoder, size=imsize)
  60. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  61. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
  62. x = dummy_eval(encoder, imsize).detach()
  63. ni = sfs_szs[-1][1]
  64. middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  65. custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  66. x = middle_conv(x)
  67. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  68. for i,idx in enumerate(sfs_idxs):
  69. not_final = i!=len(sfs_idxs)-1
  70. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  71. do_blur = blur and (not_final or blur_final)
  72. sa = self_attention and (i==len(sfs_idxs)-3)
  73. unet_block = UnetBlockDeep(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  74. norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
  75. layers.append(unet_block)
  76. x = unet_block(x)
  77. ni = x.shape[1]
  78. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  79. if last_cross:
  80. layers.append(MergeLayer(dense=True))
  81. ni += in_channels(encoder)
  82. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  83. layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  84. if y_range is not None: layers.append(SigmoidRange(*y_range))
  85. super().__init__(*layers)
  86. def __del__(self):
  87. if hasattr(self, "sfs"): self.sfs.remove()
  88. #------------------------------------------------------
  89. class UnetBlockWide(nn.Module):
  90. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  91. def __init__(self, up_in_c:int, x_in_c:int, n_out:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
  92. self_attention:bool=False, **kwargs):
  93. super().__init__()
  94. self.hook = hook
  95. up_out = x_out = n_out//2
  96. self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, leaky=leaky, **kwargs)
  97. self.bn = batchnorm_2d(x_in_c)
  98. ni = up_out + x_in_c
  99. self.conv = custom_conv_layer(ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs)
  100. self.relu = relu(leaky=leaky)
  101. def forward(self, up_in:Tensor) -> Tensor:
  102. s = self.hook.stored
  103. up_out = self.shuf(up_in)
  104. ssh = s.shape[-2:]
  105. if ssh != up_out.shape[-2:]:
  106. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  107. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  108. return self.conv(cat_x)
  109. class DynamicUnetWide(SequentialEx):
  110. "Create a U-Net from a given architecture."
  111. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  112. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
  113. norm_type:Optional[NormType]=NormType.Batch, nf_factor:int=1, **kwargs):
  114. nf = 512 * nf_factor
  115. extra_bn = norm_type == NormType.Spectral
  116. imsize = (256,256)
  117. sfs_szs = model_sizes(encoder, size=imsize)
  118. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  119. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
  120. x = dummy_eval(encoder, imsize).detach()
  121. ni = sfs_szs[-1][1]
  122. middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  123. custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  124. x = middle_conv(x)
  125. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  126. for i,idx in enumerate(sfs_idxs):
  127. not_final = i!=len(sfs_idxs)-1
  128. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  129. do_blur = blur and (not_final or blur_final)
  130. sa = self_attention and (i==len(sfs_idxs)-3)
  131. n_out = nf if not_final else nf//2
  132. unet_block = UnetBlockWide(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  133. norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
  134. layers.append(unet_block)
  135. x = unet_block(x)
  136. ni = x.shape[1]
  137. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  138. if last_cross:
  139. layers.append(MergeLayer(dense=True))
  140. ni += in_channels(encoder)
  141. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  142. layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  143. if y_range is not None: layers.append(SigmoidRange(*y_range))
  144. super().__init__(*layers)
  145. def __del__(self):
  146. if hasattr(self, "sfs"): self.sfs.remove()