nasnet.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.utils.model_zoo as model_zoo
  5. from torch.autograd import Variable
  6. pretrained_settings = {
  7. 'nasnetalarge': {
  8. 'imagenet': {
  9. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
  10. 'input_space': 'RGB',
  11. 'input_size': [3, 331, 331], # resize 354
  12. 'input_range': [0, 1],
  13. 'mean': [0.5, 0.5, 0.5],
  14. 'std': [0.5, 0.5, 0.5],
  15. 'num_classes': 1000
  16. },
  17. 'imagenet+background': {
  18. 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
  19. 'input_space': 'RGB',
  20. 'input_size': [3, 331, 331], # resize 354
  21. 'input_range': [0, 1],
  22. 'mean': [0.5, 0.5, 0.5],
  23. 'std': [0.5, 0.5, 0.5],
  24. 'num_classes': 1001
  25. }
  26. }
  27. }
  28. class MaxPoolPad(nn.Module):
  29. def __init__(self):
  30. super(MaxPoolPad, self).__init__()
  31. self.pad = nn.ZeroPad2d((1, 0, 1, 0))
  32. self.pool = nn.MaxPool2d(3, stride=2, padding=1)
  33. def forward(self, x):
  34. x = self.pad(x)
  35. x = self.pool(x)
  36. x = x[:, :, 1:, 1:]
  37. return x
  38. class AvgPoolPad(nn.Module):
  39. def __init__(self, stride=2, padding=1):
  40. super(AvgPoolPad, self).__init__()
  41. self.pad = nn.ZeroPad2d((1, 0, 1, 0))
  42. self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
  43. def forward(self, x):
  44. x = self.pad(x)
  45. x = self.pool(x)
  46. x = x[:, :, 1:, 1:]
  47. return x
  48. class SeparableConv2d(nn.Module):
  49. def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
  50. super(SeparableConv2d, self).__init__()
  51. self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel,
  52. stride=dw_stride,
  53. padding=dw_padding,
  54. bias=bias,
  55. groups=in_channels)
  56. self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
  57. def forward(self, x):
  58. x = self.depthwise_conv2d(x)
  59. x = self.pointwise_conv2d(x)
  60. return x
  61. class BranchSeparables(nn.Module):
  62. def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
  63. super(BranchSeparables, self).__init__()
  64. self.relu = nn.ReLU()
  65. self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
  66. self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
  67. self.relu1 = nn.ReLU()
  68. self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
  69. self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
  70. def forward(self, x):
  71. x = self.relu(x)
  72. x = self.separable_1(x)
  73. x = self.bn_sep_1(x)
  74. x = self.relu1(x)
  75. x = self.separable_2(x)
  76. x = self.bn_sep_2(x)
  77. return x
  78. class BranchSeparablesStem(nn.Module):
  79. def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
  80. super(BranchSeparablesStem, self).__init__()
  81. self.relu = nn.ReLU()
  82. self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
  83. self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
  84. self.relu1 = nn.ReLU()
  85. self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
  86. self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
  87. def forward(self, x):
  88. x = self.relu(x)
  89. x = self.separable_1(x)
  90. x = self.bn_sep_1(x)
  91. x = self.relu1(x)
  92. x = self.separable_2(x)
  93. x = self.bn_sep_2(x)
  94. return x
  95. class BranchSeparablesReduction(BranchSeparables):
  96. def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
  97. BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
  98. self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
  99. def forward(self, x):
  100. x = self.relu(x)
  101. x = self.padding(x)
  102. x = self.separable_1(x)
  103. x = x[:, :, 1:, 1:].contiguous()
  104. x = self.bn_sep_1(x)
  105. x = self.relu1(x)
  106. x = self.separable_2(x)
  107. x = self.bn_sep_2(x)
  108. return x
  109. class CellStem0(nn.Module):
  110. def __init__(self):
  111. super(CellStem0, self).__init__()
  112. self.conv_1x1 = nn.Sequential()
  113. self.conv_1x1.add_module('relu', nn.ReLU())
  114. self.conv_1x1.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
  115. self.conv_1x1.add_module('bn', nn.BatchNorm2d(42, eps=0.001, momentum=0.1, affine=True))
  116. self.comb_iter_0_left = BranchSeparables(42, 42, 5, 2, 2)
  117. self.comb_iter_0_right = BranchSeparablesStem(96, 42, 7, 2, 3, bias=False)
  118. self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
  119. self.comb_iter_1_right = BranchSeparablesStem(96, 42, 7, 2, 3, bias=False)
  120. self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
  121. self.comb_iter_2_right = BranchSeparablesStem(96, 42, 5, 2, 2, bias=False)
  122. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  123. self.comb_iter_4_left = BranchSeparables(42, 42, 3, 1, 1, bias=False)
  124. self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
  125. def forward(self, x):
  126. x1 = self.conv_1x1(x)
  127. x_comb_iter_0_left = self.comb_iter_0_left(x1)
  128. x_comb_iter_0_right = self.comb_iter_0_right(x)
  129. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  130. x_comb_iter_1_left = self.comb_iter_1_left(x1)
  131. x_comb_iter_1_right = self.comb_iter_1_right(x)
  132. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  133. x_comb_iter_2_left = self.comb_iter_2_left(x1)
  134. x_comb_iter_2_right = self.comb_iter_2_right(x)
  135. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  136. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  137. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  138. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  139. x_comb_iter_4_right = self.comb_iter_4_right(x1)
  140. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  141. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  142. return x_out
  143. class CellStem1(nn.Module):
  144. def __init__(self):
  145. super(CellStem1, self).__init__()
  146. self.conv_1x1 = nn.Sequential()
  147. self.conv_1x1.add_module('relu', nn.ReLU())
  148. self.conv_1x1.add_module('conv', nn.Conv2d(168, 84, 1, stride=1, bias=False))
  149. self.conv_1x1.add_module('bn', nn.BatchNorm2d(84, eps=0.001, momentum=0.1, affine=True))
  150. self.relu = nn.ReLU()
  151. self.path_1 = nn.Sequential()
  152. self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  153. self.path_1.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
  154. self.path_2 = nn.ModuleList()
  155. self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
  156. self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  157. self.path_2.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
  158. self.final_path_bn = nn.BatchNorm2d(84, eps=0.001, momentum=0.1, affine=True)
  159. self.comb_iter_0_left = BranchSeparables(84, 84, 5, 2, 2, bias=False)
  160. self.comb_iter_0_right = BranchSeparables(84, 84, 7, 2, 3, bias=False)
  161. self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
  162. self.comb_iter_1_right = BranchSeparables(84, 84, 7, 2, 3, bias=False)
  163. self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
  164. self.comb_iter_2_right = BranchSeparables(84, 84, 5, 2, 2, bias=False)
  165. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  166. self.comb_iter_4_left = BranchSeparables(84, 84, 3, 1, 1, bias=False)
  167. self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
  168. def forward(self, x_conv0, x_stem_0):
  169. x_left = self.conv_1x1(x_stem_0)
  170. x_relu = self.relu(x_conv0)
  171. # path 1
  172. x_path1 = self.path_1(x_relu)
  173. # path 2
  174. x_path2 = self.path_2.pad(x_relu)
  175. x_path2 = x_path2[:, :, 1:, 1:]
  176. x_path2 = self.path_2.avgpool(x_path2)
  177. x_path2 = self.path_2.conv(x_path2)
  178. # final path
  179. x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
  180. x_comb_iter_0_left = self.comb_iter_0_left(x_left)
  181. x_comb_iter_0_right = self.comb_iter_0_right(x_right)
  182. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  183. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  184. x_comb_iter_1_right = self.comb_iter_1_right(x_right)
  185. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  186. x_comb_iter_2_left = self.comb_iter_2_left(x_left)
  187. x_comb_iter_2_right = self.comb_iter_2_right(x_right)
  188. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  189. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  190. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  191. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  192. x_comb_iter_4_right = self.comb_iter_4_right(x_left)
  193. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  194. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  195. return x_out
  196. class FirstCell(nn.Module):
  197. def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
  198. super(FirstCell, self).__init__()
  199. self.conv_1x1 = nn.Sequential()
  200. self.conv_1x1.add_module('relu', nn.ReLU())
  201. self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
  202. self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
  203. self.relu = nn.ReLU()
  204. self.path_1 = nn.Sequential()
  205. self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  206. self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
  207. self.path_2 = nn.ModuleList()
  208. self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
  209. self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
  210. self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
  211. self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
  212. self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
  213. self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  214. self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
  215. self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  216. self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  217. self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  218. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  219. self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  220. def forward(self, x, x_prev):
  221. x_relu = self.relu(x_prev)
  222. # path 1
  223. x_path1 = self.path_1(x_relu)
  224. # path 2
  225. x_path2 = self.path_2.pad(x_relu)
  226. x_path2 = x_path2[:, :, 1:, 1:]
  227. x_path2 = self.path_2.avgpool(x_path2)
  228. x_path2 = self.path_2.conv(x_path2)
  229. # final path
  230. x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
  231. x_right = self.conv_1x1(x)
  232. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  233. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  234. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  235. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  236. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  237. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  238. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  239. x_comb_iter_2 = x_comb_iter_2_left + x_left
  240. x_comb_iter_3_left = self.comb_iter_3_left(x_left)
  241. x_comb_iter_3_right = self.comb_iter_3_right(x_left)
  242. x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
  243. x_comb_iter_4_left = self.comb_iter_4_left(x_right)
  244. x_comb_iter_4 = x_comb_iter_4_left + x_right
  245. x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  246. return x_out
  247. class NormalCell(nn.Module):
  248. def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
  249. super(NormalCell, self).__init__()
  250. self.conv_prev_1x1 = nn.Sequential()
  251. self.conv_prev_1x1.add_module('relu', nn.ReLU())
  252. self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
  253. self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
  254. self.conv_1x1 = nn.Sequential()
  255. self.conv_1x1.add_module('relu', nn.ReLU())
  256. self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
  257. self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
  258. self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
  259. self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
  260. self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
  261. self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
  262. self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  263. self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  264. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  265. self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  266. def forward(self, x, x_prev):
  267. x_left = self.conv_prev_1x1(x_prev)
  268. x_right = self.conv_1x1(x)
  269. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  270. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  271. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  272. x_comb_iter_1_left = self.comb_iter_1_left(x_left)
  273. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  274. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  275. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  276. x_comb_iter_2 = x_comb_iter_2_left + x_left
  277. x_comb_iter_3_left = self.comb_iter_3_left(x_left)
  278. x_comb_iter_3_right = self.comb_iter_3_right(x_left)
  279. x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
  280. x_comb_iter_4_left = self.comb_iter_4_left(x_right)
  281. x_comb_iter_4 = x_comb_iter_4_left + x_right
  282. x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  283. return x_out
  284. class ReductionCell0(nn.Module):
  285. def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
  286. super(ReductionCell0, self).__init__()
  287. self.conv_prev_1x1 = nn.Sequential()
  288. self.conv_prev_1x1.add_module('relu', nn.ReLU())
  289. self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
  290. self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
  291. self.conv_1x1 = nn.Sequential()
  292. self.conv_1x1.add_module('relu', nn.ReLU())
  293. self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
  294. self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
  295. self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
  296. self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
  297. self.comb_iter_1_left = MaxPoolPad()
  298. self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
  299. self.comb_iter_2_left = AvgPoolPad()
  300. self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
  301. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  302. self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  303. self.comb_iter_4_right = MaxPoolPad()
  304. def forward(self, x, x_prev):
  305. x_left = self.conv_prev_1x1(x_prev)
  306. x_right = self.conv_1x1(x)
  307. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  308. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  309. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  310. x_comb_iter_1_left = self.comb_iter_1_left(x_right)
  311. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  312. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  313. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  314. x_comb_iter_2_right = self.comb_iter_2_right(x_left)
  315. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  316. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  317. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  318. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  319. x_comb_iter_4_right = self.comb_iter_4_right(x_right)
  320. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  321. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  322. return x_out
  323. class ReductionCell1(nn.Module):
  324. def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
  325. super(ReductionCell1, self).__init__()
  326. self.conv_prev_1x1 = nn.Sequential()
  327. self.conv_prev_1x1.add_module('relu', nn.ReLU())
  328. self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
  329. self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
  330. self.conv_1x1 = nn.Sequential()
  331. self.conv_1x1.add_module('relu', nn.ReLU())
  332. self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
  333. self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
  334. self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
  335. self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
  336. self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
  337. self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
  338. self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
  339. self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
  340. self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
  341. self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
  342. self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
  343. def forward(self, x, x_prev):
  344. x_left = self.conv_prev_1x1(x_prev)
  345. x_right = self.conv_1x1(x)
  346. x_comb_iter_0_left = self.comb_iter_0_left(x_right)
  347. x_comb_iter_0_right = self.comb_iter_0_right(x_left)
  348. x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
  349. x_comb_iter_1_left = self.comb_iter_1_left(x_right)
  350. x_comb_iter_1_right = self.comb_iter_1_right(x_left)
  351. x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
  352. x_comb_iter_2_left = self.comb_iter_2_left(x_right)
  353. x_comb_iter_2_right = self.comb_iter_2_right(x_left)
  354. x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
  355. x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
  356. x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
  357. x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
  358. x_comb_iter_4_right = self.comb_iter_4_right(x_right)
  359. x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
  360. x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
  361. return x_out
  362. class NASNetALarge(nn.Module):
  363. def __init__(self, use_classifier=False, num_classes=1001):
  364. super(NASNetALarge, self).__init__()
  365. self.use_classifier,self.num_classes = use_classifier,num_classes
  366. self.conv0 = nn.Sequential()
  367. self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, padding=0, stride=2,
  368. bias=False))
  369. self.conv0.add_module('bn', nn.BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True))
  370. self.cell_stem_0 = CellStem0()
  371. self.cell_stem_1 = CellStem1()
  372. self.cell_0 = FirstCell(in_channels_left=168, out_channels_left=84,
  373. in_channels_right=336, out_channels_right=168)
  374. self.cell_1 = NormalCell(in_channels_left=336, out_channels_left=168,
  375. in_channels_right=1008, out_channels_right=168)
  376. self.cell_2 = NormalCell(in_channels_left=1008, out_channels_left=168,
  377. in_channels_right=1008, out_channels_right=168)
  378. self.cell_3 = NormalCell(in_channels_left=1008, out_channels_left=168,
  379. in_channels_right=1008, out_channels_right=168)
  380. self.cell_4 = NormalCell(in_channels_left=1008, out_channels_left=168,
  381. in_channels_right=1008, out_channels_right=168)
  382. self.cell_5 = NormalCell(in_channels_left=1008, out_channels_left=168,
  383. in_channels_right=1008, out_channels_right=168)
  384. self.reduction_cell_0 = ReductionCell0(in_channels_left=1008, out_channels_left=336,
  385. in_channels_right=1008, out_channels_right=336)
  386. self.cell_6 = FirstCell(in_channels_left=1008, out_channels_left=168,
  387. in_channels_right=1344, out_channels_right=336)
  388. self.cell_7 = NormalCell(in_channels_left=1344, out_channels_left=336,
  389. in_channels_right=2016, out_channels_right=336)
  390. self.cell_8 = NormalCell(in_channels_left=2016, out_channels_left=336,
  391. in_channels_right=2016, out_channels_right=336)
  392. self.cell_9 = NormalCell(in_channels_left=2016, out_channels_left=336,
  393. in_channels_right=2016, out_channels_right=336)
  394. self.cell_10 = NormalCell(in_channels_left=2016, out_channels_left=336,
  395. in_channels_right=2016, out_channels_right=336)
  396. self.cell_11 = NormalCell(in_channels_left=2016, out_channels_left=336,
  397. in_channels_right=2016, out_channels_right=336)
  398. self.reduction_cell_1 = ReductionCell1(in_channels_left=2016, out_channels_left=672,
  399. in_channels_right=2016, out_channels_right=672)
  400. self.cell_12 = FirstCell(in_channels_left=2016, out_channels_left=336,
  401. in_channels_right=2688, out_channels_right=672)
  402. self.cell_13 = NormalCell(in_channels_left=2688, out_channels_left=672,
  403. in_channels_right=4032, out_channels_right=672)
  404. self.cell_14 = NormalCell(in_channels_left=4032, out_channels_left=672,
  405. in_channels_right=4032, out_channels_right=672)
  406. self.cell_15 = NormalCell(in_channels_left=4032, out_channels_left=672,
  407. in_channels_right=4032, out_channels_right=672)
  408. self.cell_16 = NormalCell(in_channels_left=4032, out_channels_left=672,
  409. in_channels_right=4032, out_channels_right=672)
  410. self.cell_17 = NormalCell(in_channels_left=4032, out_channels_left=672,
  411. in_channels_right=4032, out_channels_right=672)
  412. self.relu = nn.ReLU()
  413. self.dropout = nn.Dropout()
  414. self.last_linear = nn.Linear(4032, self.num_classes)
  415. def features(self, x):
  416. x_conv0 = self.conv0(x)
  417. x_stem_0 = self.cell_stem_0(x_conv0)
  418. x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
  419. x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
  420. x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
  421. x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
  422. x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
  423. x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
  424. x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
  425. x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
  426. x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
  427. x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
  428. x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
  429. x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
  430. x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
  431. x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
  432. x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
  433. x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
  434. x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
  435. x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
  436. x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
  437. x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
  438. x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
  439. return self.relu(x_cell_17)
  440. def classifier(self, x):
  441. x = F.adaptive_max_pool2d(x, 1)
  442. x = x.view(x.size(0), -1)
  443. x = self.dropout(x)
  444. return F.log_softmax(self.linear(x))
  445. def forward(self, x):
  446. x = self.features(x)
  447. if self.use_classifier: x = self.classifier(x)
  448. return x
  449. def nasnetalarge(num_classes=1000, pretrained='imagenet'):
  450. r"""NASNetALarge model architecture from the
  451. `"NASNet" <https://arxiv.org/abs/1707.07012>`_ paper.
  452. """
  453. if pretrained:
  454. settings = pretrained_settings['nasnetalarge'][pretrained]
  455. assert num_classes == settings['num_classes'], \
  456. "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
  457. # both 'imagenet'&'imagenet+background' are loaded from same parameters
  458. model = NASNetALarge(num_classes=1001)
  459. model.load_state_dict(model_zoo.load_url(settings['url']))
  460. if pretrained == 'imagenet':
  461. new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
  462. new_last_linear.weight.data = model.last_linear.weight.data[1:]
  463. new_last_linear.bias.data = model.last_linear.bias.data[1:]
  464. model.last_linear = new_last_linear
  465. model.input_space = settings['input_space']
  466. model.input_size = settings['input_size']
  467. model.input_range = settings['input_range']
  468. model.mean = settings['mean']
  469. model.std = settings['std']
  470. else:
  471. model = NASNetALarge(num_classes=num_classes)
  472. return model