resnext.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. from torch.nn import init
  4. import math
  5. class ResNeXtBottleneck(nn.Module):
  6. expansion = 4
  7. """
  8. RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
  9. """
  10. def __init__(self, inplanes, planes, cardinality, base_width, stride=1, downsample=None):
  11. super(ResNeXtBottleneck, self).__init__()
  12. self.downsample = downsample
  13. D = int(math.floor(planes * (base_width/64.0)))
  14. C = cardinality
  15. self.conv_reduce = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
  16. self.bn_reduce = nn.BatchNorm2d(D*C)
  17. self.conv_conv = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
  18. self.bn = nn.BatchNorm2d(D*C)
  19. self.conv_expand = nn.Conv2d(D*C, planes*4, kernel_size=1, stride=1, padding=0, bias=False)
  20. self.bn_expand = nn.BatchNorm2d(planes*4)
  21. def forward(self, x):
  22. residual = x
  23. bottleneck = self.conv_reduce(x)
  24. bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True)
  25. bottleneck = self.conv_conv(bottleneck)
  26. bottleneck = F.relu(self.bn(bottleneck), inplace=True)
  27. bottleneck = self.conv_expand(bottleneck)
  28. bottleneck = self.bn_expand(bottleneck)
  29. if self.downsample is not None: residual = self.downsample(x)
  30. return F.relu(residual + bottleneck, inplace=True)
  31. class CifarResNeXt(nn.Module):
  32. """
  33. ResNext optimized for the Cifar dataset, as specified in
  34. https://arxiv.org/pdf/1611.05431.pdf
  35. """
  36. def __init__(self, block, depth, cardinality, base_width, num_classes):
  37. super(CifarResNeXt, self).__init__()
  38. # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
  39. assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101'
  40. self.layer_blocks = (depth - 2) // 9
  41. self.cardinality,self.base_width,self.num_classes,self.block = cardinality,base_width,num_classes,block
  42. self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
  43. self.bn_1 = nn.BatchNorm2d(64)
  44. self.inplanes = 64
  45. self.stage_1 = self._make_layer(64 , 1)
  46. self.stage_2 = self._make_layer(128, 2)
  47. self.stage_3 = self._make_layer(256, 2)
  48. self.avgpool = nn.AdaptiveAvgPool2d((1,1))
  49. self.classifier = nn.Linear(256*block.expansion, num_classes)
  50. for m in self.modules():
  51. if isinstance(m, nn.Conv2d):
  52. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  53. m.weight.data.normal_(0, math.sqrt(2. / n))
  54. elif isinstance(m, nn.BatchNorm2d):
  55. m.weight.data.fill_(1)
  56. m.bias.data.zero_()
  57. elif isinstance(m, nn.Linear):
  58. init.kaiming_normal(m.weight)
  59. m.bias.data.zero_()
  60. def _make_layer(self, planes, stride=1):
  61. downsample = None
  62. exp_planes = planes * self.block.expansion
  63. if stride != 1 or self.inplanes != exp_planes:
  64. downsample = nn.Sequential(
  65. nn.Conv2d(self.inplanes, exp_planes, kernel_size=1, stride=stride, bias=False),
  66. nn.BatchNorm2d(exp_planes),
  67. )
  68. layers = []
  69. layers.append(self.block(self.inplanes, planes, self.cardinality, self.base_width, stride, downsample))
  70. self.inplanes = exp_planes
  71. for i in range(1, self.layer_blocks):
  72. layers.append(self.block(self.inplanes, planes, self.cardinality, self.base_width))
  73. return nn.Sequential(*layers)
  74. def forward(self, x):
  75. x = self.conv_1_3x3(x)
  76. x = F.relu(self.bn_1(x), inplace=True)
  77. x = self.stage_1(x)
  78. x = self.stage_2(x)
  79. x = self.stage_3(x)
  80. x = self.avgpool(x)
  81. x = x.view(x.size(0), -1)
  82. return F.log_softmax(self.classifier(x))
  83. def resnext29_16_64(num_classes=10):
  84. """Constructs a ResNeXt-29, 16*64d model for CIFAR-10 (by default)
  85. Args:
  86. num_classes (uint): number of classes
  87. """
  88. model = CifarResNeXt(ResNeXtBottleneck, 29, 16, 64, num_classes)
  89. return model
  90. def resnext29_8_64(num_classes=10):
  91. """Constructs a ResNeXt-29, 8*64d model for CIFAR-10 (by default)
  92. Args:
  93. num_classes (uint): number of classes
  94. """
  95. model = CifarResNeXt(ResNeXtBottleneck, 29, 8, 64, num_classes)
  96. return model