wideresnet.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Cifar10 Wideresnet for Dawn Submission
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from ...layers import *
  7. def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)
  8. def bn(ni, init_zero=False):
  9. m = nn.BatchNorm2d(ni)
  10. m.weight.data.fill_(0 if init_zero else 1)
  11. m.bias.data.zero_()
  12. return m
  13. def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
  14. bn_initzero = bn(ni, init_zero=init_zero)
  15. return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride))
  16. def noop(x): return x
  17. class BasicBlock(nn.Module):
  18. def __init__(self, ni, nf, stride, drop_p=0.0):
  19. super().__init__()
  20. self.bn = nn.BatchNorm2d(ni)
  21. self.conv1 = conv_2d(ni, nf, 3, stride)
  22. self.conv2 = bn_relu_conv(nf, nf, 3, 1)
  23. self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
  24. self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop
  25. def forward(self, x):
  26. x2 = F.relu(self.bn(x), inplace=True)
  27. r = self.shortcut(x2)
  28. x = self.conv1(x2)
  29. if self.drop: x = self.drop(x)
  30. x = self.conv2(x) * 0.2
  31. return x.add_(r)
  32. def _make_group(N, ni, nf, block, stride, drop_p):
  33. return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]
  34. class WideResNet(nn.Module):
  35. def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16):
  36. super().__init__()
  37. n_channels = [start_nf]
  38. for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)
  39. layers = [conv_2d(3, n_channels[0], 3, 1)] # conv1
  40. for i in range(num_groups):
  41. layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)
  42. layers += [nn.BatchNorm2d(n_channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
  43. Flatten(), nn.Linear(n_channels[3], num_classes)]
  44. self.features = nn.Sequential(*layers)
  45. def forward(self, x): return self.features(x)
  46. def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)
  47. def wrn_22_k8(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.)
  48. def wrn_22_k10(): return WideResNet(num_groups=3, N=3, num_classes=10, k=10, drop_p=0.)
  49. def wrn_22_k8_p2(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.2)
  50. def wrn_28(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.)
  51. def wrn_28_k8(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.)
  52. def wrn_28_k8_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.2)
  53. def wrn_28_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.2)