conv_learner.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from .core import *
  2. from .layers import *
  3. from .learner import *
  4. from .initializers import *
  5. model_meta = {
  6. resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
  7. vgg16:[0,22], vgg19:[0,22],
  8. resnext50:[8,6], resnext101:[8,6], resnext101_64:[8,6],
  9. wrn:[8,6], inceptionresnet_2:[-2,9], inception_4:[-1,9],
  10. dn121:[0,7], dn161:[0,7], dn169:[0,7], dn201:[0,7],
  11. }
  12. model_features = {inception_4: 3072, dn121: 2048, dn161: 4416,} # nasnetalarge: 4032*2}
  13. class ConvnetBuilder():
  14. """Class representing a convolutional network.
  15. Arguments:
  16. f: a model creation function (e.g. resnet34, vgg16, etc)
  17. c (int): size of the last layer
  18. is_multi (bool): is multilabel classification?
  19. (def here http://scikit-learn.org/stable/modules/multiclass.html)
  20. is_reg (bool): is a regression?
  21. ps (float or array of float): dropout parameters
  22. xtra_fc (list of ints): list of hidden layers with # hidden neurons
  23. xtra_cut (int): # layers earlier than default to cut the model, default is 0
  24. custom_head : add custom model classes that are inherited from nn.modules at the end of the model
  25. that is mentioned on Argument 'f'
  26. """
  27. def __init__(self, f, c, is_multi, is_reg, ps=None, xtra_fc=None, xtra_cut=0, custom_head=None, pretrained=True):
  28. self.f,self.c,self.is_multi,self.is_reg,self.xtra_cut = f,c,is_multi,is_reg,xtra_cut
  29. if xtra_fc is None: xtra_fc = [512]
  30. if ps is None: ps = [0.25]*len(xtra_fc) + [0.5]
  31. self.ps,self.xtra_fc = ps,xtra_fc
  32. if f in model_meta: cut,self.lr_cut = model_meta[f]
  33. else: cut,self.lr_cut = 0,0
  34. cut-=xtra_cut
  35. layers = cut_model(f(pretrained), cut)
  36. self.nf = model_features[f] if f in model_features else (num_features(layers)*2)
  37. if not custom_head: layers += [AdaptiveConcatPool2d(), Flatten()]
  38. self.top_model = nn.Sequential(*layers)
  39. n_fc = len(self.xtra_fc)+1
  40. if not isinstance(self.ps, list): self.ps = [self.ps]*n_fc
  41. if custom_head: fc_layers = [custom_head]
  42. else: fc_layers = self.get_fc_layers()
  43. self.n_fc = len(fc_layers)
  44. self.fc_model = to_gpu(nn.Sequential(*fc_layers))
  45. if not custom_head: apply_init(self.fc_model, kaiming_normal)
  46. self.model = to_gpu(nn.Sequential(*(layers+fc_layers)))
  47. @property
  48. def name(self): return f'{self.f.__name__}_{self.xtra_cut}'
  49. def create_fc_layer(self, ni, nf, p, actn=None):
  50. res=[nn.BatchNorm1d(num_features=ni)]
  51. if p: res.append(nn.Dropout(p=p))
  52. res.append(nn.Linear(in_features=ni, out_features=nf))
  53. if actn: res.append(actn)
  54. return res
  55. def get_fc_layers(self):
  56. res=[]
  57. ni=self.nf
  58. for i,nf in enumerate(self.xtra_fc):
  59. res += self.create_fc_layer(ni, nf, p=self.ps[i], actn=nn.ReLU())
  60. ni=nf
  61. final_actn = nn.Sigmoid() if self.is_multi else nn.LogSoftmax()
  62. if self.is_reg: final_actn = None
  63. res += self.create_fc_layer(ni, self.c, p=self.ps[-1], actn=final_actn)
  64. return res
  65. def get_layer_groups(self, do_fc=False):
  66. if do_fc:
  67. return [self.fc_model]
  68. idxs = [self.lr_cut]
  69. c = children(self.top_model)
  70. if len(c)==3: c = children(c[0])+c[1:]
  71. lgs = list(split_by_idxs(c,idxs))
  72. return lgs+[self.fc_model]
  73. class ConvLearner(Learner):
  74. """
  75. Class used to train a chosen supported covnet model. Eg. ResNet-34, etc.
  76. Arguments:
  77. data: training data for model
  78. models: model architectures to base learner
  79. precompute: bool to reuse precomputed activations
  80. **kwargs: parameters from Learner() class
  81. """
  82. def __init__(self, data, models, precompute=False, **kwargs):
  83. self.precompute = False
  84. super().__init__(data, models, **kwargs)
  85. if hasattr(data, 'is_multi') and not data.is_reg and self.metrics is None:
  86. self.metrics = [accuracy_thresh(0.5)] if self.data.is_multi else [accuracy]
  87. if precompute: self.save_fc1()
  88. self.freeze()
  89. self.precompute = precompute
  90. def _get_crit(self, data):
  91. if not hasattr(data, 'is_multi'): return super()._get_crit(data)
  92. return F.l1_loss if data.is_reg else F.binary_cross_entropy if data.is_multi else F.nll_loss
  93. @classmethod
  94. def pretrained(cls, f, data, ps=None, xtra_fc=None, xtra_cut=0, custom_head=None, precompute=False,
  95. pretrained=True, **kwargs):
  96. models = ConvnetBuilder(f, data.c, data.is_multi, data.is_reg,
  97. ps=ps, xtra_fc=xtra_fc, xtra_cut=xtra_cut, custom_head=custom_head, pretrained=pretrained)
  98. return cls(data, models, precompute, **kwargs)
  99. @classmethod
  100. def lsuv_learner(cls, f, data, ps=None, xtra_fc=None, xtra_cut=0, custom_head=None, precompute=False,
  101. needed_std=1.0, std_tol=0.1, max_attempts=10, do_orthonorm=False, **kwargs):
  102. models = ConvnetBuilder(f, data.c, data.is_multi, data.is_reg,
  103. ps=ps, xtra_fc=xtra_fc, xtra_cut=xtra_cut, custom_head=custom_head, pretrained=False)
  104. convlearn=cls(data, models, precompute, **kwargs)
  105. convlearn.lsuv_init()
  106. return convlearn
  107. @property
  108. def model(self): return self.models.fc_model if self.precompute else self.models.model
  109. def half(self):
  110. if self.fp16: return
  111. self.fp16 = True
  112. if type(self.model) != FP16: self.models.model = FP16(self.model)
  113. if not isinstance(self.models.fc_model, FP16): self.models.fc_model = FP16(self.models.fc_model)
  114. def float(self):
  115. if not self.fp16: return
  116. self.fp16 = False
  117. if type(self.models.model) == FP16: self.models.model = self.model.module.float()
  118. if type(self.models.fc_model) == FP16: self.models.fc_model = self.models.fc_model.module.float()
  119. @property
  120. def data(self): return self.fc_data if self.precompute else self.data_
  121. def create_empty_bcolz(self, n, name):
  122. return bcolz.carray(np.zeros((0,n), np.float32), chunklen=1, mode='w', rootdir=name)
  123. def set_data(self, data, precompute=False):
  124. super().set_data(data)
  125. if precompute:
  126. self.unfreeze()
  127. self.save_fc1()
  128. self.freeze()
  129. self.precompute = True
  130. else:
  131. self.freeze()
  132. def get_layer_groups(self):
  133. return self.models.get_layer_groups(self.precompute)
  134. def summary(self):
  135. precompute = self.precompute
  136. self.precompute = False
  137. res = super().summary()
  138. self.precompute = precompute
  139. return res
  140. def get_activations(self, force=False):
  141. tmpl = f'_{self.models.name}_{self.data.sz}.bc'
  142. # TODO: Somehow check that directory names haven't changed (e.g. added test set)
  143. names = [os.path.join(self.tmp_path, p+tmpl) for p in ('x_act', 'x_act_val', 'x_act_test')]
  144. if os.path.exists(names[0]) and not force:
  145. self.activations = [bcolz.open(p) for p in names]
  146. else:
  147. self.activations = [self.create_empty_bcolz(self.models.nf,n) for n in names]
  148. def save_fc1(self):
  149. self.get_activations()
  150. act, val_act, test_act = self.activations
  151. m=self.models.top_model
  152. if len(self.activations[0])!=len(self.data.trn_ds):
  153. predict_to_bcolz(m, self.data.fix_dl, act)
  154. if len(self.activations[1])!=len(self.data.val_ds):
  155. predict_to_bcolz(m, self.data.val_dl, val_act)
  156. if self.data.test_dl and (len(self.activations[2])!=len(self.data.test_ds)):
  157. if self.data.test_dl: predict_to_bcolz(m, self.data.test_dl, test_act)
  158. self.fc_data = ImageClassifierData.from_arrays(self.data.path,
  159. (act, self.data.trn_y), (val_act, self.data.val_y), self.data.bs, classes=self.data.classes,
  160. test = test_act if self.data.test_dl else None, num_workers=8)
  161. def freeze(self):
  162. """ Freeze all but the very last layer.
  163. Make all layers untrainable (i.e. frozen) except for the last layer.
  164. Returns:
  165. None
  166. """
  167. self.freeze_to(-1)
  168. def unfreeze(self):
  169. """ Unfreeze all layers.
  170. Make all layers trainable by unfreezing. This will also set the `precompute` to `False` since we can
  171. no longer pre-calculate the activation of frozen layers.
  172. Returns:
  173. None
  174. """
  175. self.freeze_to(0)
  176. self.precompute = False