unet.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5. import numpy as np
  6. def get_sfs_idxs(sfs, last=True):
  7. """
  8. Return the saved feature indexes that will be concatenated
  9. Inputs:
  10. sfs (list): saved features by hook function, in other words intermediate activations
  11. last (bool): whether to concatenate only last different activation, or all from the encoder model
  12. """
  13. if last:
  14. feature_szs = [sfs_feats.features.size()[-1] for sfs_feats in sfs]
  15. sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
  16. if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
  17. else: sfs_idxs = list(range(len(sfs)))
  18. return sfs_idxs
  19. def conv_bn_relu(in_c, out_c, kernel_size, stride, padding):
  20. return [
  21. nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding),
  22. nn.ReLU(),
  23. nn.BatchNorm2d(out_c)]
  24. class UnetBlock(nn.Module):
  25. #TODO: ADAPT KERNEL SIZE, STRIDE AND PADDING SO THAT ANY SIZE DECAY WILL BE SUPPORTED
  26. def __init__(self, up_in_c, x_in_c):
  27. super().__init__()
  28. self.upconv = nn.ConvTranspose2d(up_in_c, up_in_c // 2, 2, 2) # H, W -> 2H, 2W
  29. self.conv1 = nn.Conv2d(x_in_c + up_in_c // 2, (x_in_c + up_in_c // 2) // 2, 3, 1, 1)
  30. self.conv2 = nn.Conv2d((x_in_c + up_in_c // 2) // 2, (x_in_c + up_in_c // 2) // 2, 3, 1, 1)
  31. self.bn = nn.BatchNorm2d((x_in_c + up_in_c // 2) // 2)
  32. def forward(self, up_in, x_in):
  33. up_out = self.upconv(up_in)
  34. cat_x = torch.cat([up_out, x_in], dim=1)
  35. x = F.relu(self.conv1(cat_x))
  36. x = F.relu(self.conv2(x))
  37. return self.bn(x)
  38. class SaveFeatures():
  39. """ Extract pretrained activations"""
  40. features=None
  41. def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
  42. def hook_fn(self, module, input, output): self.features = output
  43. def remove(self): self.hook.remove()
  44. class DynamicUnet(nn.Module):
  45. """
  46. A dynamic implementation of Unet architecture, because calculating connections
  47. and channels suck!. When an encoder is passed, this network will
  48. automatically construct a decoder after the first single forward pass for any
  49. given encoder architecture.
  50. Decoder part is heavily based on the original Unet paper:
  51. https://arxiv.org/abs/1505.04597.
  52. Inputs:
  53. encoder(nn.Module): Preferably a pretrained model, such as VGG or ResNet
  54. last (bool): Whether to concat only last activation just before a size change
  55. n_classes (int): Number of classes to output in final step of decoder
  56. Important Note: If architecture directly reduces the dimension of an image as soon as the
  57. first forward pass then output size will not be same as the input size, e.g. ResNet.
  58. In order to resolve this problem architecture will add an additional extra conv transpose
  59. layer. Also, currently Dynamic Unet expects size change to be H,W -> H/2, W/2. This is
  60. not a problem for state-of-the-art architectures as they follow this pattern but it should
  61. be changed for custom encoders that might have a different size decay.
  62. """
  63. def __init__(self, encoder, last=True, n_classes=3):
  64. super().__init__()
  65. self.encoder = encoder
  66. self.n_children = len(list(encoder.children()))
  67. self.sfs = [SaveFeatures(encoder[i]) for i in range(self.n_children)]
  68. self.last = last
  69. self.n_classes = n_classes
  70. def forward(self, x):
  71. # get imsize
  72. imsize = x.size()[-2:]
  73. # encoder output
  74. x = F.relu(self.encoder(x))
  75. # initialize sfs_idxs, sfs_szs, middle_in_c and middle_conv only once
  76. if not hasattr(self, 'middle_conv'):
  77. self.sfs_szs = [sfs_feats.features.size() for sfs_feats in self.sfs]
  78. self.sfs_idxs = get_sfs_idxs(self.sfs, self.last)
  79. middle_in_c = self.sfs_szs[-1][1]
  80. middle_conv = nn.Sequential(*conv_bn_relu(middle_in_c, middle_in_c * 2, 3, 1, 1),
  81. *conv_bn_relu(middle_in_c * 2, middle_in_c, 3, 1, 1))
  82. self.middle_conv = middle_conv
  83. # middle conv
  84. x = self.middle_conv(x)
  85. # initialize upmodel, extra_block and 1x1 final conv
  86. if not hasattr(self, 'upmodel'):
  87. x_copy = Variable(x.data, requires_grad=False)
  88. upmodel = []
  89. for idx in self.sfs_idxs[::-1]:
  90. up_in_c, x_in_c = int(x_copy.size()[1]), int(self.sfs_szs[idx][1])
  91. unet_block = UnetBlock(up_in_c, x_in_c)
  92. upmodel.append(unet_block)
  93. x_copy = unet_block(x_copy, self.sfs[idx].features)
  94. self.upmodel = nn.Sequential(*upmodel)
  95. if imsize != self.sfs_szs[0][-2:]:
  96. extra_in_c = self.upmodel[-1].conv2.out_channels
  97. self.extra_block = nn.ConvTranspose2d(extra_in_c, extra_in_c, 2, 2)
  98. final_in_c = self.upmodel[-1].conv2.out_channels
  99. self.final_conv = nn.Conv2d(final_in_c, self.n_classes, 1)
  100. # run upsample
  101. for block, idx in zip(self.upmodel, self.sfs_idxs[::-1]):
  102. x = block(x, self.sfs[idx].features)
  103. if hasattr(self, 'extra_block'):
  104. x = self.extra_block(x)
  105. out = self.final_conv(x)
  106. return out