unet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. from fastai.layers import *
  2. from fasterai.layers import *
  3. from fastai.torch_core import *
  4. from fastai.callbacks.hooks import *
  5. #The code below is meant to be merged into fastaiv1 ideally
  6. __all__ = ['DynamicUnet2', 'UnetBlock2']
  7. def _get_sfs_idxs(sizes:Sizes) -> List[int]:
  8. "Get the indexes of the layers where the size of the activation changes."
  9. feature_szs = [size[-1] for size in sizes]
  10. sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
  11. if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
  12. return sfs_idxs
  13. class PixelShuffle_ICNR2(nn.Module):
  14. "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
  15. def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
  16. super().__init__()
  17. nf = ifnone(nf, ni)
  18. self.conv = conv_layer2(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
  19. icnr(self.conv[0].weight)
  20. self.shuf = nn.PixelShuffle(scale)
  21. # Blurring over (h*w) kernel
  22. # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
  23. # - https://arxiv.org/abs/1806.02658
  24. self.pad = nn.ReplicationPad2d((1,0,1,0))
  25. self.blur = nn.AvgPool2d(2, stride=1)
  26. self.relu = relu(True, leaky=leaky)
  27. def forward(self,x):
  28. x = self.shuf(self.relu(self.conv(x)))
  29. return self.blur(self.pad(x)) if self.blur else x
  30. class UnetBlock2(nn.Module):
  31. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  32. def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
  33. self_attention:bool=False, nf_factor:float=1.0, **kwargs):
  34. super().__init__()
  35. self.hook = hook
  36. self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
  37. self.bn = batchnorm_2d(x_in_c)
  38. ni = up_in_c//2 + x_in_c
  39. nf = int((ni if final_div else ni//2)*nf_factor)
  40. self.conv1 = conv_layer2(ni, nf, leaky=leaky, **kwargs)
  41. self.conv2 = conv_layer2(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
  42. self.relu = relu(leaky=leaky)
  43. def forward(self, up_in:Tensor) -> Tensor:
  44. s = self.hook.stored
  45. up_out = self.shuf(up_in)
  46. ssh = s.shape[-2:]
  47. if ssh != up_out.shape[-2:]:
  48. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  49. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  50. return self.conv2(self.conv1(cat_x))
  51. class DynamicUnet2(SequentialEx):
  52. "Create a U-Net from a given architecture."
  53. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  54. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
  55. norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
  56. #extra_bn = norm_type in (NormType.Spectral, NormType.Weight)
  57. extra_bn = norm_type == NormType.Spectral
  58. imsize = (256,256)
  59. sfs_szs = model_sizes(encoder, size=imsize)
  60. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  61. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
  62. x = dummy_eval(encoder, imsize).detach()
  63. ni = sfs_szs[-1][1]
  64. middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  65. conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  66. x = middle_conv(x)
  67. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  68. for i,idx in enumerate(sfs_idxs):
  69. not_final = i!=len(sfs_idxs)-1
  70. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  71. do_blur = blur and (not_final or blur_final)
  72. sa = self_attention and (i==len(sfs_idxs)-3)
  73. unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  74. norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
  75. layers.append(unet_block)
  76. x = unet_block(x)
  77. ni = x.shape[1]
  78. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  79. if last_cross:
  80. layers.append(MergeLayer(dense=True))
  81. ni += in_channels(encoder)
  82. #TODO: Missing norm_type argument here. DOH!
  83. layers.append(res_block(ni, bottle=bottle, **kwargs))
  84. layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  85. if y_range is not None: layers.append(SigmoidRange(*y_range))
  86. super().__init__(*layers)
  87. def __del__(self):
  88. if hasattr(self, "sfs"): self.sfs.remove()
  89. class DynamicUnet3(SequentialEx):
  90. "Create a U-Net from a given architecture."
  91. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  92. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
  93. norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
  94. extra_bn = norm_type == NormType.Spectral
  95. imsize = (256,256)
  96. sfs_szs = model_sizes(encoder, size=imsize)
  97. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  98. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
  99. x = dummy_eval(encoder, imsize).detach()
  100. ni = sfs_szs[-1][1]
  101. middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  102. conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  103. x = middle_conv(x)
  104. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  105. for i,idx in enumerate(sfs_idxs):
  106. not_final = i!=len(sfs_idxs)-1
  107. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  108. do_blur = blur and (not_final or blur_final)
  109. sa = self_attention and (i==len(sfs_idxs)-3)
  110. unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  111. norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
  112. layers.append(unet_block)
  113. x = unet_block(x)
  114. ni = x.shape[1]
  115. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  116. if last_cross:
  117. layers.append(MergeLayer(dense=True))
  118. ni += in_channels(encoder)
  119. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  120. layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  121. if y_range is not None: layers.append(SigmoidRange(*y_range))
  122. super().__init__(*layers)
  123. def __del__(self):
  124. if hasattr(self, "sfs"): self.sfs.remove()
  125. #No batch norm
  126. class DynamicUnet4(SequentialEx):
  127. "Create a U-Net from a given architecture."
  128. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  129. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
  130. norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
  131. #extra_bn = norm_type == NormType.Spectral
  132. extra_bn = False
  133. imsize = (256,256)
  134. sfs_szs = model_sizes(encoder, size=imsize)
  135. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  136. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
  137. x = dummy_eval(encoder, imsize).detach()
  138. ni = sfs_szs[-1][1]
  139. middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  140. conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  141. x = middle_conv(x)
  142. #layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  143. layers = [encoder, nn.ReLU(), middle_conv]
  144. for i,idx in enumerate(sfs_idxs):
  145. not_final = i!=len(sfs_idxs)-1
  146. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  147. do_blur = blur and (not_final or blur_final)
  148. sa = self_attention and (i==len(sfs_idxs)-3)
  149. unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  150. norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
  151. layers.append(unet_block)
  152. x = unet_block(x)
  153. ni = x.shape[1]
  154. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  155. if last_cross:
  156. layers.append(MergeLayer(dense=True))
  157. ni += in_channels(encoder)
  158. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  159. layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  160. if y_range is not None: layers.append(SigmoidRange(*y_range))
  161. super().__init__(*layers)
  162. def __del__(self):
  163. if hasattr(self, "sfs"): self.sfs.remove()
  164. class UnetBlock5(nn.Module):
  165. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  166. def __init__(self, up_in_c:int, x_in_c:int, out_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
  167. self_attention:bool=False, **kwargs):
  168. super().__init__()
  169. self.hook = hook
  170. self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
  171. self.bn = batchnorm_2d(x_in_c)
  172. ni = up_in_c//2 + x_in_c
  173. nf = out_c
  174. self.conv = conv_layer2(ni, nf, leaky=leaky, self_attention=self_attention, **kwargs)
  175. self.relu = relu(leaky=leaky)
  176. def forward(self, up_in:Tensor) -> Tensor:
  177. s = self.hook.stored
  178. up_out = self.shuf(up_in)
  179. ssh = s.shape[-2:]
  180. if ssh != up_out.shape[-2:]:
  181. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  182. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  183. return self.conv(cat_x)
  184. #custom filter widths
  185. class DynamicUnet5(SequentialEx):
  186. "Create a U-Net from a given architecture."
  187. def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
  188. y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=True,
  189. norm_type:Optional[NormType]=NormType.Batch, nf:int=256, **kwargs):
  190. extra_bn = norm_type == NormType.Spectral
  191. imsize = (256,256)
  192. sfs_szs = model_sizes(encoder, size=imsize)
  193. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  194. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
  195. x = dummy_eval(encoder, imsize).detach()
  196. ni = sfs_szs[-1][1]
  197. middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
  198. conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
  199. x = middle_conv(x)
  200. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  201. for i,idx in enumerate(sfs_idxs):
  202. not_final = i!=len(sfs_idxs)-1
  203. up_in_c = int(x.shape[1]) if i == 0 else nf
  204. x_in_c = int(sfs_szs[idx][1])
  205. do_blur = blur and (not_final or blur_final)
  206. sa = self_attention and (i==len(sfs_idxs)-3)
  207. unet_block = UnetBlock5(up_in_c, x_in_c, nf, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
  208. norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
  209. layers.append(unet_block)
  210. x = unet_block(x)
  211. ni = x.shape[1]
  212. if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
  213. if last_cross:
  214. layers.append(MergeLayer(dense=True))
  215. ni += in_channels(encoder)
  216. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  217. layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
  218. if y_range is not None: layers.append(SigmoidRange(*y_range))
  219. super().__init__(*layers)
  220. def __del__(self):
  221. if hasattr(self, "sfs"): self.sfs.remove()