from .imports import * from .torch_imports import * def cond_init(m, init_fn): if not isinstance(m, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d)): if hasattr(m, 'weight'): init_fn(m.weight) if hasattr(m, 'bias'): m.bias.data.fill_(0.) def apply_init(m, init_fn): m.apply(lambda x: cond_init(x, init_fn))