layer_optimizer.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from .imports import *
  2. from .torch_imports import *
  3. from .core import *
  4. def opt_params(parm, lr, wd):
  5. return {'params': chain_params(parm), 'lr':lr, 'weight_decay':wd}
  6. class LayerOptimizer():
  7. def __init__(self, opt_fn, layer_groups, lrs, wds=None):
  8. if not isinstance(layer_groups, (list,tuple)): layer_groups=[layer_groups]
  9. if not isinstance(lrs, Iterable): lrs=[lrs]
  10. if len(lrs)==1: lrs=lrs*len(layer_groups)
  11. if wds is None: wds=0.
  12. if not isinstance(wds, Iterable): wds=[wds]
  13. if len(wds)==1: wds=wds*len(layer_groups)
  14. self.layer_groups,self.lrs,self.wds = layer_groups,lrs,wds
  15. self.opt = opt_fn(self.opt_params())
  16. def opt_params(self):
  17. assert(len(self.layer_groups) == len(self.lrs))
  18. assert(len(self.layer_groups) == len(self.wds))
  19. params = list(zip(self.layer_groups,self.lrs,self.wds))
  20. return [opt_params(*p) for p in params]
  21. @property
  22. def lr(self): return self.lrs[-1]
  23. @property
  24. def mom(self):
  25. if 'betas' in self.opt.param_groups[0]:
  26. return self.opt.param_groups[0]['betas'][0]
  27. else:
  28. return self.opt.param_groups[0]['momentum']
  29. def set_lrs(self, lrs):
  30. if not isinstance(lrs, Iterable): lrs=[lrs]
  31. if len(lrs)==1: lrs=lrs*len(self.layer_groups)
  32. set_lrs(self.opt, lrs)
  33. self.lrs=lrs
  34. def set_wds_out(self, wds):
  35. if not isinstance(wds, Iterable): wds=[wds]
  36. if len(wds)==1: wds=wds*len(self.layer_groups)
  37. set_wds_out(self.opt, wds)
  38. set_wds(self.opt, [0] * len(self.layer_groups))
  39. self.wds=wds
  40. def set_wds(self, wds):
  41. if not isinstance(wds, Iterable): wds=[wds]
  42. if len(wds)==1: wds=wds*len(self.layer_groups)
  43. set_wds(self.opt, wds)
  44. set_wds_out(self.opt, [0] * len(self.layer_groups))
  45. self.wds=wds
  46. def set_mom(self,momentum):
  47. if 'betas' in self.opt.param_groups[0]:
  48. for pg in self.opt.param_groups: pg['betas'] = (momentum, pg['betas'][1])
  49. else:
  50. for pg in self.opt.param_groups: pg['momentum'] = momentum
  51. def set_beta(self,beta):
  52. if 'betas' in self.opt.param_groups[0]:
  53. for pg in self.opt.param_groups: pg['betas'] = (pg['betas'][0],beta)
  54. elif 'alpha' in self.opt.param_groups[0]:
  55. for pg in self.opt.param_groups: pg['alpha'] = beta
  56. def set_opt_fn(self, opt_fn):
  57. if type(self.opt) != type(opt_fn(self.opt_params())):
  58. self.opt = opt_fn(self.opt_params())
  59. def zip_strict_(l, r):
  60. assert(len(l) == len(r))
  61. return zip(l, r)
  62. def set_lrs(opt, lrs):
  63. if not isinstance(lrs, Iterable): lrs=[lrs]
  64. if len(lrs)==1: lrs=lrs*len(opt.param_groups)
  65. for pg,lr in zip_strict_(opt.param_groups,lrs): pg['lr'] = lr
  66. def set_wds_out(opt, wds):
  67. if not isinstance(wds, Iterable): wds=[wds]
  68. if len(wds)==1: wds=wds*len(opt.param_groups)
  69. assert(len(opt.param_groups) == len(wds))
  70. for pg,wd in zip_strict_(opt.param_groups,wds): pg['wd'] = wd
  71. def set_wds(opt, wds):
  72. if not isinstance(wds, Iterable): wds=[wds]
  73. if len(wds)==1: wds=wds*len(opt.param_groups)
  74. assert(len(opt.param_groups) == len(wds))
  75. for pg,wd in zip_strict_(opt.param_groups,wds): pg['weight_decay'] = wd