model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. from .imports import *
  2. from .torch_imports import *
  3. from .core import *
  4. from .layer_optimizer import *
  5. from .swa import *
  6. from .fp16 import *
  7. IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4')
  8. def cut_model(m, cut):
  9. return list(m.children())[:cut] if cut else [m]
  10. def predict_to_bcolz(m, gen, arr, workers=4):
  11. arr.trim(len(arr))
  12. lock=threading.Lock()
  13. m.eval()
  14. for x,*_ in tqdm(gen):
  15. y = to_np(m(VV(x)).data)
  16. with lock:
  17. arr.append(y)
  18. arr.flush()
  19. def num_features(m):
  20. c=children(m)
  21. if len(c)==0: return None
  22. for l in reversed(c):
  23. if hasattr(l, 'num_features'): return l.num_features
  24. res = num_features(l)
  25. if res is not None: return res
  26. def torch_item(x): return x.item() if hasattr(x,'item') else x[0]
  27. class Stepper():
  28. def __init__(self, m, opt, crit, clip=0, reg_fn=None, fp16=False, loss_scale=1):
  29. self.m,self.opt,self.crit,self.clip,self.reg_fn = m,opt,crit,clip,reg_fn
  30. self.fp16 = fp16
  31. self.reset(True)
  32. if self.fp16: self.fp32_params = copy_model_to_fp32(m, opt)
  33. self.loss_scale = loss_scale
  34. def reset(self, train=True):
  35. if train: apply_leaf(self.m, set_train_mode)
  36. else: self.m.eval()
  37. if hasattr(self.m, 'reset'):
  38. self.m.reset()
  39. if self.fp16: self.fp32_params = copy_model_to_fp32(self.m, self.opt)
  40. def step(self, xs, y, epoch):
  41. xtra = []
  42. output = self.m(*xs)
  43. if isinstance(output,tuple): output,*xtra = output
  44. if self.fp16: self.m.zero_grad()
  45. else: self.opt.zero_grad()
  46. loss = raw_loss = self.crit(output, y)
  47. if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale
  48. if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
  49. loss.backward()
  50. if self.fp16: update_fp32_grads(self.fp32_params, self.m)
  51. if self.loss_scale != 1:
  52. for param in self.fp32_params: param.grad.data.div_(self.loss_scale)
  53. if self.clip: # Gradient clipping
  54. if IS_TORCH_04: nn.utils.clip_grad_norm_(trainable_params_(self.m), self.clip)
  55. else: nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip)
  56. if 'wd' in self.opt.param_groups[0] and self.opt.param_groups[0]['wd'] != 0:
  57. #Weight decay out of the loss. After the gradient computation but before the step.
  58. for group in self.opt.param_groups:
  59. lr, wd = group['lr'], group['wd']
  60. for p in group['params']:
  61. if p.grad is not None: p.data = p.data.add(-wd * lr, p.data)
  62. self.opt.step()
  63. if self.fp16:
  64. copy_fp32_to_model(self.m, self.fp32_params)
  65. torch.cuda.synchronize()
  66. return torch_item(raw_loss.data)
  67. def evaluate(self, xs, y):
  68. preds = self.m(*xs)
  69. if isinstance(preds,tuple): preds=preds[0]
  70. return preds, self.crit(preds, y)
  71. def set_train_mode(m):
  72. if (hasattr(m, 'running_mean') and (getattr(m,'bn_freeze',False)
  73. or not getattr(m,'trainable',False))): m.eval()
  74. elif (getattr(m,'drop_freeze',False) and hasattr(m, 'p')
  75. and ('drop' in type(m).__name__.lower())): m.eval()
  76. else: m.train()
  77. def fit(model, data, n_epochs, opt, crit, metrics=None, callbacks=None, stepper=Stepper,
  78. swa_model=None, swa_start=None, swa_eval_freq=None, visualize=False, **kwargs):
  79. """ Fits a model
  80. Arguments:
  81. model (model): any pytorch module
  82. net = to_gpu(net)
  83. data (ModelData): see ModelData class and subclasses (can be a list)
  84. opts: an optimizer. Example: optim.Adam.
  85. If n_epochs is a list, it needs to be the layer_optimizer to get the optimizer as it changes.
  86. n_epochs(int or list): number of epochs (or list of number of epochs)
  87. crit: loss function to optimize. Example: F.cross_entropy
  88. """
  89. seq_first = kwargs.pop('seq_first', False)
  90. all_val = kwargs.pop('all_val', False)
  91. get_ep_vals = kwargs.pop('get_ep_vals', False)
  92. metrics = metrics or []
  93. callbacks = callbacks or []
  94. avg_mom=0.98
  95. batch_num,avg_loss=0,0.
  96. for cb in callbacks: cb.on_train_begin()
  97. names = ["epoch", "trn_loss", "val_loss"] + [f.__name__ for f in metrics]
  98. if swa_model is not None:
  99. swa_names = ['swa_loss'] + [f'swa_{f.__name__}' for f in metrics]
  100. names += swa_names
  101. # will use this to call evaluate later
  102. swa_stepper = stepper(swa_model, None, crit, **kwargs)
  103. layout = "{!s:10} " * len(names)
  104. if not isinstance(n_epochs, Iterable): n_epochs=[n_epochs]
  105. if not isinstance(data, Iterable): data = [data]
  106. if len(data) == 1: data = data * len(n_epochs)
  107. for cb in callbacks: cb.on_phase_begin()
  108. model_stepper = stepper(model, opt.opt if hasattr(opt,'opt') else opt, crit, **kwargs)
  109. ep_vals = collections.OrderedDict()
  110. tot_epochs = int(np.ceil(np.array(n_epochs).sum()))
  111. cnt_phases = np.array([ep * len(dat.trn_dl) for (ep,dat) in zip(n_epochs,data)]).cumsum()
  112. phase = 0
  113. for epoch in tnrange(tot_epochs, desc='Epoch'):
  114. if phase >= len(n_epochs): break #Sometimes cumulated errors make this append.
  115. model_stepper.reset(True)
  116. cur_data = data[phase]
  117. if hasattr(cur_data, 'trn_sampler'): cur_data.trn_sampler.set_epoch(epoch)
  118. if hasattr(cur_data, 'val_sampler'): cur_data.val_sampler.set_epoch(epoch)
  119. num_batch = len(cur_data.trn_dl)
  120. t = tqdm(iter(cur_data.trn_dl), leave=False, total=num_batch, miniters=0)
  121. if all_val: val_iter = IterBatch(cur_data.val_dl)
  122. for (*x,y) in t:
  123. batch_num += 1
  124. for cb in callbacks: cb.on_batch_begin()
  125. loss = model_stepper.step(V(x),V(y), epoch)
  126. avg_loss = avg_loss * avg_mom + loss * (1-avg_mom)
  127. debias_loss = avg_loss / (1 - avg_mom**batch_num)
  128. t.set_postfix(loss=debias_loss, refresh=False)
  129. stop=False
  130. los = debias_loss if not all_val else [debias_loss] + validate_next(model_stepper,metrics, val_iter)
  131. for cb in callbacks: stop = stop or cb.on_batch_end(los)
  132. if stop: return
  133. if batch_num >= cnt_phases[phase]:
  134. for cb in callbacks: cb.on_phase_end()
  135. phase += 1
  136. if phase >= len(n_epochs):
  137. t.close()
  138. break
  139. for cb in callbacks: cb.on_phase_begin()
  140. if isinstance(opt, LayerOptimizer): model_stepper.opt = opt.opt
  141. if cur_data != data[phase]:
  142. t.close()
  143. break
  144. if not all_val:
  145. vals = validate(model_stepper, cur_data.val_dl, metrics, seq_first=seq_first)
  146. stop=False
  147. for cb in callbacks: stop = stop or cb.on_epoch_end(vals)
  148. if swa_model is not None:
  149. if (epoch + 1) >= swa_start and ((epoch + 1 - swa_start) % swa_eval_freq == 0 or epoch == tot_epochs - 1):
  150. fix_batchnorm(swa_model, cur_data.trn_dl)
  151. swa_vals = validate(swa_stepper, cur_data.val_dl, metrics)
  152. vals += swa_vals
  153. if epoch > 0:
  154. print_stats(epoch, [debias_loss] + vals, visualize, prev_val)
  155. else:
  156. print(layout.format(*names))
  157. print_stats(epoch, [debias_loss] + vals, visualize)
  158. prev_val = [debias_loss] + vals
  159. ep_vals = append_stats(ep_vals, epoch, [debias_loss] + vals)
  160. if stop: break
  161. for cb in callbacks: cb.on_train_end()
  162. if get_ep_vals: return vals, ep_vals
  163. else: return vals
  164. def append_stats(ep_vals, epoch, values, decimals=6):
  165. ep_vals[epoch]=list(np.round(values, decimals))
  166. return ep_vals
  167. def print_stats(epoch, values, visualize, prev_val=[], decimals=6):
  168. layout = "{!s:^10}" + " {!s:10}" * len(values)
  169. values = [epoch] + list(np.round(values, decimals))
  170. sym = ""
  171. if visualize:
  172. if epoch == 0: pass
  173. elif values[1] > prev_val[0] and values[2] > prev_val[1]: sym = " △ △"
  174. elif values[1] > prev_val[0] and values[2] < prev_val[1]: sym = " △ ▼"
  175. elif values[1] < prev_val[0] and values[2] > prev_val[1]: sym = " ▼ △"
  176. elif values[1] < prev_val[0] and values[2] < prev_val[1]: sym = " ▼ ▼"
  177. print(layout.format(*values) + sym)
  178. class IterBatch():
  179. def __init__(self, dl):
  180. self.idx = 0
  181. self.dl = dl
  182. self.iter = iter(dl)
  183. def __iter__(self): return self
  184. def next(self):
  185. res = next(self.iter)
  186. self.idx += 1
  187. if self.idx == len(self.dl):
  188. self.iter = iter(self.dl)
  189. self.idx=0
  190. return res
  191. def validate_next(stepper, metrics, val_iter):
  192. """Computes the loss on the next minibatch of the validation set."""
  193. stepper.reset(False)
  194. with no_grad_context():
  195. (*x,y) = val_iter.next()
  196. preds,l = stepper.evaluate(VV(x), VV(y))
  197. res = [delistify(to_np(l))]
  198. res += [f(preds.data,y) for f in metrics]
  199. stepper.reset(True)
  200. return res
  201. def batch_sz(x, seq_first=False):
  202. if is_listy(x): x = x[0]
  203. return x.shape[1 if seq_first else 0]
  204. def validate(stepper, dl, metrics, seq_first=False):
  205. batch_cnts,loss,res = [],[],[]
  206. stepper.reset(False)
  207. with no_grad_context():
  208. for (*x,y) in iter(dl):
  209. preds, l = stepper.evaluate(VV(x), VV(y))
  210. batch_cnts.append(batch_sz(x, seq_first=seq_first))
  211. loss.append(to_np(l))
  212. res.append([f(preds.data, y) for f in metrics])
  213. return [np.average(loss, 0, weights=batch_cnts)] + list(np.average(np.stack(res), 0, weights=batch_cnts))
  214. def get_prediction(x):
  215. if is_listy(x): x=x[0]
  216. return x.data
  217. def predict(m, dl):
  218. preda,_ = predict_with_targs_(m, dl)
  219. return np.concatenate(preda)
  220. def predict_batch(m, x):
  221. m.eval()
  222. if hasattr(m, 'reset'): m.reset()
  223. return m(VV(x))
  224. def predict_with_targs_(m, dl):
  225. m.eval()
  226. if hasattr(m, 'reset'): m.reset()
  227. res = []
  228. for *x,y in iter(dl): res.append([get_prediction(to_np(m(*VV(x)))),to_np(y)])
  229. return zip(*res)
  230. def predict_with_targs(m, dl):
  231. preda,targa = predict_with_targs_(m, dl)
  232. return np.concatenate(preda), np.concatenate(targa)
  233. # From https://github.com/ncullen93/torchsample
  234. def model_summary(m, inputs):
  235. def register_hook(module):
  236. def hook(module, input, output):
  237. class_name = str(module.__class__).split('.')[-1].split("'")[0]
  238. module_idx = len(summary)
  239. m_key = '%s-%i' % (class_name, module_idx+1)
  240. summary[m_key] = OrderedDict()
  241. summary[m_key]['input_shape'] = list(input[0].size())
  242. summary[m_key]['input_shape'][0] = -1
  243. if is_listy(output):
  244. summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output]
  245. else:
  246. summary[m_key]['output_shape'] = list(output.size())
  247. summary[m_key]['output_shape'][0] = -1
  248. params = 0
  249. if hasattr(module, 'weight'):
  250. params += torch.prod(torch.LongTensor(list(module.weight.size())))
  251. summary[m_key]['trainable'] = module.weight.requires_grad
  252. if hasattr(module, 'bias') and module.bias is not None:
  253. params += torch.prod(torch.LongTensor(list(module.bias.size())))
  254. summary[m_key]['nb_params'] = params
  255. if (not isinstance(module, nn.Sequential) and
  256. not isinstance(module, nn.ModuleList) and
  257. not (module == m)):
  258. hooks.append(module.register_forward_hook(hook))
  259. summary = OrderedDict()
  260. hooks = []
  261. m.apply(register_hook)
  262. xs = [to_gpu(Variable(x)) for x in inputs]
  263. m(*xs)
  264. for h in hooks: h.remove()
  265. return summary