123456789101112 |
- from .imports import *
- from .torch_imports import *
- def cond_init(m, init_fn):
- if not isinstance(m, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d)):
- if hasattr(m, 'weight'): init_fn(m.weight)
- if hasattr(m, 'bias'): m.bias.data.fill_(0.)
- def apply_init(m, init_fn):
- m.apply(lambda x: cond_init(x, init_fn))
|