lsuv_initializer.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. From https://github.com/ducha-aiki/LSUV-pytorch
  3. Copyright (C) 2017, Dmytro Mishkin
  4. All rights reserved.
  5. Redistribution and use in source and binary forms, with or without
  6. modification, are permitted provided that the following conditions are
  7. met:
  8. 1. Redistributions of source code must retain the above copyright
  9. notice, this list of conditions and the following disclaimer.
  10. 2. Redistributions in binary form must reproduce the above copyright
  11. notice, this list of conditions and the following disclaimer in the
  12. documentation and/or other materials provided with the
  13. distribution.
  14. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  15. "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  16. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  17. A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  18. HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  19. SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  20. LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  21. DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  22. THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  24. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. """
  26. import numpy as np
  27. import torch
  28. import torch.nn.init
  29. import torch.nn as nn
  30. gg = {}
  31. gg['hook_position'] = 0
  32. gg['total_fc_conv_layers'] = 0
  33. gg['done_counter'] = -1
  34. gg['hook'] = None
  35. gg['act_dict'] = {}
  36. gg['counter_to_apply_correction'] = 0
  37. gg['correction_needed'] = False
  38. gg['current_coef'] = 1.0
  39. # Orthonorm init code is taked from Lasagne
  40. # https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py
  41. def svd_orthonormal(w):
  42. shape = w.shape
  43. if len(shape) < 2:
  44. raise RuntimeError("Only shapes of length 2 or more are supported.")
  45. flat_shape = (shape[0], np.prod(shape[1:]))
  46. a = np.random.normal(0.0, 1.0, flat_shape)#w;
  47. u, _, v = np.linalg.svd(a, full_matrices=False)
  48. q = u if u.shape == flat_shape else v
  49. q = q.reshape(shape)
  50. return q.astype(np.float32)
  51. def store_activations(self, input, output):
  52. gg['act_dict'] = output.data.cpu().numpy();
  53. return
  54. def add_current_hook(m):
  55. if gg['hook'] is not None:
  56. return
  57. if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
  58. if gg['hook_position'] > gg['done_counter']:
  59. gg['hook'] = m.register_forward_hook(store_activations)
  60. else:
  61. gg['hook_position'] += 1
  62. return
  63. def count_conv_fc_layers(m):
  64. if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
  65. gg['total_fc_conv_layers'] +=1
  66. return
  67. def remove_hooks(hooks):
  68. for h in hooks:
  69. h.remove()
  70. return
  71. def orthogonal_weights_init(m):
  72. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
  73. if hasattr(m, 'weight_v'):
  74. w_ortho = svd_orthonormal(m.weight_v.data.cpu().numpy())
  75. m.weight_v.data = torch.from_numpy(w_ortho)
  76. try:
  77. nn.init.constant(m.bias, 0)
  78. except:
  79. pass
  80. else:
  81. w_ortho = svd_orthonormal(m.weight.data.cpu().numpy())
  82. m.weight.data = torch.from_numpy(w_ortho)
  83. try:
  84. nn.init.constant(m.bias, 0)
  85. except:
  86. pass
  87. return
  88. def apply_weights_correction(m):
  89. if gg['hook'] is None:
  90. return
  91. if not gg['correction_needed']:
  92. return
  93. if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
  94. if gg['counter_to_apply_correction'] < gg['hook_position']:
  95. gg['counter_to_apply_correction'] += 1
  96. else:
  97. if hasattr(m, 'weight_g'):
  98. m.weight_g.data *= float(gg['current_coef'])
  99. gg['correction_needed'] = False
  100. else:
  101. m.weight.data *= gg['current_coef']
  102. gg['correction_needed'] = False
  103. return
  104. return
  105. def apply_lsuv_init(model, data, needed_std=1.0, std_tol=0.1, max_attempts=10, do_orthonorm=True, cuda=True):
  106. model.eval();
  107. if cuda:
  108. model=model.cuda()
  109. data=data.cuda()
  110. else:
  111. model=model.cpu()
  112. data=data.cpu()
  113. model.apply(count_conv_fc_layers)
  114. if do_orthonorm:
  115. model.apply(orthogonal_weights_init)
  116. if cuda:
  117. model=model.cuda()
  118. for layer_idx in range(gg['total_fc_conv_layers']):
  119. model.apply(add_current_hook)
  120. out = model(data)
  121. current_std = gg['act_dict'].std()
  122. attempts = 0
  123. while (np.abs(current_std - needed_std) > std_tol):
  124. gg['current_coef'] = needed_std / (current_std + 1e-8);
  125. gg['correction_needed'] = True
  126. model.apply(apply_weights_correction)
  127. if cuda:
  128. model=model.cuda()
  129. out = model(data)
  130. current_std = gg['act_dict'].std()
  131. attempts+=1
  132. if attempts > max_attempts:
  133. print(f'Cannot converge in {max_attempts} iterations')
  134. break
  135. if gg['hook'] is not None:
  136. gg['hook'].remove()
  137. gg['done_counter']+=1
  138. gg['counter_to_apply_correction'] = 0
  139. gg['hook_position'] = 0
  140. gg['hook'] = None
  141. if not cuda:
  142. model=model.cpu()
  143. return model