sgdr.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. from .imports import *
  2. from .layer_optimizer import *
  3. from enum import IntEnum
  4. from timeit import default_timer as timer
  5. import copy
  6. class Callback:
  7. '''
  8. An abstract class that all callback(e.g., LossRecorder) classes extends from.
  9. Must be extended before usage.
  10. '''
  11. def on_train_begin(self): pass
  12. def on_batch_begin(self): pass
  13. def on_phase_begin(self): pass
  14. def on_epoch_end(self, metrics): pass
  15. def on_phase_end(self): pass
  16. def on_batch_end(self, metrics): pass
  17. def on_train_end(self): pass
  18. # Useful for maintaining status of a long-running job.
  19. #
  20. # Usage:
  21. # learn.fit(0.01, 1, callbacks = [LoggingCallback(save_path="/tmp/log")])
  22. class LoggingCallback(Callback):
  23. '''
  24. A class useful for maintaining status of a long-running job.
  25. e.g.: learn.fit(0.01, 1, callbacks = [LoggingCallback(save_path="/tmp/log")])
  26. '''
  27. def __init__(self, save_path):
  28. super().__init__()
  29. self.save_path=save_path
  30. def on_train_begin(self):
  31. self.batch = 0
  32. self.epoch = 0
  33. self.phase = 0
  34. self.f = open(self.save_path, "a", 1)
  35. self.log("\ton_train_begin")
  36. def on_batch_begin(self):
  37. self.log(str(self.batch)+"\ton_batch_begin")
  38. def on_phase_begin(self):
  39. self.log(str(self.phase)+"\ton_phase_begin")
  40. def on_epoch_end(self, metrics):
  41. self.log(str(self.epoch)+"\ton_epoch_end: "+str(metrics))
  42. self.epoch += 1
  43. def on_phase_end(self):
  44. self.log(str(self.phase)+"\ton_phase_end")
  45. self.phase+=1
  46. def on_batch_end(self, metrics):
  47. self.log(str(self.batch)+"\ton_batch_end: "+str(metrics))
  48. self.batch += 1
  49. def on_train_end(self):
  50. self.log("\ton_train_end")
  51. self.f.close()
  52. def log(self, string):
  53. self.f.write(time.strftime("%Y-%m-%dT%H:%M:%S")+"\t"+string+"\n")
  54. class LossRecorder(Callback):
  55. '''
  56. Saves and displays loss functions and other metrics.
  57. Default sched when none is specified in a learner.
  58. '''
  59. def __init__(self, layer_opt, save_path='', record_mom=False, metrics=[]):
  60. super().__init__()
  61. self.layer_opt=layer_opt
  62. self.init_lrs=np.array(layer_opt.lrs)
  63. self.save_path, self.record_mom, self.metrics = save_path, record_mom, metrics
  64. def on_train_begin(self):
  65. self.losses,self.lrs,self.iterations,self.epochs,self.times = [],[],[],[],[]
  66. self.start_at = timer()
  67. self.val_losses, self.rec_metrics = [], []
  68. if self.record_mom:
  69. self.momentums = []
  70. self.iteration = 0
  71. self.epoch = 0
  72. def on_epoch_end(self, metrics):
  73. self.epoch += 1
  74. self.epochs.append(self.iteration)
  75. self.times.append(timer() - self.start_at)
  76. self.save_metrics(metrics)
  77. def on_batch_end(self, loss):
  78. self.iteration += 1
  79. self.lrs.append(self.layer_opt.lr)
  80. self.iterations.append(self.iteration)
  81. if isinstance(loss, list):
  82. self.losses.append(loss[0])
  83. self.save_metrics(loss[1:])
  84. else: self.losses.append(loss)
  85. if self.record_mom: self.momentums.append(self.layer_opt.mom)
  86. def save_metrics(self,vals):
  87. self.val_losses.append(delistify(vals[0]))
  88. if len(vals) > 2: self.rec_metrics.append(vals[1:])
  89. elif len(vals) == 2: self.rec_metrics.append(vals[1])
  90. def plot_loss(self, n_skip=10, n_skip_end=5):
  91. '''
  92. plots loss function as function of iterations.
  93. When used in Jupyternotebook, plot will be displayed in notebook. Else, plot will be displayed in console and both plot and loss are saved in save_path.
  94. '''
  95. if not in_ipynb(): plt.switch_backend('agg')
  96. plt.plot(self.iterations[n_skip:-n_skip_end], self.losses[n_skip:-n_skip_end])
  97. if not in_ipynb():
  98. plt.savefig(os.path.join(self.save_path, 'loss_plot.png'))
  99. np.save(os.path.join(self.save_path, 'losses.npy'), self.losses[10:])
  100. def plot_lr(self):
  101. '''Plots learning rate in jupyter notebook or console, depending on the enviroment of the learner.'''
  102. if not in_ipynb():
  103. plt.switch_backend('agg')
  104. if self.record_mom:
  105. fig, axs = plt.subplots(1,2,figsize=(12,4))
  106. for i in range(0,2): axs[i].set_xlabel('iterations')
  107. axs[0].set_ylabel('learning rate')
  108. axs[1].set_ylabel('momentum')
  109. axs[0].plot(self.iterations,self.lrs)
  110. axs[1].plot(self.iterations,self.momentums)
  111. else:
  112. plt.xlabel("iterations")
  113. plt.ylabel("learning rate")
  114. plt.plot(self.iterations, self.lrs)
  115. if not in_ipynb():
  116. plt.savefig(os.path.join(self.save_path, 'lr_plot.png'))
  117. class LR_Updater(LossRecorder):
  118. '''
  119. Abstract class where all Learning Rate updaters inherit from. (e.g., CirularLR)
  120. Calculates and updates new learning rate and momentum at the end of each batch.
  121. Have to be extended.
  122. '''
  123. def on_train_begin(self):
  124. super().on_train_begin()
  125. self.update_lr()
  126. if self.record_mom:
  127. self.update_mom()
  128. def on_batch_end(self, loss):
  129. res = super().on_batch_end(loss)
  130. self.update_lr()
  131. if self.record_mom:
  132. self.update_mom()
  133. return res
  134. def update_lr(self):
  135. new_lrs = self.calc_lr(self.init_lrs)
  136. self.layer_opt.set_lrs(new_lrs)
  137. def update_mom(self):
  138. new_mom = self.calc_mom()
  139. self.layer_opt.set_mom(new_mom)
  140. @abstractmethod
  141. def calc_lr(self, init_lrs): raise NotImplementedError
  142. @abstractmethod
  143. def calc_mom(self): raise NotImplementedError
  144. class LR_Finder(LR_Updater):
  145. '''
  146. Helps you find an optimal learning rate for a model, as per suggetion of 2015 CLR paper.
  147. Learning rate is increased in linear or log scale, depending on user input, and the result of the loss funciton is retained and can be plotted later.
  148. '''
  149. def __init__(self, layer_opt, nb, end_lr=10, linear=False, metrics = []):
  150. self.linear, self.stop_dv = linear, True
  151. ratio = end_lr/layer_opt.lr
  152. self.lr_mult = (ratio/nb) if linear else ratio**(1/nb)
  153. super().__init__(layer_opt,metrics=metrics)
  154. def on_train_begin(self):
  155. super().on_train_begin()
  156. self.best=1e9
  157. def calc_lr(self, init_lrs):
  158. mult = self.lr_mult*self.iteration if self.linear else self.lr_mult**self.iteration
  159. return init_lrs * mult
  160. def on_batch_end(self, metrics):
  161. loss = metrics[0] if isinstance(metrics,list) else metrics
  162. if self.stop_dv and (math.isnan(loss) or loss>self.best*4):
  163. return True
  164. if (loss<self.best and self.iteration>10): self.best=loss
  165. return super().on_batch_end(metrics)
  166. def plot(self, n_skip=10, n_skip_end=5):
  167. '''
  168. Plots the loss function with respect to learning rate, in log scale.
  169. '''
  170. plt.ylabel("validation loss")
  171. plt.xlabel("learning rate (log scale)")
  172. plt.plot(self.lrs[n_skip:-(n_skip_end+1)], self.losses[n_skip:-(n_skip_end+1)])
  173. plt.xscale('log')
  174. plt.savefig(os.path.join(self.save_path, 'lr_loss_plot.png'))
  175. class LR_Finder2(LR_Finder):
  176. """
  177. A variant of lr_find() that helps find the best learning rate. It doesn't do
  178. an epoch but a fixed num of iterations (which may be more or less than an epoch
  179. depending on your data).
  180. """
  181. def __init__(self, layer_opt, nb, end_lr=10, linear=False, metrics=[], stop_dv=True):
  182. self.nb, self.metrics = nb, metrics
  183. super().__init__(layer_opt, nb, end_lr, linear, metrics)
  184. self.stop_dv = stop_dv
  185. def on_batch_end(self, loss):
  186. if self.iteration == self.nb:
  187. return True
  188. return super().on_batch_end(loss)
  189. def plot(self, n_skip=10, n_skip_end=5, smoothed=True):
  190. if self.metrics is None: self.metrics = []
  191. n_plots = len(self.metrics)+2
  192. fig, axs = plt.subplots(n_plots,figsize=(6,4*n_plots))
  193. for i in range(0,n_plots): axs[i].set_xlabel('learning rate')
  194. axs[0].set_ylabel('training loss')
  195. axs[1].set_ylabel('validation loss')
  196. for i,m in enumerate(self.metrics):
  197. axs[i+2].set_ylabel(m.__name__)
  198. if len(self.metrics) == 1:
  199. values = self.rec_metrics
  200. else:
  201. values = [rec[i] for rec in self.rec_metrics]
  202. if smoothed: values = smooth_curve(values,0.98)
  203. axs[i+2].plot(self.lrs[n_skip:-n_skip_end], values[n_skip:-n_skip_end])
  204. plt_val_l = smooth_curve(self.val_losses, 0.98) if smoothed else self.val_losses
  205. axs[0].plot(self.lrs[n_skip:-n_skip_end],self.losses[n_skip:-n_skip_end])
  206. axs[1].plot(self.lrs[n_skip:-n_skip_end],plt_val_l[n_skip:-n_skip_end])
  207. class CosAnneal(LR_Updater):
  208. ''' Learning rate scheduler that implements a cosine annealation schedule. '''
  209. def __init__(self, layer_opt, nb, on_cycle_end=None, cycle_mult=1):
  210. self.nb,self.on_cycle_end,self.cycle_mult = nb,on_cycle_end,cycle_mult
  211. super().__init__(layer_opt)
  212. def on_train_begin(self):
  213. self.cycle_iter,self.cycle_count=0,0
  214. super().on_train_begin()
  215. def calc_lr(self, init_lrs):
  216. if self.iteration<self.nb/20:
  217. self.cycle_iter += 1
  218. return init_lrs/100.
  219. cos_out = np.cos(np.pi*(self.cycle_iter)/self.nb) + 1
  220. self.cycle_iter += 1
  221. if self.cycle_iter==self.nb:
  222. self.cycle_iter = 0
  223. self.nb *= self.cycle_mult
  224. if self.on_cycle_end: self.on_cycle_end(self, self.cycle_count)
  225. self.cycle_count += 1
  226. return init_lrs / 2 * cos_out
  227. class CircularLR(LR_Updater):
  228. '''
  229. A learning rate updater that implements the CircularLearningRate (CLR) scheme.
  230. Learning rate is increased then decreased linearly.
  231. '''
  232. def __init__(self, layer_opt, nb, div=4, cut_div=8, on_cycle_end=None, momentums=None):
  233. self.nb,self.div,self.cut_div,self.on_cycle_end = nb,div,cut_div,on_cycle_end
  234. if momentums is not None:
  235. self.moms = momentums
  236. super().__init__(layer_opt, record_mom=(momentums is not None))
  237. def on_train_begin(self):
  238. self.cycle_iter,self.cycle_count=0,0
  239. super().on_train_begin()
  240. def calc_lr(self, init_lrs):
  241. cut_pt = self.nb//self.cut_div
  242. if self.cycle_iter>cut_pt:
  243. pct = 1 - (self.cycle_iter - cut_pt)/(self.nb - cut_pt)
  244. else: pct = self.cycle_iter/cut_pt
  245. res = init_lrs * (1 + pct*(self.div-1)) / self.div
  246. self.cycle_iter += 1
  247. if self.cycle_iter==self.nb:
  248. self.cycle_iter = 0
  249. if self.on_cycle_end: self.on_cycle_end(self, self.cycle_count)
  250. self.cycle_count += 1
  251. return res
  252. def calc_mom(self):
  253. cut_pt = self.nb//self.cut_div
  254. if self.cycle_iter>cut_pt:
  255. pct = (self.cycle_iter - cut_pt)/(self.nb - cut_pt)
  256. else: pct = 1 - self.cycle_iter/cut_pt
  257. res = self.moms[1] + pct * (self.moms[0] - self.moms[1])
  258. return res
  259. class CircularLR_beta(LR_Updater):
  260. def __init__(self, layer_opt, nb, div=10, pct=10, on_cycle_end=None, momentums=None):
  261. self.nb,self.div,self.pct,self.on_cycle_end = nb,div,pct,on_cycle_end
  262. self.cycle_nb = int(nb * (1-pct/100) / 2)
  263. if momentums is not None:
  264. self.moms = momentums
  265. super().__init__(layer_opt, record_mom=(momentums is not None))
  266. def on_train_begin(self):
  267. self.cycle_iter,self.cycle_count=0,0
  268. super().on_train_begin()
  269. def calc_lr(self, init_lrs):
  270. if self.cycle_iter>2 * self.cycle_nb:
  271. pct = (self.cycle_iter - 2*self.cycle_nb)/(self.nb - 2*self.cycle_nb)
  272. res = init_lrs * (1 + (pct * (1-100)/100)) / self.div
  273. elif self.cycle_iter>self.cycle_nb:
  274. pct = 1 - (self.cycle_iter - self.cycle_nb)/self.cycle_nb
  275. res = init_lrs * (1 + pct*(self.div-1)) / self.div
  276. else:
  277. pct = self.cycle_iter/self.cycle_nb
  278. res = init_lrs * (1 + pct*(self.div-1)) / self.div
  279. self.cycle_iter += 1
  280. if self.cycle_iter==self.nb:
  281. self.cycle_iter = 0
  282. if self.on_cycle_end: self.on_cycle_end(self, self.cycle_count)
  283. self.cycle_count += 1
  284. return res
  285. def calc_mom(self):
  286. if self.cycle_iter>2*self.cycle_nb:
  287. res = self.moms[0]
  288. elif self.cycle_iter>self.cycle_nb:
  289. pct = 1 - (self.cycle_iter - self.cycle_nb)/self.cycle_nb
  290. res = self.moms[0] + pct * (self.moms[1] - self.moms[0])
  291. else:
  292. pct = self.cycle_iter/self.cycle_nb
  293. res = self.moms[0] + pct * (self.moms[1] - self.moms[0])
  294. return res
  295. class SaveBestModel(LossRecorder):
  296. """ Save weights of the best model based during training.
  297. If metrics are provided, the first metric in the list is used to
  298. find the best model.
  299. If no metrics are provided, the loss is used.
  300. Args:
  301. model: the fastai model
  302. lr: indicate to use test images; otherwise use validation images
  303. name: the name of filename of the weights without '.h5'
  304. Usage:
  305. Briefly, you have your model 'learn' variable and call fit.
  306. >>> learn.fit(lr, 2, cycle_len=2, cycle_mult=1, best_save_name='mybestmodel')
  307. ....
  308. >>> learn.load('mybestmodel')
  309. For more details see http://forums.fast.ai/t/a-code-snippet-to-save-the-best-model-during-training/12066
  310. """
  311. def __init__(self, model, layer_opt, metrics, name='best_model'):
  312. super().__init__(layer_opt)
  313. self.name = name
  314. self.model = model
  315. self.best_loss = None
  316. self.best_acc = None
  317. self.save_method = self.save_when_only_loss if metrics==None else self.save_when_acc
  318. def save_when_only_loss(self, metrics):
  319. loss = metrics[0]
  320. if self.best_loss == None or loss < self.best_loss:
  321. self.best_loss = loss
  322. self.model.save(f'{self.name}')
  323. def save_when_acc(self, metrics):
  324. loss, acc = metrics[0], metrics[1]
  325. if self.best_acc == None or acc > self.best_acc:
  326. self.best_acc = acc
  327. self.best_loss = loss
  328. self.model.save(f'{self.name}')
  329. elif acc == self.best_acc and loss < self.best_loss:
  330. self.best_loss = loss
  331. self.model.save(f'{self.name}')
  332. def on_epoch_end(self, metrics):
  333. super().on_epoch_end(metrics)
  334. self.save_method(metrics)
  335. class WeightDecaySchedule(Callback):
  336. def __init__(self, layer_opt, batch_per_epoch, cycle_len, cycle_mult, n_cycles, norm_wds=False, wds_sched_mult=None):
  337. """
  338. Implements the weight decay schedule as mentioned in https://arxiv.org/abs/1711.05101
  339. :param layer_opt: The LayerOptimizer
  340. :param batch_per_epoch: Num batches in 1 epoch
  341. :param cycle_len: Num epochs in initial cycle. Subsequent cycle_len = previous cycle_len * cycle_mult
  342. :param cycle_mult: Cycle multiplier
  343. :param n_cycles: Number of cycles to be executed
  344. """
  345. super().__init__()
  346. self.layer_opt = layer_opt
  347. self.batch_per_epoch = batch_per_epoch
  348. self.init_wds = np.array(layer_opt.wds) # Weights as set by user
  349. self.init_lrs = np.array(layer_opt.lrs) # Learning rates as set by user
  350. self.new_wds = None # Holds the new weight decay factors, calculated in on_batch_begin()
  351. self.iteration = 0
  352. self.epoch = 0
  353. self.wds_sched_mult = wds_sched_mult
  354. self.norm_wds = norm_wds
  355. self.wds_history = list()
  356. # Pre calculating the number of epochs in the cycle of current running epoch
  357. self.epoch_to_num_cycles, i = dict(), 0
  358. for cycle in range(n_cycles):
  359. for _ in range(cycle_len):
  360. self.epoch_to_num_cycles[i] = cycle_len
  361. i += 1
  362. cycle_len *= cycle_mult
  363. def on_train_begin(self):
  364. self.iteration = 0
  365. self.epoch = 0
  366. def on_batch_begin(self):
  367. # Prepare for decay of weights
  368. # Default weight decay (as provided by user)
  369. wdn = self.init_wds
  370. # Weight decay multiplier (The 'eta' in the paper). Optional.
  371. wdm = 1.0
  372. if self.wds_sched_mult is not None:
  373. wdm = self.wds_sched_mult(self)
  374. # Weight decay normalized. Optional.
  375. if self.norm_wds:
  376. wdn = wdn / np.sqrt(self.batch_per_epoch * self.epoch_to_num_cycles[self.epoch])
  377. # Final wds
  378. self.new_wds = wdm * wdn
  379. # Set weight_decay with zeros so that it is not applied in Adam, we will apply it outside in on_batch_end()
  380. self.layer_opt.set_wds_out(self.new_wds)
  381. # We have to save the existing weights before the optimizer changes the values
  382. self.iteration += 1
  383. def on_epoch_end(self, metrics):
  384. self.epoch += 1
  385. class DecayType(IntEnum):
  386. ''' Data class, each decay type is assigned a number. '''
  387. NO = 1
  388. LINEAR = 2
  389. COSINE = 3
  390. EXPONENTIAL = 4
  391. POLYNOMIAL = 5
  392. class DecayScheduler():
  393. '''Given initial and endvalue, this class generates the next value depending on decay type and number of iterations. (by calling next_val().) '''
  394. def __init__(self, dec_type, num_it, start_val, end_val=None, extra=None):
  395. self.dec_type, self.nb, self.start_val, self.end_val, self.extra = dec_type, num_it, start_val, end_val, extra
  396. self.it = 0
  397. if self.end_val is None and not (self.dec_type in [1,4]): self.end_val = 0
  398. def next_val(self):
  399. self.it += 1
  400. if self.dec_type == DecayType.NO:
  401. return self.start_val
  402. elif self.dec_type == DecayType.LINEAR:
  403. pct = self.it/self.nb
  404. return self.start_val + pct * (self.end_val-self.start_val)
  405. elif self.dec_type == DecayType.COSINE:
  406. cos_out = np.cos(np.pi*(self.it)/self.nb) + 1
  407. return self.end_val + (self.start_val-self.end_val) / 2 * cos_out
  408. elif self.dec_type == DecayType.EXPONENTIAL:
  409. ratio = self.end_val / self.start_val
  410. return self.start_val * (ratio ** (self.it/self.nb))
  411. elif self.dec_type == DecayType.POLYNOMIAL:
  412. return self.end_val + (self.start_val-self.end_val) * (1 - self.it/self.nb)**self.extra
  413. class TrainingPhase():
  414. '''
  415. Object with training information for each phase, when multiple phases are involved during training.
  416. Used in fit_opt_sched in learner.py
  417. '''
  418. def __init__(self, epochs=1, opt_fn=optim.SGD, lr=1e-2, lr_decay=DecayType.NO, momentum=0.9,
  419. momentum_decay=DecayType.NO, beta=None, wds=None, wd_loss=True):
  420. """
  421. Creates an object containing all the relevant informations for one part of a model training.
  422. Args
  423. epochs: number of epochs to train like this
  424. opt_fn: an optimizer (example optim.Adam)
  425. lr: one learning rate or a tuple of the form (start_lr,end_lr)
  426. each of those can be a list/numpy array for differential learning rates
  427. lr_decay: a DecayType object specifying how the learning rate should change
  428. momentum: one momentum (or beta1 in case of Adam), or a tuple of the form (start_mom,end_mom)
  429. momentum_decay: a DecayType object specifying how the momentum should change
  430. beta: beta2 parameter of Adam or alpha parameter of RMSProp
  431. wds: weight decay (can be an array for differential wds)
  432. """
  433. self.epochs, self.opt_fn, self.lr, self.momentum, self.beta, self.wds = epochs, opt_fn, lr, momentum, beta, wds
  434. if isinstance(lr_decay,tuple): self.lr_decay, self.extra_lr = lr_decay
  435. else: self.lr_decay, self.extra_lr = lr_decay, None
  436. if isinstance(momentum_decay,tuple): self.mom_decay, self.extra_mom = momentum_decay
  437. else: self.mom_decay, self.extra_mom = momentum_decay, None
  438. self.wd_loss = wd_loss
  439. def phase_begin(self, layer_opt, nb_batches):
  440. self.layer_opt = layer_opt
  441. if isinstance(self.lr, tuple): start_lr,end_lr = self.lr
  442. else: start_lr, end_lr = self.lr, None
  443. self.lr_sched = DecayScheduler(self.lr_decay, nb_batches * self.epochs, start_lr, end_lr, extra=self.extra_lr)
  444. if isinstance(self.momentum, tuple): start_mom,end_mom = self.momentum
  445. else: start_mom, end_mom = self.momentum, None
  446. self.mom_sched = DecayScheduler(self.mom_decay, nb_batches * self.epochs, start_mom, end_mom, extra=self.extra_mom)
  447. self.layer_opt.set_opt_fn(self.opt_fn)
  448. self.layer_opt.set_lrs(start_lr)
  449. self.layer_opt.set_mom(start_mom)
  450. if self.beta is not None: self.layer_opt.set_beta(self.beta)
  451. if self.wds is not None:
  452. if not isinstance(self.wds, Iterable): self.wds=[self.wds]
  453. if len(self.wds)==1: self.wds=self.wds*len(self.layer_opt.layer_groups)
  454. if self.wd_loss: self.layer_opt.set_wds(self.wds)
  455. else: self.layer_opt.set_wds_out(self.wds)
  456. def update(self):
  457. new_lr, new_mom = self.lr_sched.next_val(), self.mom_sched.next_val()
  458. self.layer_opt.set_lrs(new_lr)
  459. self.layer_opt.set_mom(new_mom)
  460. class OptimScheduler(LossRecorder):
  461. '''Learning rate Scheduler for training involving multiple phases.'''
  462. def __init__(self, layer_opt, phases, nb_batches, stop_div = False):
  463. self.phases, self.nb_batches, self.stop_div = phases, nb_batches, stop_div
  464. super().__init__(layer_opt, record_mom=True)
  465. def on_train_begin(self):
  466. super().on_train_begin()
  467. self.phase,self.best=0,1e9
  468. def on_batch_end(self, metrics):
  469. loss = metrics[0] if isinstance(metrics,list) else metrics
  470. if self.stop_div and (math.isnan(loss) or loss>self.best*4):
  471. return True
  472. if (loss<self.best and self.iteration>10): self.best=loss
  473. super().on_batch_end(metrics)
  474. self.phases[self.phase].update()
  475. def on_phase_begin(self):
  476. self.phases[self.phase].phase_begin(self.layer_opt, self.nb_batches[self.phase])
  477. def on_phase_end(self):
  478. self.phase += 1
  479. def plot_lr(self, show_text=True, show_moms=True):
  480. """Plots the lr rate/momentum schedule"""
  481. phase_limits = [0]
  482. for nb_batch, phase in zip(self.nb_batches, self.phases):
  483. phase_limits.append(phase_limits[-1] + nb_batch * phase.epochs)
  484. if not in_ipynb():
  485. plt.switch_backend('agg')
  486. np_plts = 2 if show_moms else 1
  487. fig, axs = plt.subplots(1,np_plts,figsize=(6*np_plts,4))
  488. if not show_moms: axs = [axs]
  489. for i in range(np_plts): axs[i].set_xlabel('iterations')
  490. axs[0].set_ylabel('learning rate')
  491. axs[0].plot(self.iterations,self.lrs)
  492. if show_moms:
  493. axs[1].set_ylabel('momentum')
  494. axs[1].plot(self.iterations,self.momentums)
  495. if show_text:
  496. for i, phase in enumerate(self.phases):
  497. text = phase.opt_fn.__name__
  498. if phase.wds is not None: text+='\nwds='+str(phase.wds)
  499. if phase.beta is not None: text+='\nbeta='+str(phase.beta)
  500. for k in range(np_plts):
  501. if i < len(self.phases)-1:
  502. draw_line(axs[k], phase_limits[i+1])
  503. draw_text(axs[k], (phase_limits[i]+phase_limits[i+1])/2, text)
  504. if not in_ipynb():
  505. plt.savefig(os.path.join(self.save_path, 'lr_plot.png'))
  506. def plot(self, n_skip=10, n_skip_end=5, linear=None):
  507. if linear is None: linear = self.phases[-1].lr_decay == DecayType.LINEAR
  508. plt.ylabel("loss")
  509. plt.plot(self.lrs[n_skip:-n_skip_end], self.losses[n_skip:-n_skip_end])
  510. if linear: plt.xlabel("learning rate")
  511. else:
  512. plt.xlabel("learning rate (log scale)")
  513. plt.xscale('log')
  514. def draw_line(ax,x):
  515. xmin, xmax, ymin, ymax = ax.axis()
  516. ax.plot([x,x],[ymin,ymax], color='red', linestyle='dashed')
  517. def draw_text(ax,x, text):
  518. xmin, xmax, ymin, ymax = ax.axis()
  519. ax.text(x,(ymin+ymax)/2,text, horizontalalignment='center', verticalalignment='center', fontsize=14, alpha=0.5)
  520. def smooth_curve(vals, beta):
  521. avg_val = 0
  522. smoothed = []
  523. for (i,v) in enumerate(vals):
  524. avg_val = beta * avg_val + (1-beta) * v
  525. smoothed.append(avg_val/(1-beta**(i+1)))
  526. return smoothed