inception.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import models
  5. try:
  6. from torchvision.models.utils import load_state_dict_from_url
  7. except ImportError:
  8. from torch.utils.model_zoo import load_url as load_state_dict_from_url
  9. # Inception weights ported to Pytorch from
  10. # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
  11. FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
  12. class InceptionV3(nn.Module):
  13. """Pretrained InceptionV3 network returning feature maps"""
  14. # Index of default block of inception to return,
  15. # corresponds to output of final average pooling
  16. DEFAULT_BLOCK_INDEX = 3
  17. # Maps feature dimensionality to their output blocks indices
  18. BLOCK_INDEX_BY_DIM = {
  19. 64: 0, # First max pooling features
  20. 192: 1, # Second max pooling featurs
  21. 768: 2, # Pre-aux classifier features
  22. 2048: 3, # Final average pooling features
  23. }
  24. def __init__(
  25. self,
  26. output_blocks=[DEFAULT_BLOCK_INDEX],
  27. resize_input=True,
  28. normalize_input=True,
  29. requires_grad=False,
  30. use_fid_inception=True,
  31. ):
  32. """Build pretrained InceptionV3
  33. Parameters
  34. ----------
  35. output_blocks : list of int
  36. Indices of blocks to return features of. Possible values are:
  37. - 0: corresponds to output of first max pooling
  38. - 1: corresponds to output of second max pooling
  39. - 2: corresponds to output which is fed to aux classifier
  40. - 3: corresponds to output of final average pooling
  41. resize_input : bool
  42. If true, bilinearly resizes input to width and height 299 before
  43. feeding input to model. As the network without fully connected
  44. layers is fully convolutional, it should be able to handle inputs
  45. of arbitrary size, so resizing might not be strictly needed
  46. normalize_input : bool
  47. If true, scales the input from range (0, 1) to the range the
  48. pretrained Inception network expects, namely (-1, 1)
  49. requires_grad : bool
  50. If true, parameters of the model require gradients. Possibly useful
  51. for finetuning the network
  52. use_fid_inception : bool
  53. If true, uses the pretrained Inception model used in Tensorflow's
  54. FID implementation. If false, uses the pretrained Inception model
  55. available in torchvision. The FID Inception model has different
  56. weights and a slightly different structure from torchvision's
  57. Inception model. If you want to compute FID scores, you are
  58. strongly advised to set this parameter to true to get comparable
  59. results.
  60. """
  61. super(InceptionV3, self).__init__()
  62. self.resize_input = resize_input
  63. self.normalize_input = normalize_input
  64. self.output_blocks = sorted(output_blocks)
  65. self.last_needed_block = max(output_blocks)
  66. assert self.last_needed_block <= 3, 'Last possible output block index is 3'
  67. self.blocks = nn.ModuleList()
  68. if use_fid_inception:
  69. inception = fid_inception_v3()
  70. else:
  71. inception = models.inception_v3(pretrained=True)
  72. # Block 0: input to maxpool1
  73. block0 = [
  74. inception.Conv2d_1a_3x3,
  75. inception.Conv2d_2a_3x3,
  76. inception.Conv2d_2b_3x3,
  77. nn.MaxPool2d(kernel_size=3, stride=2),
  78. ]
  79. self.blocks.append(nn.Sequential(*block0))
  80. # Block 1: maxpool1 to maxpool2
  81. if self.last_needed_block >= 1:
  82. block1 = [
  83. inception.Conv2d_3b_1x1,
  84. inception.Conv2d_4a_3x3,
  85. nn.MaxPool2d(kernel_size=3, stride=2),
  86. ]
  87. self.blocks.append(nn.Sequential(*block1))
  88. # Block 2: maxpool2 to aux classifier
  89. if self.last_needed_block >= 2:
  90. block2 = [
  91. inception.Mixed_5b,
  92. inception.Mixed_5c,
  93. inception.Mixed_5d,
  94. inception.Mixed_6a,
  95. inception.Mixed_6b,
  96. inception.Mixed_6c,
  97. inception.Mixed_6d,
  98. inception.Mixed_6e,
  99. ]
  100. self.blocks.append(nn.Sequential(*block2))
  101. # Block 3: aux classifier to final avgpool
  102. if self.last_needed_block >= 3:
  103. block3 = [
  104. inception.Mixed_7a,
  105. inception.Mixed_7b,
  106. inception.Mixed_7c,
  107. nn.AdaptiveAvgPool2d(output_size=(1, 1)),
  108. ]
  109. self.blocks.append(nn.Sequential(*block3))
  110. for param in self.parameters():
  111. param.requires_grad = requires_grad
  112. def forward(self, inp):
  113. """Get Inception feature maps
  114. Parameters
  115. ----------
  116. inp : torch.autograd.Variable
  117. Input tensor of shape Bx3xHxW. Values are expected to be in
  118. range (0, 1)
  119. Returns
  120. -------
  121. List of torch.autograd.Variable, corresponding to the selected output
  122. block, sorted ascending by index
  123. """
  124. outp = []
  125. x = inp
  126. if self.resize_input:
  127. x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
  128. if self.normalize_input:
  129. x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
  130. for idx, block in enumerate(self.blocks):
  131. x = block(x)
  132. if idx in self.output_blocks:
  133. outp.append(x)
  134. if idx == self.last_needed_block:
  135. break
  136. return outp
  137. def fid_inception_v3():
  138. """Build pretrained Inception model for FID computation
  139. The Inception model for FID computation uses a different set of weights
  140. and has a slightly different structure than torchvision's Inception.
  141. This method first constructs torchvision's Inception and then patches the
  142. necessary parts that are different in the FID Inception model.
  143. """
  144. inception = models.inception_v3(
  145. num_classes=1008, aux_logits=False, pretrained=False
  146. )
  147. inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
  148. inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
  149. inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
  150. inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
  151. inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
  152. inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
  153. inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
  154. inception.Mixed_7b = FIDInceptionE_1(1280)
  155. inception.Mixed_7c = FIDInceptionE_2(2048)
  156. state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
  157. inception.load_state_dict(state_dict)
  158. return inception
  159. class FIDInceptionA(models.inception.InceptionA):
  160. """InceptionA block patched for FID computation"""
  161. def __init__(self, in_channels, pool_features):
  162. super(FIDInceptionA, self).__init__(in_channels, pool_features)
  163. def forward(self, x):
  164. branch1x1 = self.branch1x1(x)
  165. branch5x5 = self.branch5x5_1(x)
  166. branch5x5 = self.branch5x5_2(branch5x5)
  167. branch3x3dbl = self.branch3x3dbl_1(x)
  168. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  169. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  170. # Patch: Tensorflow's average pool does not use the padded zero's in
  171. # its average calculation
  172. branch_pool = F.avg_pool2d(
  173. x, kernel_size=3, stride=1, padding=1, count_include_pad=False
  174. )
  175. branch_pool = self.branch_pool(branch_pool)
  176. outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
  177. return torch.cat(outputs, 1)
  178. class FIDInceptionC(models.inception.InceptionC):
  179. """InceptionC block patched for FID computation"""
  180. def __init__(self, in_channels, channels_7x7):
  181. super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
  182. def forward(self, x):
  183. branch1x1 = self.branch1x1(x)
  184. branch7x7 = self.branch7x7_1(x)
  185. branch7x7 = self.branch7x7_2(branch7x7)
  186. branch7x7 = self.branch7x7_3(branch7x7)
  187. branch7x7dbl = self.branch7x7dbl_1(x)
  188. branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
  189. branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
  190. branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
  191. branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
  192. # Patch: Tensorflow's average pool does not use the padded zero's in
  193. # its average calculation
  194. branch_pool = F.avg_pool2d(
  195. x, kernel_size=3, stride=1, padding=1, count_include_pad=False
  196. )
  197. branch_pool = self.branch_pool(branch_pool)
  198. outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
  199. return torch.cat(outputs, 1)
  200. class FIDInceptionE_1(models.inception.InceptionE):
  201. """First InceptionE block patched for FID computation"""
  202. def __init__(self, in_channels):
  203. super(FIDInceptionE_1, self).__init__(in_channels)
  204. def forward(self, x):
  205. branch1x1 = self.branch1x1(x)
  206. branch3x3 = self.branch3x3_1(x)
  207. branch3x3 = [
  208. self.branch3x3_2a(branch3x3),
  209. self.branch3x3_2b(branch3x3),
  210. ]
  211. branch3x3 = torch.cat(branch3x3, 1)
  212. branch3x3dbl = self.branch3x3dbl_1(x)
  213. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  214. branch3x3dbl = [
  215. self.branch3x3dbl_3a(branch3x3dbl),
  216. self.branch3x3dbl_3b(branch3x3dbl),
  217. ]
  218. branch3x3dbl = torch.cat(branch3x3dbl, 1)
  219. # Patch: Tensorflow's average pool does not use the padded zero's in
  220. # its average calculation
  221. branch_pool = F.avg_pool2d(
  222. x, kernel_size=3, stride=1, padding=1, count_include_pad=False
  223. )
  224. branch_pool = self.branch_pool(branch_pool)
  225. outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
  226. return torch.cat(outputs, 1)
  227. class FIDInceptionE_2(models.inception.InceptionE):
  228. """Second InceptionE block patched for FID computation"""
  229. def __init__(self, in_channels):
  230. super(FIDInceptionE_2, self).__init__(in_channels)
  231. def forward(self, x):
  232. branch1x1 = self.branch1x1(x)
  233. branch3x3 = self.branch3x3_1(x)
  234. branch3x3 = [
  235. self.branch3x3_2a(branch3x3),
  236. self.branch3x3_2b(branch3x3),
  237. ]
  238. branch3x3 = torch.cat(branch3x3, 1)
  239. branch3x3dbl = self.branch3x3dbl_1(x)
  240. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  241. branch3x3dbl = [
  242. self.branch3x3dbl_3a(branch3x3dbl),
  243. self.branch3x3dbl_3b(branch3x3dbl),
  244. ]
  245. branch3x3dbl = torch.cat(branch3x3dbl, 1)
  246. # Patch: The FID Inception model uses max pooling instead of average
  247. # pooling. This is likely an error in this specific Inception
  248. # implementation, as other Inception models use average pooling here
  249. # (which matches the description in the paper).
  250. branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
  251. branch_pool = self.branch_pool(branch_pool)
  252. outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
  253. return torch.cat(outputs, 1)