torch_imports.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. from distutils.version import LooseVersion
  3. import torch, torchvision, torchtext
  4. from torch import nn, cuda, backends, FloatTensor, LongTensor, optim
  5. import torch.nn.functional as F
  6. from torch.autograd import Variable
  7. from torch.utils.data import Dataset, TensorDataset
  8. from torch.nn.init import kaiming_uniform, kaiming_normal
  9. from torchvision.transforms import Compose
  10. from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
  11. from torchvision.models import vgg16_bn, vgg19_bn
  12. from torchvision.models import densenet121, densenet161, densenet169, densenet201
  13. from .models.resnext_50_32x4d import resnext_50_32x4d
  14. from .models.resnext_101_32x4d import resnext_101_32x4d
  15. from .models.resnext_101_64x4d import resnext_101_64x4d
  16. from .models.wrn_50_2f import wrn_50_2f
  17. from .models.inceptionresnetv2 import InceptionResnetV2
  18. from .models.inceptionv4 import inceptionv4
  19. from .models.nasnet import nasnetalarge
  20. from .models.fa_resnet import *
  21. import warnings
  22. warnings.filterwarnings('ignore', message='Implicit dimension choice', category=UserWarning)
  23. def children(m): return m if isinstance(m, (list, tuple)) else list(m.children())
  24. def save_model(m, p): torch.save(m.state_dict(), p)
  25. def load_model(m, p):
  26. sd = torch.load(p, map_location=lambda storage, loc: storage)
  27. names = set(m.state_dict().keys())
  28. for n in list(sd.keys()): # list "detatches" the iterator
  29. if n not in names and n+'_raw' in names:
  30. if n+'_raw' not in sd: sd[n+'_raw'] = sd[n]
  31. del sd[n]
  32. m.load_state_dict(sd)
  33. def load_pre(pre, f, fn):
  34. m = f()
  35. path = os.path.dirname(__file__)
  36. if pre: load_model(m, f'{path}/weights/{fn}.pth')
  37. return m
  38. def _fastai_model(name, paper_title, paper_href):
  39. def add_docs_wrapper(f):
  40. f.__doc__ = f"""{name} model from
  41. `"{paper_title}" <{paper_href}>`_
  42. Args:
  43. pre (bool): If True, returns a model pre-trained on ImageNet
  44. """
  45. return f
  46. return add_docs_wrapper
  47. @_fastai_model('Inception 4', 'Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning',
  48. 'https://arxiv.org/pdf/1602.07261.pdf')
  49. def inception_4(pre): return children(inceptionv4(pretrained=pre))[0]
  50. @_fastai_model('Inception 4', 'Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning',
  51. 'https://arxiv.org/pdf/1602.07261.pdf')
  52. def inceptionresnet_2(pre): return load_pre(pre, InceptionResnetV2, 'inceptionresnetv2-d579a627')
  53. @_fastai_model('ResNeXt 50', 'Aggregated Residual Transformations for Deep Neural Networks',
  54. 'https://arxiv.org/abs/1611.05431')
  55. def resnext50(pre): return load_pre(pre, resnext_50_32x4d, 'resnext_50_32x4d')
  56. @_fastai_model('ResNeXt 101_32', 'Aggregated Residual Transformations for Deep Neural Networks',
  57. 'https://arxiv.org/abs/1611.05431')
  58. def resnext101(pre): return load_pre(pre, resnext_101_32x4d, 'resnext_101_32x4d')
  59. @_fastai_model('ResNeXt 101_64', 'Aggregated Residual Transformations for Deep Neural Networks',
  60. 'https://arxiv.org/abs/1611.05431')
  61. def resnext101_64(pre): return load_pre(pre, resnext_101_64x4d, 'resnext_101_64x4d')
  62. @_fastai_model('Wide Residual Networks', 'Wide Residual Networks',
  63. 'https://arxiv.org/pdf/1605.07146.pdf')
  64. def wrn(pre): return load_pre(pre, wrn_50_2f, 'wrn_50_2f')
  65. @_fastai_model('Densenet-121', 'Densely Connected Convolutional Networks',
  66. 'https://arxiv.org/pdf/1608.06993.pdf')
  67. def dn121(pre): return children(densenet121(pre))[0]
  68. @_fastai_model('Densenet-169', 'Densely Connected Convolutional Networks',
  69. 'https://arxiv.org/pdf/1608.06993.pdf')
  70. def dn161(pre): return children(densenet161(pre))[0]
  71. @_fastai_model('Densenet-161', 'Densely Connected Convolutional Networks',
  72. 'https://arxiv.org/pdf/1608.06993.pdf')
  73. def dn169(pre): return children(densenet169(pre))[0]
  74. @_fastai_model('Densenet-201', 'Densely Connected Convolutional Networks',
  75. 'https://arxiv.org/pdf/1608.06993.pdf')
  76. def dn201(pre): return children(densenet201(pre))[0]
  77. @_fastai_model('Vgg-16 with batch norm added', 'Very Deep Convolutional Networks for Large-Scale Image Recognition',
  78. 'https://arxiv.org/pdf/1409.1556.pdf')
  79. def vgg16(pre): return children(vgg16_bn(pre))[0]
  80. @_fastai_model('Vgg-19 with batch norm added', 'Very Deep Convolutional Networks for Large-Scale Image Recognition',
  81. 'https://arxiv.org/pdf/1409.1556.pdf')
  82. def vgg19(pre): return children(vgg19_bn(pre))[0]