fa_resnet.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import torch.nn as nn
  2. import math
  3. import torch.utils.model_zoo as model_zoo
  4. from ..layers import *
  5. model_urls = {
  6. 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
  7. 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
  8. 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
  9. 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
  10. 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
  11. }
  12. def conv3x3(in_planes, out_planes, stride=1):
  13. "3x3 convolution with padding"
  14. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  15. padding=1, bias=False)
  16. def bn1(planes):
  17. m = nn.BatchNorm1d(planes)
  18. m.weight.data.fill_(1)
  19. m.bias.data.zero_()
  20. return m
  21. def bn(planes, init_zero=False):
  22. m = nn.BatchNorm2d(planes)
  23. m.weight.data.fill_(0 if init_zero else 1)
  24. m.bias.data.zero_()
  25. return m
  26. class BasicBlock(nn.Module):
  27. expansion = 1
  28. def __init__(self, inplanes, planes, stride=1, downsample=None):
  29. super().__init__()
  30. self.conv1 = conv3x3(inplanes, planes, stride)
  31. self.bn1 = bn(planes)
  32. self.relu = nn.ReLU(inplace=True)
  33. self.conv2 = conv3x3(planes, planes)
  34. self.bn2 = bn(planes)
  35. self.downsample = downsample
  36. self.stride = stride
  37. def forward(self, x):
  38. residual = x
  39. if self.downsample is not None: residual = self.downsample(x)
  40. out = self.conv1(x)
  41. out = self.relu(out)
  42. out = self.bn1(out)
  43. out = self.conv2(out)
  44. out += residual
  45. out = self.relu(out)
  46. out = self.bn2(out)
  47. return out
  48. class BottleneckFinal(nn.Module):
  49. expansion = 4
  50. def __init__(self, inplanes, planes, stride=1, downsample=None):
  51. super().__init__()
  52. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  53. self.bn1 = bn(planes)
  54. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  55. padding=1, bias=False)
  56. self.bn2 = bn(planes)
  57. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  58. self.bn3 = bn(planes * 4)
  59. self.relu = nn.ReLU(inplace=True)
  60. self.downsample = downsample
  61. self.stride = stride
  62. def forward(self, x):
  63. residual = x
  64. if self.downsample is not None: residual = self.downsample(x)
  65. out = self.conv1(x)
  66. out = self.bn1(out)
  67. out = self.relu(out)
  68. out = self.conv2(out)
  69. out = self.bn2(out)
  70. out = self.relu(out)
  71. out = self.conv3(out)
  72. out += residual
  73. out = self.bn3(out)
  74. out = self.relu(out)
  75. return out
  76. class BottleneckZero(nn.Module):
  77. expansion = 4
  78. def __init__(self, inplanes, planes, stride=1, downsample=None):
  79. super().__init__()
  80. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  81. self.bn1 = bn(planes)
  82. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  83. padding=1, bias=False)
  84. self.bn2 = bn(planes)
  85. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  86. self.bn3 = bn(planes * 4, init_zero=True)
  87. self.relu = nn.ReLU(inplace=True)
  88. self.downsample = downsample
  89. self.stride = stride
  90. def forward(self, x):
  91. residual = x
  92. if self.downsample is not None: residual = self.downsample(x)
  93. out = self.conv1(x)
  94. out = self.bn1(out)
  95. out = self.relu(out)
  96. out = self.conv2(out)
  97. out = self.bn2(out)
  98. out = self.relu(out)
  99. out = self.conv3(out)
  100. out = self.bn3(out)
  101. out += residual
  102. out = self.relu(out)
  103. return out
  104. class Bottleneck(nn.Module):
  105. expansion = 4
  106. def __init__(self, inplanes, planes, stride=1, downsample=None):
  107. super().__init__()
  108. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  109. self.bn1 = bn(planes)
  110. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  111. padding=1, bias=False)
  112. self.bn2 = bn(planes)
  113. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  114. self.bn3 = bn(planes * 4)
  115. self.relu = nn.ReLU(inplace=True)
  116. self.downsample = downsample
  117. self.stride = stride
  118. def forward(self, x):
  119. residual = x
  120. if self.downsample is not None: residual = self.downsample(x)
  121. out = self.conv1(x)
  122. out = self.bn1(out)
  123. out = self.relu(out)
  124. out = self.conv2(out)
  125. out = self.bn2(out)
  126. out = self.relu(out)
  127. out = self.conv3(out)
  128. out = self.bn3(out)
  129. out += residual
  130. out = self.relu(out)
  131. return out
  132. class ResNet(nn.Module):
  133. def __init__(self, block, layers, num_classes=1000, k=1, vgg_head=False):
  134. super().__init__()
  135. self.inplanes = 64
  136. features = [nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
  137. , bn(64) , nn.ReLU(inplace=True) , nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  138. , self._make_layer(block, int(64*k), layers[0])
  139. , self._make_layer(block, int(128*k), layers[1], stride=2)
  140. , self._make_layer(block, int(256*k), layers[2], stride=2)
  141. , self._make_layer(block, int(512*k), layers[3], stride=2)]
  142. out_sz = int(512*k) * block.expansion
  143. if vgg_head:
  144. features += [nn.AdaptiveAvgPool2d(3), Flatten()
  145. , nn.Linear(out_sz*3*3, 4096), nn.ReLU(inplace=True), bn1(4096), nn.Dropout(0.25)
  146. , nn.Linear(4096, 4096), nn.ReLU(inplace=True), bn1(4096), nn.Dropout(0.25)
  147. , nn.Linear(4096, num_classes)]
  148. else: features += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(out_sz, num_classes)]
  149. self.features = nn.Sequential(*features)
  150. for m in self.modules():
  151. if isinstance(m, nn.Conv2d):
  152. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  153. m.weight.data.normal_(0, math.sqrt(2. / n))
  154. def _make_layer(self, block, planes, blocks, stride=1):
  155. downsample = None
  156. if stride != 1 or self.inplanes != planes * block.expansion:
  157. downsample = nn.Sequential(
  158. nn.Conv2d(self.inplanes, planes * block.expansion,
  159. kernel_size=1, stride=stride, bias=False),
  160. bn(planes * block.expansion),
  161. )
  162. layers = []
  163. layers.append(block(self.inplanes, planes, stride, downsample))
  164. self.inplanes = planes * block.expansion
  165. for i in range(1, blocks): layers.append(block(self.inplanes, planes))
  166. return nn.Sequential(*layers)
  167. def forward(self, x): return self.features(x)
  168. def load(model, pre, name):
  169. if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))
  170. return model
  171. def fa_resnet18(pretrained=False, **kwargs): return load(ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), pretrained, 'resnet18')
  172. def fa_resnet34(pretrained=False, **kwargs): return load(ResNet(BasicBlock, [3, 4, 6, 3], **kwargs), pretrained, 'resnet34')
  173. def fa_resnet50(pretrained=False, **kwargs): return load(ResNet(Bottleneck, [3, 4, 6, 3], **kwargs), pretrained, 'resnet50')
  174. def fa_resnet101(pretrained=False, **kwargs): return load(ResNet(Bottleneck, [3, 4, 23, 3], **kwargs), pretrained, 'resnet101')
  175. def fa_resnet152(pretrained=False, **kwargs): return load(ResNet(Bottleneck, [3, 8, 36, 3], **kwargs), pretrained, 'resnet152')
  176. def bnf_resnet50 (): return ResNet(BottleneckFinal, [3, 4, 6, 3])
  177. def bnz_resnet50 (): return ResNet(BottleneckZero, [3, 4, 6, 3])
  178. def w5_resnet50 (): return ResNet(Bottleneck, [2, 3, 3, 2], k=1.5)
  179. def w25_resnet50(): return ResNet(Bottleneck, [3, 4, 4, 3], k=1.25)
  180. def w125_resnet50(): return ResNet(Bottleneck,[3, 4, 6, 3], k=1.125)
  181. def vgg_resnet50(): return ResNet(Bottleneck, [3, 4, 6, 3], vgg_head=True)