convert_torch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from __future__ import print_function
  2. import argparse
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torch.optim as optim
  7. from torch.autograd import Variable
  8. from torch.utils.serialization import load_lua
  9. import numpy as np
  10. import os
  11. import math
  12. from functools import reduce
  13. class LambdaBase(nn.Sequential):
  14. def __init__(self, fn, *args):
  15. super(LambdaBase, self).__init__(*args)
  16. self.lambda_func = fn
  17. def forward_prepare(self, input):
  18. output = []
  19. for module in self._modules.values():
  20. output.append(module(input))
  21. return output if output else input
  22. class Lambda(LambdaBase):
  23. def forward(self, input):
  24. return self.lambda_func(self.forward_prepare(input))
  25. class LambdaMap(LambdaBase):
  26. def forward(self, input):
  27. # result is Variables list [Variable1, Variable2, ...]
  28. return list(map(self.lambda_func,self.forward_prepare(input)))
  29. class LambdaReduce(LambdaBase):
  30. def forward(self, input):
  31. # result is a Variable
  32. return reduce(self.lambda_func,self.forward_prepare(input))
  33. def copy_param(m,n):
  34. if m.weight is not None: n.weight.data.copy_(m.weight)
  35. if m.bias is not None: n.bias.data.copy_(m.bias)
  36. if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean)
  37. if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)
  38. def add_submodule(seq, *args):
  39. for n in args:
  40. seq.add_module(str(len(seq._modules)),n)
  41. def lua_recursive_model(module,seq):
  42. for m in module.modules:
  43. name = type(m).__name__
  44. real = m
  45. if name == 'TorchObject':
  46. name = m._typename.replace('cudnn.','')
  47. m = m._obj
  48. if name == 'SpatialConvolution':
  49. if not hasattr(m,'groups'): m.groups=1
  50. n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None))
  51. copy_param(m,n)
  52. add_submodule(seq,n)
  53. elif name == 'SpatialBatchNormalization':
  54. n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
  55. copy_param(m,n)
  56. add_submodule(seq,n)
  57. elif name == 'ReLU':
  58. n = nn.ReLU()
  59. add_submodule(seq,n)
  60. elif name == 'SpatialMaxPooling':
  61. n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
  62. add_submodule(seq,n)
  63. elif name == 'SpatialAveragePooling':
  64. n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
  65. add_submodule(seq,n)
  66. elif name == 'SpatialUpSamplingNearest':
  67. n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
  68. add_submodule(seq,n)
  69. elif name == 'View':
  70. n = Lambda(lambda x: x.view(x.size(0),-1))
  71. add_submodule(seq,n)
  72. elif name == 'Linear':
  73. # Linear in pytorch only accept 2D input
  74. n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )
  75. n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None))
  76. copy_param(m,n2)
  77. n = nn.Sequential(n1,n2)
  78. add_submodule(seq,n)
  79. elif name == 'Dropout':
  80. m.inplace = False
  81. n = nn.Dropout(m.p)
  82. add_submodule(seq,n)
  83. elif name == 'SoftMax':
  84. n = nn.Softmax()
  85. add_submodule(seq,n)
  86. elif name == 'Identity':
  87. n = Lambda(lambda x: x) # do nothing
  88. add_submodule(seq,n)
  89. elif name == 'SpatialFullConvolution':
  90. n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))
  91. add_submodule(seq,n)
  92. elif name == 'SpatialReplicationPadding':
  93. n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
  94. add_submodule(seq,n)
  95. elif name == 'SpatialReflectionPadding':
  96. n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
  97. add_submodule(seq,n)
  98. elif name == 'Copy':
  99. n = Lambda(lambda x: x) # do nothing
  100. add_submodule(seq,n)
  101. elif name == 'Narrow':
  102. n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
  103. add_submodule(seq,n)
  104. elif name == 'SpatialCrossMapLRN':
  105. lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
  106. n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
  107. add_submodule(seq,n)
  108. elif name == 'Sequential':
  109. n = nn.Sequential()
  110. lua_recursive_model(m,n)
  111. add_submodule(seq,n)
  112. elif name == 'ConcatTable': # output is list
  113. n = LambdaMap(lambda x: x)
  114. lua_recursive_model(m,n)
  115. add_submodule(seq,n)
  116. elif name == 'CAddTable': # input is list
  117. n = LambdaReduce(lambda x,y: x+y)
  118. add_submodule(seq,n)
  119. elif name == 'Concat':
  120. dim = m.dimension
  121. n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim))
  122. lua_recursive_model(m,n)
  123. add_submodule(seq,n)
  124. elif name == 'TorchObject':
  125. print('Not Implement',name,real._typename)
  126. else:
  127. print('Not Implement',name)
  128. def lua_recursive_source(module):
  129. s = []
  130. for m in module.modules:
  131. name = type(m).__name__
  132. real = m
  133. if name == 'TorchObject':
  134. name = m._typename.replace('cudnn.','')
  135. m = m._obj
  136. if name == 'SpatialConvolution':
  137. if not hasattr(m,'groups'): m.groups=1
  138. s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
  139. m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
  140. elif name == 'SpatialBatchNormalization':
  141. s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
  142. elif name == 'ReLU':
  143. s += ['nn.ReLU()']
  144. elif name == 'SpatialMaxPooling':
  145. s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
  146. elif name == 'SpatialAveragePooling':
  147. s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
  148. elif name == 'SpatialUpSamplingNearest':
  149. s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
  150. elif name == 'View':
  151. s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
  152. elif name == 'Linear':
  153. s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
  154. s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None))
  155. s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)]
  156. elif name == 'Dropout':
  157. s += ['nn.Dropout({})'.format(m.p)]
  158. elif name == 'SoftMax':
  159. s += ['nn.Softmax()']
  160. elif name == 'Identity':
  161. s += ['Lambda(lambda x: x), # Identity']
  162. elif name == 'SpatialFullConvolution':
  163. s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
  164. m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))]
  165. elif name == 'SpatialReplicationPadding':
  166. s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
  167. elif name == 'SpatialReflectionPadding':
  168. s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
  169. elif name == 'Copy':
  170. s += ['Lambda(lambda x: x), # Copy']
  171. elif name == 'Narrow':
  172. s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
  173. elif name == 'SpatialCrossMapLRN':
  174. lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
  175. s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
  176. elif name == 'Sequential':
  177. s += ['nn.Sequential( # Sequential']
  178. s += lua_recursive_source(m)
  179. s += [')']
  180. elif name == 'ConcatTable':
  181. s += ['LambdaMap(lambda x: x, # ConcatTable']
  182. s += lua_recursive_source(m)
  183. s += [')']
  184. elif name == 'CAddTable':
  185. s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
  186. elif name == 'Concat':
  187. dim = m.dimension
  188. s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)]
  189. s += lua_recursive_source(m)
  190. s += [')']
  191. else:
  192. s += '# ' + name + ' Not Implement,\n'
  193. s = map(lambda x: '\t{}'.format(x),s)
  194. return s
  195. def simplify_source(s):
  196. s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s)
  197. s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s)
  198. s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s)
  199. s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s)
  200. s = map(lambda x: x.replace('),#Conv2d',')'),s)
  201. s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s)
  202. s = map(lambda x: x.replace('),#BatchNorm2d',')'),s)
  203. s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s)
  204. s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s)
  205. s = map(lambda x: x.replace('),#MaxPool2d',')'),s)
  206. s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s)
  207. s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
  208. s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
  209. s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
  210. s = map(lambda x: '{},\n'.format(x),s)
  211. s = map(lambda x: x[1:],s)
  212. s = reduce(lambda x,y: x+y, s)
  213. return s
  214. def torch_to_pytorch(t7_filename,outputname=None):
  215. model = load_lua(t7_filename,unknown_classes=True)
  216. if type(model).__name__=='hashable_uniq_dict': model=model.model
  217. model.gradInput = None
  218. slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
  219. s = simplify_source(slist)
  220. header = '''
  221. import torch
  222. import torch.nn as nn
  223. from torch.autograd import Variable
  224. from functools import reduce
  225. class LambdaBase(nn.Sequential):
  226. def __init__(self, fn, *args):
  227. super(LambdaBase, self).__init__(*args)
  228. self.lambda_func = fn
  229. def forward_prepare(self, input):
  230. output = []
  231. for module in self._modules.values():
  232. output.append(module(input))
  233. return output if output else input
  234. class Lambda(LambdaBase):
  235. def forward(self, input):
  236. return self.lambda_func(self.forward_prepare(input))
  237. class LambdaMap(LambdaBase):
  238. def forward(self, input):
  239. return list(map(self.lambda_func,self.forward_prepare(input)))
  240. class LambdaReduce(LambdaBase):
  241. def forward(self, input):
  242. return reduce(self.lambda_func,self.forward_prepare(input))
  243. '''
  244. varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
  245. s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
  246. if outputname is None: outputname=varname
  247. with open(outputname+'.py', "w") as pyfile:
  248. pyfile.write(s)
  249. n = nn.Sequential()
  250. lua_recursive_model(model,n)
  251. torch.save(n.state_dict(),outputname+'.pth')
  252. parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
  253. parser.add_argument('--model','-m', type=str, required=True,
  254. help='torch model file in t7 format')
  255. parser.add_argument('--output', '-o', type=str, default=None,
  256. help='output file name prefix, xxx.py xxx.pth')
  257. args = parser.parse_args()
  258. torch_to_pytorch(args.model,args.output)