unet.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. from fastai.layers import *
  2. from .layers import *
  3. from fastai.torch_core import *
  4. from fastai.callbacks.hooks import *
  5. from fastai.vision import *
  6. # The code below is meant to be merged into fastaiv1 ideally
  7. __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
  8. def _get_sfs_idxs(sizes: Sizes) -> List[int]:
  9. "Get the indexes of the layers where the size of the activation changes."
  10. feature_szs = [size[-1] for size in sizes]
  11. sfs_idxs = list(
  12. np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
  13. )
  14. if feature_szs[0] != feature_szs[1]:
  15. sfs_idxs = [0] + sfs_idxs
  16. return sfs_idxs
  17. class CustomPixelShuffle_ICNR(nn.Module):
  18. "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
  19. def __init__(
  20. self,
  21. ni: int,
  22. nf: int = None,
  23. scale: int = 2,
  24. blur: bool = False,
  25. leaky: float = None,
  26. **kwargs
  27. ):
  28. super().__init__()
  29. nf = ifnone(nf, ni)
  30. self.conv = custom_conv_layer(
  31. ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
  32. )
  33. icnr(self.conv[0].weight)
  34. self.shuf = nn.PixelShuffle(scale)
  35. # Blurring over (h*w) kernel
  36. # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
  37. # - https://arxiv.org/abs/1806.02658
  38. self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
  39. self.blur = nn.AvgPool2d(2, stride=1)
  40. self.relu = relu(True, leaky=leaky)
  41. def forward(self, x):
  42. x = self.shuf(self.relu(self.conv(x)))
  43. return self.blur(self.pad(x)) if self.blur else x
  44. class UnetBlockDeep(nn.Module):
  45. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  46. def __init__(
  47. self,
  48. up_in_c: int,
  49. x_in_c: int,
  50. hook: Hook,
  51. final_div: bool = True,
  52. blur: bool = False,
  53. leaky: float = None,
  54. self_attention: bool = False,
  55. nf_factor: float = 1.0,
  56. **kwargs
  57. ):
  58. super().__init__()
  59. self.hook = hook
  60. self.shuf = CustomPixelShuffle_ICNR(
  61. up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
  62. )
  63. self.bn = batchnorm_2d(x_in_c)
  64. ni = up_in_c // 2 + x_in_c
  65. nf = int((ni if final_div else ni // 2) * nf_factor)
  66. self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
  67. self.conv2 = custom_conv_layer(
  68. nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
  69. )
  70. self.relu = relu(leaky=leaky)
  71. def forward(self, up_in: Tensor) -> Tensor:
  72. s = self.hook.stored
  73. up_out = self.shuf(up_in)
  74. ssh = s.shape[-2:]
  75. if ssh != up_out.shape[-2:]:
  76. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  77. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  78. return self.conv2(self.conv1(cat_x))
  79. class DynamicUnetDeep(SequentialEx):
  80. "Create a U-Net from a given architecture."
  81. def __init__(
  82. self,
  83. encoder: nn.Module,
  84. n_classes: int,
  85. blur: bool = False,
  86. blur_final=True,
  87. self_attention: bool = False,
  88. y_range: Optional[Tuple[float, float]] = None,
  89. last_cross: bool = True,
  90. bottle: bool = False,
  91. norm_type: Optional[NormType] = NormType.Batch,
  92. nf_factor: float = 1.0,
  93. **kwargs
  94. ):
  95. extra_bn = norm_type == NormType.Spectral
  96. imsize = (256, 256)
  97. sfs_szs = model_sizes(encoder, size=imsize)
  98. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  99. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
  100. x = dummy_eval(encoder, imsize).detach()
  101. ni = sfs_szs[-1][1]
  102. middle_conv = nn.Sequential(
  103. custom_conv_layer(
  104. ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
  105. ),
  106. custom_conv_layer(
  107. ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
  108. ),
  109. ).eval()
  110. x = middle_conv(x)
  111. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  112. for i, idx in enumerate(sfs_idxs):
  113. not_final = i != len(sfs_idxs) - 1
  114. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  115. do_blur = blur and (not_final or blur_final)
  116. sa = self_attention and (i == len(sfs_idxs) - 3)
  117. unet_block = UnetBlockDeep(
  118. up_in_c,
  119. x_in_c,
  120. self.sfs[i],
  121. final_div=not_final,
  122. blur=blur,
  123. self_attention=sa,
  124. norm_type=norm_type,
  125. extra_bn=extra_bn,
  126. nf_factor=nf_factor,
  127. **kwargs
  128. ).eval()
  129. layers.append(unet_block)
  130. x = unet_block(x)
  131. ni = x.shape[1]
  132. if imsize != sfs_szs[0][-2:]:
  133. layers.append(PixelShuffle_ICNR(ni, **kwargs))
  134. if last_cross:
  135. layers.append(MergeLayer(dense=True))
  136. ni += in_channels(encoder)
  137. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  138. layers += [
  139. custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
  140. ]
  141. if y_range is not None:
  142. layers.append(SigmoidRange(*y_range))
  143. super().__init__(*layers)
  144. def __del__(self):
  145. if hasattr(self, "sfs"):
  146. self.sfs.remove()
  147. # ------------------------------------------------------
  148. class UnetBlockWide(nn.Module):
  149. "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
  150. def __init__(
  151. self,
  152. up_in_c: int,
  153. x_in_c: int,
  154. n_out: int,
  155. hook: Hook,
  156. final_div: bool = True,
  157. blur: bool = False,
  158. leaky: float = None,
  159. self_attention: bool = False,
  160. **kwargs
  161. ):
  162. super().__init__()
  163. self.hook = hook
  164. up_out = x_out = n_out // 2
  165. self.shuf = CustomPixelShuffle_ICNR(
  166. up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
  167. )
  168. self.bn = batchnorm_2d(x_in_c)
  169. ni = up_out + x_in_c
  170. self.conv = custom_conv_layer(
  171. ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
  172. )
  173. self.relu = relu(leaky=leaky)
  174. def forward(self, up_in: Tensor) -> Tensor:
  175. s = self.hook.stored
  176. up_out = self.shuf(up_in)
  177. ssh = s.shape[-2:]
  178. if ssh != up_out.shape[-2:]:
  179. up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
  180. cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
  181. return self.conv(cat_x)
  182. class DynamicUnetWide(SequentialEx):
  183. "Create a U-Net from a given architecture."
  184. def __init__(
  185. self,
  186. encoder: nn.Module,
  187. n_classes: int,
  188. blur: bool = False,
  189. blur_final=True,
  190. self_attention: bool = False,
  191. y_range: Optional[Tuple[float, float]] = None,
  192. last_cross: bool = True,
  193. bottle: bool = False,
  194. norm_type: Optional[NormType] = NormType.Batch,
  195. nf_factor: int = 1,
  196. **kwargs
  197. ):
  198. nf = 512 * nf_factor
  199. extra_bn = norm_type == NormType.Spectral
  200. imsize = (256, 256)
  201. sfs_szs = model_sizes(encoder, size=imsize)
  202. sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
  203. self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
  204. x = dummy_eval(encoder, imsize).detach()
  205. ni = sfs_szs[-1][1]
  206. middle_conv = nn.Sequential(
  207. custom_conv_layer(
  208. ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
  209. ),
  210. custom_conv_layer(
  211. ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
  212. ),
  213. ).eval()
  214. x = middle_conv(x)
  215. layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
  216. for i, idx in enumerate(sfs_idxs):
  217. not_final = i != len(sfs_idxs) - 1
  218. up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
  219. do_blur = blur and (not_final or blur_final)
  220. sa = self_attention and (i == len(sfs_idxs) - 3)
  221. n_out = nf if not_final else nf // 2
  222. unet_block = UnetBlockWide(
  223. up_in_c,
  224. x_in_c,
  225. n_out,
  226. self.sfs[i],
  227. final_div=not_final,
  228. blur=blur,
  229. self_attention=sa,
  230. norm_type=norm_type,
  231. extra_bn=extra_bn,
  232. **kwargs
  233. ).eval()
  234. layers.append(unet_block)
  235. x = unet_block(x)
  236. ni = x.shape[1]
  237. if imsize != sfs_szs[0][-2:]:
  238. layers.append(PixelShuffle_ICNR(ni, **kwargs))
  239. if last_cross:
  240. layers.append(MergeLayer(dense=True))
  241. ni += in_channels(encoder)
  242. layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
  243. layers += [
  244. custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
  245. ]
  246. if y_range is not None:
  247. layers.append(SigmoidRange(*y_range))
  248. super().__init__(*layers)
  249. def __del__(self):
  250. if hasattr(self, "sfs"):
  251. self.sfs.remove()