initializers.py 334 B

123456789101112
  1. from .imports import *
  2. from .torch_imports import *
  3. def cond_init(m, init_fn):
  4. if not isinstance(m, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d)):
  5. if hasattr(m, 'weight'): init_fn(m.weight)
  6. if hasattr(m, 'bias'): m.bias.data.fill_(0.)
  7. def apply_init(m, init_fn):
  8. m.apply(lambda x: cond_init(x, init_fn))