modules.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from fastai.torch_imports import *
  2. from fastai.conv_learner import *
  3. from torch.nn.utils.spectral_norm import spectral_norm
  4. class ConvBlock(nn.Module):
  5. def __init__(self, ni:int, no:int, ks:int=3, stride:int=1, pad:int=None, actn:bool=True,
  6. bn:bool=True, bias:bool=True, sn:bool=False, leakyReLu:bool=False, self_attention:bool=False,
  7. inplace_relu:bool=True):
  8. super().__init__()
  9. if pad is None: pad = ks//2//stride
  10. if sn:
  11. layers = [spectral_norm(nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias))]
  12. else:
  13. layers = [nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias)]
  14. if actn:
  15. layers.append(nn.LeakyReLU(0.2, inplace=inplace_relu)) if leakyReLu else layers.append(nn.ReLU(inplace=inplace_relu))
  16. if bn:
  17. layers.append(nn.BatchNorm2d(no))
  18. if self_attention:
  19. layers.append(SelfAttention(no, 1))
  20. self.seq = nn.Sequential(*layers)
  21. def forward(self, x):
  22. return self.seq(x)
  23. class UpSampleBlock(nn.Module):
  24. @staticmethod
  25. def _conv(ni:int, nf:int, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
  26. layers = [ConvBlock(ni, nf, ks=ks, sn=sn, bn=bn, actn=False, leakyReLu=leakyReLu)]
  27. return nn.Sequential(*layers)
  28. @staticmethod
  29. def _icnr(x:torch.Tensor, scale:int=2):
  30. init=nn.init.kaiming_normal_
  31. new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
  32. subkernel = torch.zeros(new_shape)
  33. subkernel = init(subkernel)
  34. subkernel = subkernel.transpose(0, 1)
  35. subkernel = subkernel.contiguous().view(subkernel.shape[0],
  36. subkernel.shape[1], -1)
  37. kernel = subkernel.repeat(1, 1, scale ** 2)
  38. transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
  39. kernel = kernel.contiguous().view(transposed_shape)
  40. kernel = kernel.transpose(0, 1)
  41. return kernel
  42. def __init__(self, ni:int, nf:int, scale:int=2, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
  43. super().__init__()
  44. layers = []
  45. assert (math.log(scale,2)).is_integer()
  46. for i in range(int(math.log(scale,2))):
  47. layers += [UpSampleBlock._conv(ni, nf*4,ks=ks, bn=bn, sn=sn, leakyReLu=leakyReLu),
  48. nn.PixelShuffle(2)]
  49. if bn:
  50. layers += [nn.BatchNorm2d(nf)]
  51. ni = nf
  52. self.sequence = nn.Sequential(*layers)
  53. self._icnr_init()
  54. def _icnr_init(self):
  55. conv_shuffle = self.sequence[0][0].seq[0]
  56. kernel = UpSampleBlock._icnr(conv_shuffle.weight)
  57. conv_shuffle.weight.data.copy_(kernel)
  58. def forward(self, x):
  59. return self.sequence(x)
  60. class UnetBlock(nn.Module):
  61. def __init__(self, up_in:int , x_in:int , n_out:int, bn:bool=True, sn:bool=False, leakyReLu:bool=False,
  62. self_attention:bool=False, inplace_relu:bool=True):
  63. super().__init__()
  64. up_out = x_out = n_out//2
  65. self.x_conv = ConvBlock(x_in, x_out, ks=1, bn=False, actn=False, sn=sn, inplace_relu=inplace_relu)
  66. self.tr_conv = UpSampleBlock(up_in, up_out, 2, bn=bn, sn=sn, leakyReLu=leakyReLu)
  67. self.relu = nn.LeakyReLU(0.2, inplace=inplace_relu) if leakyReLu else nn.ReLU(inplace=inplace_relu)
  68. out_layers = []
  69. if bn:
  70. out_layers.append(nn.BatchNorm2d(n_out))
  71. if self_attention:
  72. out_layers.append(SelfAttention(n_out))
  73. self.out = nn.Sequential(*out_layers)
  74. def forward(self, up_p:int, x_p:int):
  75. up_p = self.tr_conv(up_p)
  76. x_p = self.x_conv(x_p)
  77. x = torch.cat([up_p,x_p], dim=1)
  78. x = self.relu(x)
  79. return self.out(x)
  80. class SaveFeatures():
  81. features=None
  82. def __init__(self, m:nn.Module):
  83. self.hook = m.register_forward_hook(self.hook_fn)
  84. def hook_fn(self, module, input, output):
  85. self.features = output
  86. def remove(self):
  87. self.hook.remove()
  88. class SelfAttention(nn.Module):
  89. def __init__(self, in_channel:int, gain:int=1):
  90. super().__init__()
  91. self.query = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
  92. self.key = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
  93. self.value = self._spectral_init(nn.Conv1d(in_channel, in_channel, 1), gain=gain)
  94. self.gamma = nn.Parameter(torch.tensor(0.0))
  95. def _spectral_init(self, module:nn.Module, gain:int=1):
  96. nn.init.kaiming_uniform_(module.weight, gain)
  97. if module.bias is not None:
  98. module.bias.data.zero_()
  99. return spectral_norm(module)
  100. def forward(self, input:torch.Tensor):
  101. shape = input.shape
  102. flatten = input.view(shape[0], shape[1], -1)
  103. query = self.query(flatten).permute(0, 2, 1)
  104. key = self.key(flatten)
  105. value = self.value(flatten)
  106. query_key = torch.bmm(query, key)
  107. attn = F.softmax(query_key, 1)
  108. attn = torch.bmm(value, attn)
  109. attn = attn.view(*shape)
  110. out = self.gamma * attn + input
  111. return out