wrn_50_2f.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. from functools import reduce
  5. class LambdaBase(nn.Sequential):
  6. def __init__(self, fn, *args):
  7. super(LambdaBase, self).__init__(*args)
  8. self.lambda_func = fn
  9. def forward_prepare(self, input):
  10. output = []
  11. for module in self._modules.values():
  12. output.append(module(input))
  13. return output if output else input
  14. class Lambda(LambdaBase):
  15. def forward(self, input):
  16. return self.lambda_func(self.forward_prepare(input))
  17. class LambdaMap(LambdaBase):
  18. def forward(self, input):
  19. return list(map(self.lambda_func,self.forward_prepare(input)))
  20. class LambdaReduce(LambdaBase):
  21. def forward(self, input):
  22. return reduce(self.lambda_func,self.forward_prepare(input))
  23. def wrn_50_2f(): return nn.Sequential( # Sequential,
  24. nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False),
  25. nn.BatchNorm2d(64),
  26. nn.ReLU(),
  27. nn.MaxPool2d((3, 3),(2, 2),(1, 1)),
  28. nn.Sequential( # Sequential,
  29. nn.Sequential( # Sequential,
  30. LambdaMap(lambda x: x, # ConcatTable,
  31. nn.Sequential( # Sequential,
  32. nn.Conv2d(64,128,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  33. nn.BatchNorm2d(128),
  34. nn.ReLU(),
  35. nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  36. nn.BatchNorm2d(128),
  37. nn.ReLU(),
  38. nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  39. nn.BatchNorm2d(256),
  40. ),
  41. nn.Sequential( # Sequential,
  42. nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  43. nn.BatchNorm2d(256),
  44. ),
  45. ),
  46. LambdaReduce(lambda x,y: x+y), # CAddTable,
  47. nn.ReLU(),
  48. ),
  49. nn.Sequential( # Sequential,
  50. LambdaMap(lambda x: x, # ConcatTable,
  51. nn.Sequential( # Sequential,
  52. nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  53. nn.BatchNorm2d(128),
  54. nn.ReLU(),
  55. nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  56. nn.BatchNorm2d(128),
  57. nn.ReLU(),
  58. nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  59. nn.BatchNorm2d(256),
  60. ),
  61. Lambda(lambda x: x), # Identity,
  62. ),
  63. LambdaReduce(lambda x,y: x+y), # CAddTable,
  64. nn.ReLU(),
  65. ),
  66. nn.Sequential( # Sequential,
  67. LambdaMap(lambda x: x, # ConcatTable,
  68. nn.Sequential( # Sequential,
  69. nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  70. nn.BatchNorm2d(128),
  71. nn.ReLU(),
  72. nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  73. nn.BatchNorm2d(128),
  74. nn.ReLU(),
  75. nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  76. nn.BatchNorm2d(256),
  77. ),
  78. Lambda(lambda x: x), # Identity,
  79. ),
  80. LambdaReduce(lambda x,y: x+y), # CAddTable,
  81. nn.ReLU(),
  82. ),
  83. ),
  84. nn.Sequential( # Sequential,
  85. nn.Sequential( # Sequential,
  86. LambdaMap(lambda x: x, # ConcatTable,
  87. nn.Sequential( # Sequential,
  88. nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  89. nn.BatchNorm2d(256),
  90. nn.ReLU(),
  91. nn.Conv2d(256,256,(3, 3),(2, 2),(1, 1),1,1,bias=False),
  92. nn.BatchNorm2d(256),
  93. nn.ReLU(),
  94. nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  95. nn.BatchNorm2d(512),
  96. ),
  97. nn.Sequential( # Sequential,
  98. nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False),
  99. nn.BatchNorm2d(512),
  100. ),
  101. ),
  102. LambdaReduce(lambda x,y: x+y), # CAddTable,
  103. nn.ReLU(),
  104. ),
  105. nn.Sequential( # Sequential,
  106. LambdaMap(lambda x: x, # ConcatTable,
  107. nn.Sequential( # Sequential,
  108. nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  109. nn.BatchNorm2d(256),
  110. nn.ReLU(),
  111. nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  112. nn.BatchNorm2d(256),
  113. nn.ReLU(),
  114. nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  115. nn.BatchNorm2d(512),
  116. ),
  117. Lambda(lambda x: x), # Identity,
  118. ),
  119. LambdaReduce(lambda x,y: x+y), # CAddTable,
  120. nn.ReLU(),
  121. ),
  122. nn.Sequential( # Sequential,
  123. LambdaMap(lambda x: x, # ConcatTable,
  124. nn.Sequential( # Sequential,
  125. nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  126. nn.BatchNorm2d(256),
  127. nn.ReLU(),
  128. nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  129. nn.BatchNorm2d(256),
  130. nn.ReLU(),
  131. nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  132. nn.BatchNorm2d(512),
  133. ),
  134. Lambda(lambda x: x), # Identity,
  135. ),
  136. LambdaReduce(lambda x,y: x+y), # CAddTable,
  137. nn.ReLU(),
  138. ),
  139. nn.Sequential( # Sequential,
  140. LambdaMap(lambda x: x, # ConcatTable,
  141. nn.Sequential( # Sequential,
  142. nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  143. nn.BatchNorm2d(256),
  144. nn.ReLU(),
  145. nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  146. nn.BatchNorm2d(256),
  147. nn.ReLU(),
  148. nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  149. nn.BatchNorm2d(512),
  150. ),
  151. Lambda(lambda x: x), # Identity,
  152. ),
  153. LambdaReduce(lambda x,y: x+y), # CAddTable,
  154. nn.ReLU(),
  155. ),
  156. ),
  157. nn.Sequential( # Sequential,
  158. nn.Sequential( # Sequential,
  159. LambdaMap(lambda x: x, # ConcatTable,
  160. nn.Sequential( # Sequential,
  161. nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  162. nn.BatchNorm2d(512),
  163. nn.ReLU(),
  164. nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,1,bias=False),
  165. nn.BatchNorm2d(512),
  166. nn.ReLU(),
  167. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  168. nn.BatchNorm2d(1024),
  169. ),
  170. nn.Sequential( # Sequential,
  171. nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False),
  172. nn.BatchNorm2d(1024),
  173. ),
  174. ),
  175. LambdaReduce(lambda x,y: x+y), # CAddTable,
  176. nn.ReLU(),
  177. ),
  178. nn.Sequential( # Sequential,
  179. LambdaMap(lambda x: x, # ConcatTable,
  180. nn.Sequential( # Sequential,
  181. nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  182. nn.BatchNorm2d(512),
  183. nn.ReLU(),
  184. nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  185. nn.BatchNorm2d(512),
  186. nn.ReLU(),
  187. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  188. nn.BatchNorm2d(1024),
  189. ),
  190. Lambda(lambda x: x), # Identity,
  191. ),
  192. LambdaReduce(lambda x,y: x+y), # CAddTable,
  193. nn.ReLU(),
  194. ),
  195. nn.Sequential( # Sequential,
  196. LambdaMap(lambda x: x, # ConcatTable,
  197. nn.Sequential( # Sequential,
  198. nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  199. nn.BatchNorm2d(512),
  200. nn.ReLU(),
  201. nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  202. nn.BatchNorm2d(512),
  203. nn.ReLU(),
  204. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  205. nn.BatchNorm2d(1024),
  206. ),
  207. Lambda(lambda x: x), # Identity,
  208. ),
  209. LambdaReduce(lambda x,y: x+y), # CAddTable,
  210. nn.ReLU(),
  211. ),
  212. nn.Sequential( # Sequential,
  213. LambdaMap(lambda x: x, # ConcatTable,
  214. nn.Sequential( # Sequential,
  215. nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  216. nn.BatchNorm2d(512),
  217. nn.ReLU(),
  218. nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  219. nn.BatchNorm2d(512),
  220. nn.ReLU(),
  221. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  222. nn.BatchNorm2d(1024),
  223. ),
  224. Lambda(lambda x: x), # Identity,
  225. ),
  226. LambdaReduce(lambda x,y: x+y), # CAddTable,
  227. nn.ReLU(),
  228. ),
  229. nn.Sequential( # Sequential,
  230. LambdaMap(lambda x: x, # ConcatTable,
  231. nn.Sequential( # Sequential,
  232. nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  233. nn.BatchNorm2d(512),
  234. nn.ReLU(),
  235. nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  236. nn.BatchNorm2d(512),
  237. nn.ReLU(),
  238. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  239. nn.BatchNorm2d(1024),
  240. ),
  241. Lambda(lambda x: x), # Identity,
  242. ),
  243. LambdaReduce(lambda x,y: x+y), # CAddTable,
  244. nn.ReLU(),
  245. ),
  246. nn.Sequential( # Sequential,
  247. LambdaMap(lambda x: x, # ConcatTable,
  248. nn.Sequential( # Sequential,
  249. nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  250. nn.BatchNorm2d(512),
  251. nn.ReLU(),
  252. nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  253. nn.BatchNorm2d(512),
  254. nn.ReLU(),
  255. nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  256. nn.BatchNorm2d(1024),
  257. ),
  258. Lambda(lambda x: x), # Identity,
  259. ),
  260. LambdaReduce(lambda x,y: x+y), # CAddTable,
  261. nn.ReLU(),
  262. ),
  263. ),
  264. nn.Sequential( # Sequential,
  265. nn.Sequential( # Sequential,
  266. LambdaMap(lambda x: x, # ConcatTable,
  267. nn.Sequential( # Sequential,
  268. nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  269. nn.BatchNorm2d(1024),
  270. nn.ReLU(),
  271. nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,1,bias=False),
  272. nn.BatchNorm2d(1024),
  273. nn.ReLU(),
  274. nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  275. nn.BatchNorm2d(2048),
  276. ),
  277. nn.Sequential( # Sequential,
  278. nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False),
  279. nn.BatchNorm2d(2048),
  280. ),
  281. ),
  282. LambdaReduce(lambda x,y: x+y), # CAddTable,
  283. nn.ReLU(),
  284. ),
  285. nn.Sequential( # Sequential,
  286. LambdaMap(lambda x: x, # ConcatTable,
  287. nn.Sequential( # Sequential,
  288. nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  289. nn.BatchNorm2d(1024),
  290. nn.ReLU(),
  291. nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  292. nn.BatchNorm2d(1024),
  293. nn.ReLU(),
  294. nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  295. nn.BatchNorm2d(2048),
  296. ),
  297. Lambda(lambda x: x), # Identity,
  298. ),
  299. LambdaReduce(lambda x,y: x+y), # CAddTable,
  300. nn.ReLU(),
  301. ),
  302. nn.Sequential( # Sequential,
  303. LambdaMap(lambda x: x, # ConcatTable,
  304. nn.Sequential( # Sequential,
  305. nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  306. nn.BatchNorm2d(1024),
  307. nn.ReLU(),
  308. nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,1,bias=False),
  309. nn.BatchNorm2d(1024),
  310. nn.ReLU(),
  311. nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False),
  312. nn.BatchNorm2d(2048),
  313. ),
  314. Lambda(lambda x: x), # Identity,
  315. ),
  316. LambdaReduce(lambda x,y: x+y), # CAddTable,
  317. nn.ReLU(),
  318. ),
  319. ),
  320. nn.AvgPool2d((7, 7),(1, 1)),
  321. Lambda(lambda x: x.view(x.size(0),-1)), # View,
  322. nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2048,1000)), # Linear,
  323. )