senet.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. '''SENet in PyTorch.
  2. SENet is the winner of ImageNet-2017 (https://arxiv.org/abs/1709.01507).
  3. '''
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.autograd import Variable
  8. class BasicBlock(nn.Module):
  9. def __init__(self, in_planes, planes, stride=1):
  10. super(BasicBlock, self).__init__()
  11. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  12. self.bn1 = nn.BatchNorm2d(planes)
  13. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  14. self.bn2 = nn.BatchNorm2d(planes)
  15. self.shortcut = nn.Sequential()
  16. if stride != 1 or in_planes != planes:
  17. self.shortcut = nn.Sequential(
  18. nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
  19. nn.BatchNorm2d(planes)
  20. )
  21. # SE layers
  22. self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
  23. self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
  24. def forward(self, x):
  25. out = F.relu(self.bn1(self.conv1(x)))
  26. out = self.bn2(self.conv2(out))
  27. # Squeeze
  28. w = F.avg_pool2d(out, out.size(2))
  29. w = F.relu(self.fc1(w))
  30. w = F.sigmoid(self.fc2(w))
  31. # Excitation
  32. out = out * w # New broadcasting feature from v0.2!
  33. out += self.shortcut(x)
  34. out = F.relu(out)
  35. return out
  36. class PreActBlock(nn.Module):
  37. def __init__(self, in_planes, planes, stride=1):
  38. super(PreActBlock, self).__init__()
  39. self.bn1 = nn.BatchNorm2d(in_planes)
  40. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  41. self.bn2 = nn.BatchNorm2d(planes)
  42. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  43. if stride != 1 or in_planes != planes:
  44. self.shortcut = nn.Sequential(
  45. nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
  46. )
  47. # SE layers
  48. self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
  49. self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
  50. def forward(self, x):
  51. out = F.relu(self.bn1(x))
  52. shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
  53. out = self.conv1(out)
  54. out = self.conv2(F.relu(self.bn2(out)))
  55. # Squeeze
  56. w = F.avg_pool2d(out, out.size(2))
  57. w = F.relu(self.fc1(w))
  58. w = F.sigmoid(self.fc2(w))
  59. # Excitation
  60. out = out * w
  61. out += shortcut
  62. return out
  63. class SENet(nn.Module):
  64. def __init__(self, block, num_blocks, num_classes=10):
  65. super(SENet, self).__init__()
  66. self.in_planes = 64
  67. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
  68. self.bn1 = nn.BatchNorm2d(64)
  69. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
  70. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
  71. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
  72. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
  73. self.linear = nn.Linear(512, num_classes)
  74. def _make_layer(self, block, planes, num_blocks, stride):
  75. strides = [stride] + [1]*(num_blocks-1)
  76. layers = []
  77. for stride in strides:
  78. layers.append(block(self.in_planes, planes, stride))
  79. self.in_planes = planes
  80. return nn.Sequential(*layers)
  81. def forward(self, x):
  82. out = F.relu(self.bn1(self.conv1(x)))
  83. out = self.layer1(out)
  84. out = self.layer2(out)
  85. out = self.layer3(out)
  86. out = self.layer4(out)
  87. out = F.adaptive_max_pool2d(out, 1)
  88. out = out.view(out.size(0), -1)
  89. out = F.log_softmax(self.linear(out))
  90. return out
  91. def SENet18(): return SENet(PreActBlock, [2,2,2,2])
  92. def SENet34(): return SENet(PreActBlock, [3,4,6,3])