fp16.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. import torch.nn as nn
  3. from .core import trainable_params_
  4. from .torch_imports import *
  5. IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4')
  6. class FP16(nn.Module):
  7. def __init__(self, module):
  8. super().__init__()
  9. self.module = batchnorm_to_fp32(module.half())
  10. def forward(self, input):
  11. if is_float(input): input = input.half()
  12. return self.module(input)
  13. def load_state_dict(self, *inputs, **kwargs):
  14. self.module.load_state_dict(*inputs, **kwargs)
  15. def state_dict(self, *inputs, **kwargs):
  16. return self.module.state_dict(*inputs, **kwargs)
  17. def __getitem__(self, idx):
  18. return self.module[idx]
  19. def is_float(tensor):
  20. if IS_TORCH_04: return tensor.is_floating_point()
  21. if isinstance(tensor, Variable): tensor = tensor.data
  22. return isinstance(tensor, torch.cuda.FloatTensor)
  23. def batchnorm_to_fp32(module):
  24. '''
  25. BatchNorm layers to have parameters in single precision.
  26. Find all layers and convert them back to float. This can't
  27. be done with built in .apply as that function will apply
  28. fn to all modules, parameters, and buffers. Thus we wouldn't
  29. be able to guard the float conversion based on the module type.
  30. '''
  31. if isinstance(module, nn.modules.batchnorm._BatchNorm):
  32. module.float()
  33. for child in module.children():
  34. batchnorm_to_fp32(child)
  35. return module
  36. def copy_model_to_fp32(m, optim):
  37. """ Creates a fp32 copy of model parameters and sets optimizer parameters
  38. """
  39. fp32_params = [m_param.clone().type(torch.cuda.FloatTensor).detach() for m_param in trainable_params_(m)]
  40. optim_groups = [group['params'] for group in optim.param_groups]
  41. iter_fp32_params = iter(fp32_params)
  42. for group_params in optim_groups:
  43. for i in range(len(group_params)):
  44. if not group_params[i].requires_grad: continue # only update trainable_params_
  45. fp32_param = next(iter_fp32_params)
  46. assert(fp32_param.shape == group_params[i].shape)
  47. fp32_param.requires_grad = group_params[i].requires_grad
  48. group_params[i] = fp32_param
  49. return fp32_params
  50. def copy_fp32_to_model(m, fp32_params):
  51. m_params = trainable_params_(m)
  52. assert(len(m_params) == len(fp32_params))
  53. for fp32_param, m_param in zip(fp32_params, m_params):
  54. m_param.data.copy_(fp32_param.data)
  55. def update_fp32_grads(fp32_params, m):
  56. m_params = trainable_params_(m)
  57. assert(len(m_params) == len(fp32_params))
  58. for fp32_param, m_param in zip(fp32_params, m_params):
  59. if fp32_param.grad is None:
  60. fp32_param.grad = nn.Parameter(fp32_param.data.new().resize_(*fp32_param.data.size()))
  61. fp32_param.grad.data.copy_(m_param.grad.data)