senet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. '''
  2. File name: senet.py
  3. Squeeze-and-Excitation Networks (SeNet) implementation for fast.ai/pytorch with pretrained model
  4. Credit https://github.com/hujie-frank/SENet
  5. SENet is the winner of ImageNet-2017 (https://arxiv.org/pdf/1709.01507.pdf).
  6. '''
  7. from collections import OrderedDict
  8. import math
  9. import torch.utils.model_zoo as model_zoo
  10. from ..layers import *
  11. import torch.nn as nn
  12. from torch.utils import model_zoo
  13. __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
  14. 'se_resnext50_32x4d', 'se_resnext101_32x4d']
  15. pretrained_settings = {
  16. 'senet154': {
  17. 'imagenet': {
  18. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
  19. 'input_space': 'RGB',
  20. 'input_size': [3, 224, 224],
  21. 'input_range': [0, 1],
  22. 'mean': [0.485, 0.456, 0.406],
  23. 'std': [0.229, 0.224, 0.225],
  24. 'num_classes': 1000
  25. }
  26. },
  27. 'se_resnet50': {
  28. 'imagenet': {
  29. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
  30. 'input_space': 'RGB',
  31. 'input_size': [3, 224, 224],
  32. 'input_range': [0, 1],
  33. 'mean': [0.485, 0.456, 0.406],
  34. 'std': [0.229, 0.224, 0.225],
  35. 'num_classes': 1000
  36. }
  37. },
  38. 'se_resnet101': {
  39. 'imagenet': {
  40. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
  41. 'input_space': 'RGB',
  42. 'input_size': [3, 224, 224],
  43. 'input_range': [0, 1],
  44. 'mean': [0.485, 0.456, 0.406],
  45. 'std': [0.229, 0.224, 0.225],
  46. 'num_classes': 1000
  47. }
  48. },
  49. 'se_resnet152': {
  50. 'imagenet': {
  51. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
  52. 'input_space': 'RGB',
  53. 'input_size': [3, 224, 224],
  54. 'input_range': [0, 1],
  55. 'mean': [0.485, 0.456, 0.406],
  56. 'std': [0.229, 0.224, 0.225],
  57. 'num_classes': 1000
  58. }
  59. },
  60. 'se_resnext50_32x4d': {
  61. 'imagenet': {
  62. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
  63. 'input_space': 'RGB',
  64. 'input_size': [3, 224, 224],
  65. 'input_range': [0, 1],
  66. 'mean': [0.485, 0.456, 0.406],
  67. 'std': [0.229, 0.224, 0.225],
  68. 'num_classes': 1000
  69. }
  70. },
  71. 'se_resnext101_32x4d': {
  72. 'imagenet': {
  73. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
  74. 'input_space': 'RGB',
  75. 'input_size': [3, 224, 224],
  76. 'input_range': [0, 1],
  77. 'mean': [0.485, 0.456, 0.406],
  78. 'std': [0.229, 0.224, 0.225],
  79. 'num_classes': 1000
  80. }
  81. },
  82. }
  83. class SEModule(nn.Module):
  84. def __init__(self, channels, reduction):
  85. super(SEModule, self).__init__()
  86. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  87. self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
  88. padding=0)
  89. self.relu = nn.ReLU(inplace=True)
  90. self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
  91. padding=0)
  92. self.sigmoid = nn.Sigmoid()
  93. def forward(self, x):
  94. module_input = x
  95. x = self.avg_pool(x)
  96. x = self.fc1(x)
  97. x = self.relu(x)
  98. x = self.fc2(x)
  99. x = self.sigmoid(x)
  100. return module_input * x
  101. class Bottleneck(nn.Module):
  102. """
  103. Base class for bottlenecks that implements `forward()` method.
  104. """
  105. def forward(self, x):
  106. residual = x
  107. out = self.conv1(x)
  108. out = self.bn1(out)
  109. out = self.relu(out)
  110. out = self.conv2(out)
  111. out = self.bn2(out)
  112. out = self.relu(out)
  113. out = self.conv3(out)
  114. out = self.bn3(out)
  115. if self.downsample is not None:
  116. residual = self.downsample(x)
  117. out = self.se_module(out) + residual
  118. out = self.relu(out)
  119. return out
  120. class SEBottleneck(Bottleneck):
  121. """
  122. Bottleneck for SENet154.
  123. """
  124. expansion = 4
  125. def __init__(self, inplanes, planes, groups, reduction, stride=1,
  126. downsample=None):
  127. super(SEBottleneck, self).__init__()
  128. self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
  129. self.bn1 = nn.BatchNorm2d(planes * 2)
  130. self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
  131. stride=stride, padding=1, groups=groups,
  132. bias=False)
  133. self.bn2 = nn.BatchNorm2d(planes * 4)
  134. self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
  135. bias=False)
  136. self.bn3 = nn.BatchNorm2d(planes * 4)
  137. self.relu = nn.ReLU(inplace=True)
  138. self.se_module = SEModule(planes * 4, reduction=reduction)
  139. self.downsample = downsample
  140. self.stride = stride
  141. class SEResNetBottleneck(Bottleneck):
  142. """
  143. ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
  144. implementation and uses `stride=stride` in `conv1` and not in `conv2`
  145. (the latter is used in the torchvision implementation of ResNet).
  146. """
  147. expansion = 4
  148. def __init__(self, inplanes, planes, groups, reduction, stride=1,
  149. downsample=None):
  150. super(SEResNetBottleneck, self).__init__()
  151. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
  152. stride=stride)
  153. self.bn1 = nn.BatchNorm2d(planes)
  154. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
  155. groups=groups, bias=False)
  156. self.bn2 = nn.BatchNorm2d(planes)
  157. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  158. self.bn3 = nn.BatchNorm2d(planes * 4)
  159. self.relu = nn.ReLU(inplace=True)
  160. self.se_module = SEModule(planes * 4, reduction=reduction)
  161. self.downsample = downsample
  162. self.stride = stride
  163. class SEResNeXtBottleneck(Bottleneck):
  164. """
  165. ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
  166. """
  167. expansion = 4
  168. def __init__(self, inplanes, planes, groups, reduction, stride=1,
  169. downsample=None, base_width=4):
  170. super(SEResNeXtBottleneck, self).__init__()
  171. width = math.floor(planes * (base_width / 64)) * groups
  172. self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
  173. stride=1)
  174. self.bn1 = nn.BatchNorm2d(width)
  175. self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
  176. padding=1, groups=groups, bias=False)
  177. self.bn2 = nn.BatchNorm2d(width)
  178. self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
  179. self.bn3 = nn.BatchNorm2d(planes * 4)
  180. self.relu = nn.ReLU(inplace=True)
  181. self.se_module = SEModule(planes * 4, reduction=reduction)
  182. self.downsample = downsample
  183. self.stride = stride
  184. class SENet(nn.Module):
  185. def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
  186. inplanes=128, input_3x3=True, downsample_kernel_size=3,
  187. downsample_padding=1, num_classes=1000):
  188. """
  189. Parameters
  190. ----------
  191. block (nn.Module): Bottleneck class.
  192. - For SENet154: SEBottleneck
  193. - For SE-ResNet models: SEResNetBottleneck
  194. - For SE-ResNeXt models: SEResNeXtBottleneck
  195. layers (list of ints): Number of residual blocks for 4 layers of the
  196. network (layer1...layer4).
  197. groups (int): Number of groups for the 3x3 convolution in each
  198. bottleneck block.
  199. - For SENet154: 64
  200. - For SE-ResNet models: 1
  201. - For SE-ResNeXt models: 32
  202. reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
  203. - For all models: 16
  204. dropout_p (float or None): Drop probability for the Dropout layer.
  205. If `None` the Dropout layer is not used.
  206. - For SENet154: 0.2
  207. - For SE-ResNet models: None
  208. - For SE-ResNeXt models: None
  209. inplanes (int): Number of input channels for layer1.
  210. - For SENet154: 128
  211. - For SE-ResNet models: 64
  212. - For SE-ResNeXt models: 64
  213. input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
  214. a single 7x7 convolution in layer0.
  215. - For SENet154: True
  216. - For SE-ResNet models: False
  217. - For SE-ResNeXt models: False
  218. downsample_kernel_size (int): Kernel size for downsampling convolutions
  219. in layer2, layer3 and layer4.
  220. - For SENet154: 3
  221. - For SE-ResNet models: 1
  222. - For SE-ResNeXt models: 1
  223. downsample_padding (int): Padding for downsampling convolutions in
  224. layer2, layer3 and layer4.
  225. - For SENet154: 1
  226. - For SE-ResNet models: 0
  227. - For SE-ResNeXt models: 0
  228. num_classes (int): Number of outputs in `last_linear` layer.
  229. - For all models: 1000
  230. """
  231. super(SENet, self).__init__()
  232. self.inplanes = inplanes
  233. if input_3x3:
  234. layer0_modules = [
  235. ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
  236. bias=False)),
  237. ('bn1', nn.BatchNorm2d(64)),
  238. ('relu1', nn.ReLU(inplace=True)),
  239. ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
  240. bias=False)),
  241. ('bn2', nn.BatchNorm2d(64)),
  242. ('relu2', nn.ReLU(inplace=True)),
  243. ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
  244. bias=False)),
  245. ('bn3', nn.BatchNorm2d(inplanes)),
  246. ('relu3', nn.ReLU(inplace=True)),
  247. ]
  248. else:
  249. layer0_modules = [
  250. ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
  251. padding=3, bias=False)),
  252. ('bn1', nn.BatchNorm2d(inplanes)),
  253. ('relu1', nn.ReLU(inplace=True)),
  254. ]
  255. # To preserve compatibility with Caffe weights `ceil_mode=True`
  256. # is used instead of `padding=1`.
  257. layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
  258. ceil_mode=True)))
  259. self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
  260. self.layer1 = self._make_layer(
  261. block,
  262. planes=64,
  263. blocks=layers[0],
  264. groups=groups,
  265. reduction=reduction,
  266. downsample_kernel_size=1,
  267. downsample_padding=0
  268. )
  269. self.layer2 = self._make_layer(
  270. block,
  271. planes=128,
  272. blocks=layers[1],
  273. stride=2,
  274. groups=groups,
  275. reduction=reduction,
  276. downsample_kernel_size=downsample_kernel_size,
  277. downsample_padding=downsample_padding
  278. )
  279. self.layer3 = self._make_layer(
  280. block,
  281. planes=256,
  282. blocks=layers[2],
  283. stride=2,
  284. groups=groups,
  285. reduction=reduction,
  286. downsample_kernel_size=downsample_kernel_size,
  287. downsample_padding=downsample_padding
  288. )
  289. self.layer4 = self._make_layer(
  290. block,
  291. planes=512,
  292. blocks=layers[3],
  293. stride=2,
  294. groups=groups,
  295. reduction=reduction,
  296. downsample_kernel_size=downsample_kernel_size,
  297. downsample_padding=downsample_padding
  298. )
  299. self.avg_pool = nn.AvgPool2d(7, stride=1)
  300. self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
  301. self.last_linear = nn.Linear(512 * block.expansion, num_classes)
  302. def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
  303. downsample_kernel_size=1, downsample_padding=0):
  304. downsample = None
  305. if stride != 1 or self.inplanes != planes * block.expansion:
  306. downsample = nn.Sequential(
  307. nn.Conv2d(self.inplanes, planes * block.expansion,
  308. kernel_size=downsample_kernel_size, stride=stride,
  309. padding=downsample_padding, bias=False),
  310. nn.BatchNorm2d(planes * block.expansion),
  311. )
  312. layers = []
  313. layers.append(block(self.inplanes, planes, groups, reduction, stride,
  314. downsample))
  315. self.inplanes = planes * block.expansion
  316. for i in range(1, blocks):
  317. layers.append(block(self.inplanes, planes, groups, reduction))
  318. return nn.Sequential(*layers)
  319. def features(self, x):
  320. x = self.layer0(x)
  321. x = self.layer1(x)
  322. x = self.layer2(x)
  323. x = self.layer3(x)
  324. x = self.layer4(x)
  325. return x
  326. def logits(self, x):
  327. x = self.avg_pool(x)
  328. if self.dropout is not None:
  329. x = self.dropout(x)
  330. x = x.view(x.size(0), -1)
  331. x = self.last_linear(x)
  332. return x
  333. def forward(self, x):
  334. x = self.features(x)
  335. x = self.logits(x)
  336. return x
  337. def initialize_pretrained_model(model, num_classes, settings):
  338. assert num_classes == settings['num_classes'], \
  339. 'num_classes should be {}, but is {}'.format(
  340. settings['num_classes'], num_classes)
  341. model.load_state_dict(model_zoo.load_url(settings['url']))
  342. model.input_space = settings['input_space']
  343. model.input_size = settings['input_size']
  344. model.input_range = settings['input_range']
  345. model.mean = settings['mean']
  346. model.std = settings['std']
  347. def senet154(num_classes=1000, pretrained='imagenet'):
  348. model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
  349. dropout_p=0.2, num_classes=num_classes)
  350. if pretrained is not None:
  351. settings = pretrained_settings['senet154'][pretrained]
  352. initialize_pretrained_model(model, num_classes, settings)
  353. return model
  354. def se_resnet50(num_classes=1000, pretrained='imagenet'):
  355. model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
  356. dropout_p=None, inplanes=64, input_3x3=False,
  357. downsample_kernel_size=1, downsample_padding=0,
  358. num_classes=num_classes)
  359. if pretrained is not None:
  360. settings = pretrained_settings['se_resnet50'][pretrained]
  361. initialize_pretrained_model(model, num_classes, settings)
  362. return model
  363. def se_resnet101(num_classes=1000, pretrained='imagenet'):
  364. model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
  365. dropout_p=None, inplanes=64, input_3x3=False,
  366. downsample_kernel_size=1, downsample_padding=0,
  367. num_classes=num_classes)
  368. if pretrained is not None:
  369. settings = pretrained_settings['se_resnet101'][pretrained]
  370. initialize_pretrained_model(model, num_classes, settings)
  371. return model
  372. def se_resnet152(num_classes=1000, pretrained='imagenet'):
  373. model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
  374. dropout_p=None, inplanes=64, input_3x3=False,
  375. downsample_kernel_size=1, downsample_padding=0,
  376. num_classes=num_classes)
  377. if pretrained is not None:
  378. settings = pretrained_settings['se_resnet152'][pretrained]
  379. initialize_pretrained_model(model, num_classes, settings)
  380. return model
  381. def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
  382. model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
  383. dropout_p=None, inplanes=64, input_3x3=False,
  384. downsample_kernel_size=1, downsample_padding=0,
  385. num_classes=num_classes)
  386. if pretrained is not None:
  387. settings = pretrained_settings['se_resnext50_32x4d'][pretrained]
  388. initialize_pretrained_model(model, num_classes, settings)
  389. return model
  390. def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
  391. model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
  392. dropout_p=None, inplanes=64, input_3x3=False,
  393. downsample_kernel_size=1, downsample_padding=0,
  394. num_classes=num_classes)
  395. if pretrained is not None:
  396. settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
  397. initialize_pretrained_model(model, num_classes, settings)
  398. return model