123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- # Cifar10 Wideresnet for Dawn Submission
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ...layers import *
- def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)
- def bn(ni, init_zero=False):
- m = nn.BatchNorm2d(ni)
- m.weight.data.fill_(0 if init_zero else 1)
- m.bias.data.zero_()
- return m
- def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
- bn_initzero = bn(ni, init_zero=init_zero)
- return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride))
- def noop(x): return x
- class BasicBlock(nn.Module):
- def __init__(self, ni, nf, stride, drop_p=0.0):
- super().__init__()
- self.bn = nn.BatchNorm2d(ni)
- self.conv1 = conv_2d(ni, nf, 3, stride)
- self.conv2 = bn_relu_conv(nf, nf, 3, 1)
- self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
- self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop
- def forward(self, x):
- x2 = F.relu(self.bn(x), inplace=True)
- r = self.shortcut(x2)
- x = self.conv1(x2)
- if self.drop: x = self.drop(x)
- x = self.conv2(x) * 0.2
- return x.add_(r)
- def _make_group(N, ni, nf, block, stride, drop_p):
- return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]
- class WideResNet(nn.Module):
- def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16):
- super().__init__()
- n_channels = [start_nf]
- for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)
- layers = [conv_2d(3, n_channels[0], 3, 1)] # conv1
- for i in range(num_groups):
- layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)
- layers += [nn.BatchNorm2d(n_channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
- Flatten(), nn.Linear(n_channels[3], num_classes)]
- self.features = nn.Sequential(*layers)
- def forward(self, x): return self.features(x)
- def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)
- def wrn_22_k8(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.)
- def wrn_22_k10(): return WideResNet(num_groups=3, N=3, num_classes=10, k=10, drop_p=0.)
- def wrn_22_k8_p2(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.2)
- def wrn_28(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.)
- def wrn_28_k8(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.)
- def wrn_28_k8_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.2)
- def wrn_28_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.2)
|