123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- from __future__ import print_function
- import argparse
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torch.autograd import Variable
- from torch.utils.serialization import load_lua
- import numpy as np
- import os
- import math
- from functools import reduce
- class LambdaBase(nn.Sequential):
- def __init__(self, fn, *args):
- super(LambdaBase, self).__init__(*args)
- self.lambda_func = fn
- def forward_prepare(self, input):
- output = []
- for module in self._modules.values():
- output.append(module(input))
- return output if output else input
- class Lambda(LambdaBase):
- def forward(self, input):
- return self.lambda_func(self.forward_prepare(input))
- class LambdaMap(LambdaBase):
- def forward(self, input):
- # result is Variables list [Variable1, Variable2, ...]
- return list(map(self.lambda_func,self.forward_prepare(input)))
- class LambdaReduce(LambdaBase):
- def forward(self, input):
- # result is a Variable
- return reduce(self.lambda_func,self.forward_prepare(input))
- def copy_param(m,n):
- if m.weight is not None: n.weight.data.copy_(m.weight)
- if m.bias is not None: n.bias.data.copy_(m.bias)
- if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean)
- if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)
- def add_submodule(seq, *args):
- for n in args:
- seq.add_module(str(len(seq._modules)),n)
- def lua_recursive_model(module,seq):
- for m in module.modules:
- name = type(m).__name__
- real = m
- if name == 'TorchObject':
- name = m._typename.replace('cudnn.','')
- m = m._obj
- if name == 'SpatialConvolution':
- if not hasattr(m,'groups'): m.groups=1
- n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None))
- copy_param(m,n)
- add_submodule(seq,n)
- elif name == 'SpatialBatchNormalization':
- n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
- copy_param(m,n)
- add_submodule(seq,n)
- elif name == 'ReLU':
- n = nn.ReLU()
- add_submodule(seq,n)
- elif name == 'SpatialMaxPooling':
- n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
- add_submodule(seq,n)
- elif name == 'SpatialAveragePooling':
- n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
- add_submodule(seq,n)
- elif name == 'SpatialUpSamplingNearest':
- n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
- add_submodule(seq,n)
- elif name == 'View':
- n = Lambda(lambda x: x.view(x.size(0),-1))
- add_submodule(seq,n)
- elif name == 'Linear':
- # Linear in pytorch only accept 2D input
- n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )
- n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None))
- copy_param(m,n2)
- n = nn.Sequential(n1,n2)
- add_submodule(seq,n)
- elif name == 'Dropout':
- m.inplace = False
- n = nn.Dropout(m.p)
- add_submodule(seq,n)
- elif name == 'SoftMax':
- n = nn.Softmax()
- add_submodule(seq,n)
- elif name == 'Identity':
- n = Lambda(lambda x: x) # do nothing
- add_submodule(seq,n)
- elif name == 'SpatialFullConvolution':
- n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))
- add_submodule(seq,n)
- elif name == 'SpatialReplicationPadding':
- n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
- add_submodule(seq,n)
- elif name == 'SpatialReflectionPadding':
- n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
- add_submodule(seq,n)
- elif name == 'Copy':
- n = Lambda(lambda x: x) # do nothing
- add_submodule(seq,n)
- elif name == 'Narrow':
- n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
- add_submodule(seq,n)
- elif name == 'SpatialCrossMapLRN':
- lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
- n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
- add_submodule(seq,n)
- elif name == 'Sequential':
- n = nn.Sequential()
- lua_recursive_model(m,n)
- add_submodule(seq,n)
- elif name == 'ConcatTable': # output is list
- n = LambdaMap(lambda x: x)
- lua_recursive_model(m,n)
- add_submodule(seq,n)
- elif name == 'CAddTable': # input is list
- n = LambdaReduce(lambda x,y: x+y)
- add_submodule(seq,n)
- elif name == 'Concat':
- dim = m.dimension
- n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim))
- lua_recursive_model(m,n)
- add_submodule(seq,n)
- elif name == 'TorchObject':
- print('Not Implement',name,real._typename)
- else:
- print('Not Implement',name)
- def lua_recursive_source(module):
- s = []
- for m in module.modules:
- name = type(m).__name__
- real = m
- if name == 'TorchObject':
- name = m._typename.replace('cudnn.','')
- m = m._obj
- if name == 'SpatialConvolution':
- if not hasattr(m,'groups'): m.groups=1
- s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
- m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
- elif name == 'SpatialBatchNormalization':
- s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
- elif name == 'ReLU':
- s += ['nn.ReLU()']
- elif name == 'SpatialMaxPooling':
- s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
- elif name == 'SpatialAveragePooling':
- s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
- elif name == 'SpatialUpSamplingNearest':
- s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
- elif name == 'View':
- s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
- elif name == 'Linear':
- s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
- s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None))
- s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)]
- elif name == 'Dropout':
- s += ['nn.Dropout({})'.format(m.p)]
- elif name == 'SoftMax':
- s += ['nn.Softmax()']
- elif name == 'Identity':
- s += ['Lambda(lambda x: x), # Identity']
- elif name == 'SpatialFullConvolution':
- s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
- m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))]
- elif name == 'SpatialReplicationPadding':
- s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
- elif name == 'SpatialReflectionPadding':
- s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
- elif name == 'Copy':
- s += ['Lambda(lambda x: x), # Copy']
- elif name == 'Narrow':
- s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
- elif name == 'SpatialCrossMapLRN':
- lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
- s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
- elif name == 'Sequential':
- s += ['nn.Sequential( # Sequential']
- s += lua_recursive_source(m)
- s += [')']
- elif name == 'ConcatTable':
- s += ['LambdaMap(lambda x: x, # ConcatTable']
- s += lua_recursive_source(m)
- s += [')']
- elif name == 'CAddTable':
- s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
- elif name == 'Concat':
- dim = m.dimension
- s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)]
- s += lua_recursive_source(m)
- s += [')']
- else:
- s += '# ' + name + ' Not Implement,\n'
- s = map(lambda x: '\t{}'.format(x),s)
- return s
- def simplify_source(s):
- s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s)
- s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s)
- s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s)
- s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s)
- s = map(lambda x: x.replace('),#Conv2d',')'),s)
- s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s)
- s = map(lambda x: x.replace('),#BatchNorm2d',')'),s)
- s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s)
- s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s)
- s = map(lambda x: x.replace('),#MaxPool2d',')'),s)
- s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s)
- s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
- s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
- s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
-
- s = map(lambda x: '{},\n'.format(x),s)
- s = map(lambda x: x[1:],s)
- s = reduce(lambda x,y: x+y, s)
- return s
- def torch_to_pytorch(t7_filename,outputname=None):
- model = load_lua(t7_filename,unknown_classes=True)
- if type(model).__name__=='hashable_uniq_dict': model=model.model
- model.gradInput = None
- slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
- s = simplify_source(slist)
- header = '''
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- from functools import reduce
- class LambdaBase(nn.Sequential):
- def __init__(self, fn, *args):
- super(LambdaBase, self).__init__(*args)
- self.lambda_func = fn
- def forward_prepare(self, input):
- output = []
- for module in self._modules.values():
- output.append(module(input))
- return output if output else input
- class Lambda(LambdaBase):
- def forward(self, input):
- return self.lambda_func(self.forward_prepare(input))
- class LambdaMap(LambdaBase):
- def forward(self, input):
- return list(map(self.lambda_func,self.forward_prepare(input)))
- class LambdaReduce(LambdaBase):
- def forward(self, input):
- return reduce(self.lambda_func,self.forward_prepare(input))
- '''
- varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
- s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
- if outputname is None: outputname=varname
- with open(outputname+'.py', "w") as pyfile:
- pyfile.write(s)
- n = nn.Sequential()
- lua_recursive_model(model,n)
- torch.save(n.state_dict(),outputname+'.pth')
- parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
- parser.add_argument('--model','-m', type=str, required=True,
- help='torch model file in t7 format')
- parser.add_argument('--output', '-o', type=str, default=None,
- help='output file name prefix, xxx.py xxx.pth')
- args = parser.parse_args()
- torch_to_pytorch(args.model,args.output)
|