swa.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """
  2. From the paper:
  3. Averaging Weights Leads to Wider Optima and Better Generalization
  4. Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson
  5. https://arxiv.org/abs/1803.05407
  6. 2018
  7. Author's implementation: https://github.com/timgaripov/swa
  8. """
  9. import torch
  10. from .sgdr import *
  11. from .core import *
  12. class SWA(Callback):
  13. def __init__(self, model, swa_model, swa_start):
  14. super().__init__()
  15. self.model,self.swa_model,self.swa_start=model,swa_model,swa_start
  16. def on_train_begin(self):
  17. self.epoch = 0
  18. self.swa_n = 0
  19. def on_epoch_end(self, metrics):
  20. if (self.epoch + 1) >= self.swa_start:
  21. self.update_average_model()
  22. self.swa_n += 1
  23. self.epoch += 1
  24. def update_average_model(self):
  25. # update running average of parameters
  26. model_params = self.model.parameters()
  27. swa_params = self.swa_model.parameters()
  28. for model_param, swa_param in zip(model_params, swa_params):
  29. swa_param.data *= self.swa_n
  30. swa_param.data += model_param.data
  31. swa_param.data /= (self.swa_n + 1)
  32. def collect_bn_modules(module, bn_modules):
  33. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  34. bn_modules.append(module)
  35. def fix_batchnorm(swa_model, train_dl):
  36. """
  37. During training, batch norm layers keep track of a running mean and
  38. variance of the previous layer's activations. Because the parameters
  39. of the SWA model are computed as the average of other models' parameters,
  40. the SWA model never sees the training data itself, and therefore has no
  41. opportunity to compute the correct batch norm statistics. Before performing
  42. inference with the SWA model, we perform a single pass over the training data
  43. to calculate an accurate running mean and variance for each batch norm layer.
  44. """
  45. bn_modules = []
  46. swa_model.apply(lambda module: collect_bn_modules(module, bn_modules))
  47. if not bn_modules: return
  48. swa_model.train()
  49. for module in bn_modules:
  50. module.running_mean = torch.zeros_like(module.running_mean)
  51. module.running_var = torch.ones_like(module.running_var)
  52. momenta = [m.momentum for m in bn_modules]
  53. inputs_seen = 0
  54. for (*x,y) in iter(train_dl):
  55. xs = V(x)
  56. batch_size = xs[0].size(0)
  57. momentum = batch_size / (inputs_seen + batch_size)
  58. for module in bn_modules:
  59. module.momentum = momentum
  60. res = swa_model(*xs)
  61. inputs_seen += batch_size
  62. for module, momentum in zip(bn_modules, momenta):
  63. module.momentum = momentum