darknet.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .layers import *
  5. class ConvBN(nn.Module):
  6. "convolutional layer then batchnorm"
  7. def __init__(self, ch_in, ch_out, kernel_size = 3, stride=1, padding=0):
  8. super().__init__()
  9. self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
  10. self.bn = nn.BatchNorm2d(ch_out, momentum=0.01)
  11. self.relu = nn.LeakyReLU(0.1, inplace=True)
  12. def forward(self, x): return self.relu(self.bn(self.conv(x)))
  13. class DarknetBlock(nn.Module):
  14. def __init__(self, ch_in):
  15. super().__init__()
  16. ch_hid = ch_in//2
  17. self.conv1 = ConvBN(ch_in, ch_hid, kernel_size=1, stride=1, padding=0)
  18. self.conv2 = ConvBN(ch_hid, ch_in, kernel_size=3, stride=1, padding=1)
  19. def forward(self, x): return self.conv2(self.conv1(x)) + x
  20. class Darknet(nn.Module):
  21. "Replicates the darknet classifier from the YOLOv3 paper (table 1)"
  22. def make_group_layer(self, ch_in, num_blocks, stride=1):
  23. layers = [ConvBN(ch_in,ch_in*2,stride=stride)]
  24. for i in range(num_blocks): layers.append(DarknetBlock(ch_in*2))
  25. return layers
  26. def __init__(self, num_blocks, num_classes=1000, start_nf=32):
  27. super().__init__()
  28. nf = start_nf
  29. layers = [ConvBN(3, nf, kernel_size=3, stride=1, padding=1)]
  30. for i,nb in enumerate(num_blocks):
  31. layers += self.make_group_layer(nf, nb, stride=(1 if i==1 else 2))
  32. nf *= 2
  33. layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]
  34. self.layers = nn.Sequential(*layers)
  35. def forward(self, x): return self.layers(x)
  36. def darknet_53(num_classes=1000): return Darknet([1,2,8,8,4], num_classes)
  37. def darknet_small(num_classes=1000): return Darknet([1,2,4,8,4], num_classes)
  38. def darknet_mini(num_classes=1000): return Darknet([1,2,4,4,2], num_classes, start_nf=24)
  39. def darknet_mini2(num_classes=1000): return Darknet([1,2,8,8,4], num_classes, start_nf=16)
  40. def darknet_mini3(num_classes=1000): return Darknet([1,2,4,4], num_classes)