1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- """
- From the paper:
- Averaging Weights Leads to Wider Optima and Better Generalization
- Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson
- https://arxiv.org/abs/1803.05407
- 2018
-
- Author's implementation: https://github.com/timgaripov/swa
- """
- import torch
- from .sgdr import *
- from .core import *
- class SWA(Callback):
- def __init__(self, model, swa_model, swa_start):
- super().__init__()
- self.model,self.swa_model,self.swa_start=model,swa_model,swa_start
-
- def on_train_begin(self):
- self.epoch = 0
- self.swa_n = 0
- def on_epoch_end(self, metrics):
- if (self.epoch + 1) >= self.swa_start:
- self.update_average_model()
- self.swa_n += 1
-
- self.epoch += 1
-
- def update_average_model(self):
- # update running average of parameters
- model_params = self.model.parameters()
- swa_params = self.swa_model.parameters()
- for model_param, swa_param in zip(model_params, swa_params):
- swa_param.data *= self.swa_n
- swa_param.data += model_param.data
- swa_param.data /= (self.swa_n + 1)
-
- def collect_bn_modules(module, bn_modules):
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- bn_modules.append(module)
- def fix_batchnorm(swa_model, train_dl):
- """
- During training, batch norm layers keep track of a running mean and
- variance of the previous layer's activations. Because the parameters
- of the SWA model are computed as the average of other models' parameters,
- the SWA model never sees the training data itself, and therefore has no
- opportunity to compute the correct batch norm statistics. Before performing
- inference with the SWA model, we perform a single pass over the training data
- to calculate an accurate running mean and variance for each batch norm layer.
- """
- bn_modules = []
- swa_model.apply(lambda module: collect_bn_modules(module, bn_modules))
-
- if not bn_modules: return
- swa_model.train()
- for module in bn_modules:
- module.running_mean = torch.zeros_like(module.running_mean)
- module.running_var = torch.ones_like(module.running_var)
-
- momenta = [m.momentum for m in bn_modules]
- inputs_seen = 0
- for (*x,y) in iter(train_dl):
- xs = V(x)
- batch_size = xs[0].size(0)
- momentum = batch_size / (inputs_seen + batch_size)
- for module in bn_modules:
- module.momentum = momentum
-
- res = swa_model(*xs)
-
- inputs_seen += batch_size
-
- for module, momentum in zip(bn_modules, momenta):
- module.momentum = momentum
|