123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.model_zoo as model_zoo
- from torch.autograd import Variable
- pretrained_settings = {
- 'nasnetalarge': {
- 'imagenet': {
- 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
- 'input_space': 'RGB',
- 'input_size': [3, 331, 331], # resize 354
- 'input_range': [0, 1],
- 'mean': [0.5, 0.5, 0.5],
- 'std': [0.5, 0.5, 0.5],
- 'num_classes': 1000
- },
- 'imagenet+background': {
- 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
- 'input_space': 'RGB',
- 'input_size': [3, 331, 331], # resize 354
- 'input_range': [0, 1],
- 'mean': [0.5, 0.5, 0.5],
- 'std': [0.5, 0.5, 0.5],
- 'num_classes': 1001
- }
- }
- }
- class MaxPoolPad(nn.Module):
- def __init__(self):
- super(MaxPoolPad, self).__init__()
- self.pad = nn.ZeroPad2d((1, 0, 1, 0))
- self.pool = nn.MaxPool2d(3, stride=2, padding=1)
- def forward(self, x):
- x = self.pad(x)
- x = self.pool(x)
- x = x[:, :, 1:, 1:]
- return x
- class AvgPoolPad(nn.Module):
- def __init__(self, stride=2, padding=1):
- super(AvgPoolPad, self).__init__()
- self.pad = nn.ZeroPad2d((1, 0, 1, 0))
- self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
- def forward(self, x):
- x = self.pad(x)
- x = self.pool(x)
- x = x[:, :, 1:, 1:]
- return x
- class SeparableConv2d(nn.Module):
- def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
- super(SeparableConv2d, self).__init__()
- self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel,
- stride=dw_stride,
- padding=dw_padding,
- bias=bias,
- groups=in_channels)
- self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
- def forward(self, x):
- x = self.depthwise_conv2d(x)
- x = self.pointwise_conv2d(x)
- return x
- class BranchSeparables(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
- super(BranchSeparables, self).__init__()
- self.relu = nn.ReLU()
- self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
- self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
- self.relu1 = nn.ReLU()
- self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
- self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
- def forward(self, x):
- x = self.relu(x)
- x = self.separable_1(x)
- x = self.bn_sep_1(x)
- x = self.relu1(x)
- x = self.separable_2(x)
- x = self.bn_sep_2(x)
- return x
- class BranchSeparablesStem(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
- super(BranchSeparablesStem, self).__init__()
- self.relu = nn.ReLU()
- self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
- self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
- self.relu1 = nn.ReLU()
- self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
- self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
- def forward(self, x):
- x = self.relu(x)
- x = self.separable_1(x)
- x = self.bn_sep_1(x)
- x = self.relu1(x)
- x = self.separable_2(x)
- x = self.bn_sep_2(x)
- return x
- class BranchSeparablesReduction(BranchSeparables):
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
- BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
- self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
- def forward(self, x):
- x = self.relu(x)
- x = self.padding(x)
- x = self.separable_1(x)
- x = x[:, :, 1:, 1:].contiguous()
- x = self.bn_sep_1(x)
- x = self.relu1(x)
- x = self.separable_2(x)
- x = self.bn_sep_2(x)
- return x
- class CellStem0(nn.Module):
- def __init__(self):
- super(CellStem0, self).__init__()
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(42, eps=0.001, momentum=0.1, affine=True))
- self.comb_iter_0_left = BranchSeparables(42, 42, 5, 2, 2)
- self.comb_iter_0_right = BranchSeparablesStem(96, 42, 7, 2, 3, bias=False)
- self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
- self.comb_iter_1_right = BranchSeparablesStem(96, 42, 7, 2, 3, bias=False)
- self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
- self.comb_iter_2_right = BranchSeparablesStem(96, 42, 5, 2, 2, bias=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparables(42, 42, 3, 1, 1, bias=False)
- self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
- def forward(self, x):
- x1 = self.conv_1x1(x)
- x_comb_iter_0_left = self.comb_iter_0_left(x1)
- x_comb_iter_0_right = self.comb_iter_0_right(x)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x1)
- x_comb_iter_1_right = self.comb_iter_1_right(x)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x1)
- x_comb_iter_2_right = self.comb_iter_2_right(x)
- x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
- x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
- x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
- x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
- x_comb_iter_4_right = self.comb_iter_4_right(x1)
- x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
- x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class CellStem1(nn.Module):
- def __init__(self):
- super(CellStem1, self).__init__()
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(168, 84, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(84, eps=0.001, momentum=0.1, affine=True))
- self.relu = nn.ReLU()
- self.path_1 = nn.Sequential()
- self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
- self.path_1.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
- self.path_2 = nn.ModuleList()
- self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
- self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
- self.path_2.add_module('conv', nn.Conv2d(96, 42, 1, stride=1, bias=False))
- self.final_path_bn = nn.BatchNorm2d(84, eps=0.001, momentum=0.1, affine=True)
- self.comb_iter_0_left = BranchSeparables(84, 84, 5, 2, 2, bias=False)
- self.comb_iter_0_right = BranchSeparables(84, 84, 7, 2, 3, bias=False)
- self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
- self.comb_iter_1_right = BranchSeparables(84, 84, 7, 2, 3, bias=False)
- self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
- self.comb_iter_2_right = BranchSeparables(84, 84, 5, 2, 2, bias=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparables(84, 84, 3, 1, 1, bias=False)
- self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
- def forward(self, x_conv0, x_stem_0):
- x_left = self.conv_1x1(x_stem_0)
- x_relu = self.relu(x_conv0)
- # path 1
- x_path1 = self.path_1(x_relu)
- # path 2
- x_path2 = self.path_2.pad(x_relu)
- x_path2 = x_path2[:, :, 1:, 1:]
- x_path2 = self.path_2.avgpool(x_path2)
- x_path2 = self.path_2.conv(x_path2)
- # final path
- x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
- x_comb_iter_0_left = self.comb_iter_0_left(x_left)
- x_comb_iter_0_right = self.comb_iter_0_right(x_right)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_left)
- x_comb_iter_1_right = self.comb_iter_1_right(x_right)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_left)
- x_comb_iter_2_right = self.comb_iter_2_right(x_right)
- x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
- x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
- x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
- x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
- x_comb_iter_4_right = self.comb_iter_4_right(x_left)
- x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
- x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class FirstCell(nn.Module):
- def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
- super(FirstCell, self).__init__()
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
- self.relu = nn.ReLU()
- self.path_1 = nn.Sequential()
- self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
- self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
- self.path_2 = nn.ModuleList()
- self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
- self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
- self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
- self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
- self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
- self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
- self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- def forward(self, x, x_prev):
- x_relu = self.relu(x_prev)
- # path 1
- x_path1 = self.path_1(x_relu)
- # path 2
- x_path2 = self.path_2.pad(x_relu)
- x_path2 = x_path2[:, :, 1:, 1:]
- x_path2 = self.path_2.avgpool(x_path2)
- x_path2 = self.path_2.conv(x_path2)
- # final path
- x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
- x_right = self.conv_1x1(x)
- x_comb_iter_0_left = self.comb_iter_0_left(x_right)
- x_comb_iter_0_right = self.comb_iter_0_right(x_left)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_left)
- x_comb_iter_1_right = self.comb_iter_1_right(x_left)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_right)
- x_comb_iter_2 = x_comb_iter_2_left + x_left
- x_comb_iter_3_left = self.comb_iter_3_left(x_left)
- x_comb_iter_3_right = self.comb_iter_3_right(x_left)
- x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
- x_comb_iter_4_left = self.comb_iter_4_left(x_right)
- x_comb_iter_4 = x_comb_iter_4_left + x_right
- x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class NormalCell(nn.Module):
- def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
- super(NormalCell, self).__init__()
- self.conv_prev_1x1 = nn.Sequential()
- self.conv_prev_1x1.add_module('relu', nn.ReLU())
- self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
- self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
- self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
- self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
- self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
- self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
- self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- def forward(self, x, x_prev):
- x_left = self.conv_prev_1x1(x_prev)
- x_right = self.conv_1x1(x)
- x_comb_iter_0_left = self.comb_iter_0_left(x_right)
- x_comb_iter_0_right = self.comb_iter_0_right(x_left)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_left)
- x_comb_iter_1_right = self.comb_iter_1_right(x_left)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_right)
- x_comb_iter_2 = x_comb_iter_2_left + x_left
- x_comb_iter_3_left = self.comb_iter_3_left(x_left)
- x_comb_iter_3_right = self.comb_iter_3_right(x_left)
- x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
- x_comb_iter_4_left = self.comb_iter_4_left(x_right)
- x_comb_iter_4 = x_comb_iter_4_left + x_right
- x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class ReductionCell0(nn.Module):
- def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
- super(ReductionCell0, self).__init__()
- self.conv_prev_1x1 = nn.Sequential()
- self.conv_prev_1x1.add_module('relu', nn.ReLU())
- self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
- self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
- self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
- self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
- self.comb_iter_1_left = MaxPoolPad()
- self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
- self.comb_iter_2_left = AvgPoolPad()
- self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- self.comb_iter_4_right = MaxPoolPad()
- def forward(self, x, x_prev):
- x_left = self.conv_prev_1x1(x_prev)
- x_right = self.conv_1x1(x)
- x_comb_iter_0_left = self.comb_iter_0_left(x_right)
- x_comb_iter_0_right = self.comb_iter_0_right(x_left)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_right)
- x_comb_iter_1_right = self.comb_iter_1_right(x_left)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_right)
- x_comb_iter_2_right = self.comb_iter_2_right(x_left)
- x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
- x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
- x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
- x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
- x_comb_iter_4_right = self.comb_iter_4_right(x_right)
- x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
- x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class ReductionCell1(nn.Module):
- def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
- super(ReductionCell1, self).__init__()
- self.conv_prev_1x1 = nn.Sequential()
- self.conv_prev_1x1.add_module('relu', nn.ReLU())
- self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
- self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
- self.conv_1x1 = nn.Sequential()
- self.conv_1x1.add_module('relu', nn.ReLU())
- self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
- self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
- self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
- self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
- self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
- self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
- self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
- self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
- self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
- self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
- self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
- def forward(self, x, x_prev):
- x_left = self.conv_prev_1x1(x_prev)
- x_right = self.conv_1x1(x)
- x_comb_iter_0_left = self.comb_iter_0_left(x_right)
- x_comb_iter_0_right = self.comb_iter_0_right(x_left)
- x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
- x_comb_iter_1_left = self.comb_iter_1_left(x_right)
- x_comb_iter_1_right = self.comb_iter_1_right(x_left)
- x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
- x_comb_iter_2_left = self.comb_iter_2_left(x_right)
- x_comb_iter_2_right = self.comb_iter_2_right(x_left)
- x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
- x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
- x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
- x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
- x_comb_iter_4_right = self.comb_iter_4_right(x_right)
- x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
- x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
- return x_out
- class NASNetALarge(nn.Module):
- def __init__(self, use_classifier=False, num_classes=1001):
- super(NASNetALarge, self).__init__()
- self.use_classifier,self.num_classes = use_classifier,num_classes
- self.conv0 = nn.Sequential()
- self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, padding=0, stride=2,
- bias=False))
- self.conv0.add_module('bn', nn.BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True))
- self.cell_stem_0 = CellStem0()
- self.cell_stem_1 = CellStem1()
- self.cell_0 = FirstCell(in_channels_left=168, out_channels_left=84,
- in_channels_right=336, out_channels_right=168)
- self.cell_1 = NormalCell(in_channels_left=336, out_channels_left=168,
- in_channels_right=1008, out_channels_right=168)
- self.cell_2 = NormalCell(in_channels_left=1008, out_channels_left=168,
- in_channels_right=1008, out_channels_right=168)
- self.cell_3 = NormalCell(in_channels_left=1008, out_channels_left=168,
- in_channels_right=1008, out_channels_right=168)
- self.cell_4 = NormalCell(in_channels_left=1008, out_channels_left=168,
- in_channels_right=1008, out_channels_right=168)
- self.cell_5 = NormalCell(in_channels_left=1008, out_channels_left=168,
- in_channels_right=1008, out_channels_right=168)
- self.reduction_cell_0 = ReductionCell0(in_channels_left=1008, out_channels_left=336,
- in_channels_right=1008, out_channels_right=336)
- self.cell_6 = FirstCell(in_channels_left=1008, out_channels_left=168,
- in_channels_right=1344, out_channels_right=336)
- self.cell_7 = NormalCell(in_channels_left=1344, out_channels_left=336,
- in_channels_right=2016, out_channels_right=336)
- self.cell_8 = NormalCell(in_channels_left=2016, out_channels_left=336,
- in_channels_right=2016, out_channels_right=336)
- self.cell_9 = NormalCell(in_channels_left=2016, out_channels_left=336,
- in_channels_right=2016, out_channels_right=336)
- self.cell_10 = NormalCell(in_channels_left=2016, out_channels_left=336,
- in_channels_right=2016, out_channels_right=336)
- self.cell_11 = NormalCell(in_channels_left=2016, out_channels_left=336,
- in_channels_right=2016, out_channels_right=336)
- self.reduction_cell_1 = ReductionCell1(in_channels_left=2016, out_channels_left=672,
- in_channels_right=2016, out_channels_right=672)
- self.cell_12 = FirstCell(in_channels_left=2016, out_channels_left=336,
- in_channels_right=2688, out_channels_right=672)
- self.cell_13 = NormalCell(in_channels_left=2688, out_channels_left=672,
- in_channels_right=4032, out_channels_right=672)
- self.cell_14 = NormalCell(in_channels_left=4032, out_channels_left=672,
- in_channels_right=4032, out_channels_right=672)
- self.cell_15 = NormalCell(in_channels_left=4032, out_channels_left=672,
- in_channels_right=4032, out_channels_right=672)
- self.cell_16 = NormalCell(in_channels_left=4032, out_channels_left=672,
- in_channels_right=4032, out_channels_right=672)
- self.cell_17 = NormalCell(in_channels_left=4032, out_channels_left=672,
- in_channels_right=4032, out_channels_right=672)
- self.relu = nn.ReLU()
- self.dropout = nn.Dropout()
- self.last_linear = nn.Linear(4032, self.num_classes)
- def features(self, x):
- x_conv0 = self.conv0(x)
- x_stem_0 = self.cell_stem_0(x_conv0)
- x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
- x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
- x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
- x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
- x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
- x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
- x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
- x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
- x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
- x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
- x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
- x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
- x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
- x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
- x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
- x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
- x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
- x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
- x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
- x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
- x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
- return self.relu(x_cell_17)
- def classifier(self, x):
- x = F.adaptive_max_pool2d(x, 1)
- x = x.view(x.size(0), -1)
- x = self.dropout(x)
- return F.log_softmax(self.linear(x))
- def forward(self, x):
- x = self.features(x)
- if self.use_classifier: x = self.classifier(x)
- return x
- def nasnetalarge(num_classes=1000, pretrained='imagenet'):
- r"""NASNetALarge model architecture from the
- `"NASNet" <https://arxiv.org/abs/1707.07012>`_ paper.
- """
- if pretrained:
- settings = pretrained_settings['nasnetalarge'][pretrained]
- assert num_classes == settings['num_classes'], \
- "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
- # both 'imagenet'&'imagenet+background' are loaded from same parameters
- model = NASNetALarge(num_classes=1001)
- model.load_state_dict(model_zoo.load_url(settings['url']))
- if pretrained == 'imagenet':
- new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
- new_last_linear.weight.data = model.last_linear.weight.data[1:]
- new_last_linear.bias.data = model.last_linear.bias.data[1:]
- model.last_linear = new_last_linear
- model.input_space = settings['input_space']
- model.input_size = settings['input_size']
- model.input_range = settings['input_range']
- model.mean = settings['mean']
- model.std = settings['std']
- else:
- model = NASNetALarge(num_classes=num_classes)
- return model
|