rnn_reg.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from .torch_imports import *
  2. from .core import *
  3. from functools import wraps
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4')
  7. def dropout_mask(x, sz, dropout):
  8. """ Applies a dropout mask whose size is determined by passed argument 'sz'.
  9. Args:
  10. x (nn.Variable): A torch Variable object
  11. sz (tuple(int, int, int)): The expected size of the new tensor
  12. dropout (float): The dropout fraction to apply
  13. This method uses the bernoulli distribution to decide which activations to keep.
  14. Additionally, the sampled activations is rescaled is using the factor 1/(1 - dropout).
  15. In the example given below, one can see that approximately .8 fraction of the
  16. returned tensors are zero. Rescaling with the factor 1/(1 - 0.8) returns a tensor
  17. with 5's in the unit places.
  18. The official link to the pytorch bernoulli function is here:
  19. http://pytorch.org/docs/master/torch.html#torch.bernoulli
  20. Examples:
  21. >>> a_Var = torch.autograd.Variable(torch.Tensor(2, 3, 4).uniform_(0, 1), requires_grad=False)
  22. >>> a_Var
  23. Variable containing:
  24. (0 ,.,.) =
  25. 0.6890 0.5412 0.4303 0.8918
  26. 0.3871 0.7944 0.0791 0.5979
  27. 0.4575 0.7036 0.6186 0.7217
  28. (1 ,.,.) =
  29. 0.8354 0.1690 0.1734 0.8099
  30. 0.6002 0.2602 0.7907 0.4446
  31. 0.5877 0.7464 0.4257 0.3386
  32. [torch.FloatTensor of size 2x3x4]
  33. >>> a_mask = dropout_mask(a_Var.data, (1,a_Var.size(1),a_Var.size(2)), dropout=0.8)
  34. >>> a_mask
  35. (0 ,.,.) =
  36. 0 5 0 0
  37. 0 0 0 5
  38. 5 0 5 0
  39. [torch.FloatTensor of size 1x3x4]
  40. """
  41. return x.new(*sz).bernoulli_(1-dropout)/(1-dropout)
  42. class LockedDropout(nn.Module):
  43. def __init__(self, p=0.5):
  44. super().__init__()
  45. self.p=p
  46. def forward(self, x):
  47. if not self.training or not self.p: return x
  48. m = dropout_mask(x.data, (1, x.size(1), x.size(2)), self.p)
  49. return Variable(m, requires_grad=False) * x
  50. class WeightDrop(torch.nn.Module):
  51. """A custom torch layer that serves as a wrapper on another torch layer.
  52. Primarily responsible for updating the weights in the wrapped module based
  53. on a specified dropout.
  54. """
  55. def __init__(self, module, dropout, weights=['weight_hh_l0']):
  56. """ Default constructor for the WeightDrop module
  57. Args:
  58. module (torch.nn.Module): A pytorch layer being wrapped
  59. dropout (float): a dropout value to apply
  60. weights (list(str)): the parameters of the wrapped **module**
  61. which should be fractionally dropped.
  62. """
  63. super().__init__()
  64. self.module,self.weights,self.dropout = module,weights,dropout
  65. self._setup()
  66. def _setup(self):
  67. """ for each string defined in self.weights, the corresponding
  68. attribute in the wrapped module is referenced, then deleted, and subsequently
  69. registered as a new parameter with a slightly modified name.
  70. Args:
  71. None
  72. Returns:
  73. None
  74. """
  75. if isinstance(self.module, torch.nn.RNNBase): self.module.flatten_parameters = noop
  76. for name_w in self.weights:
  77. w = getattr(self.module, name_w)
  78. del self.module._parameters[name_w]
  79. self.module.register_parameter(name_w + '_raw', nn.Parameter(w.data))
  80. def _setweights(self):
  81. """ Uses pytorch's built-in dropout function to apply dropout to the parameters of
  82. the wrapped module.
  83. Args:
  84. None
  85. Returns:
  86. None
  87. """
  88. for name_w in self.weights:
  89. raw_w = getattr(self.module, name_w + '_raw')
  90. w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
  91. if hasattr(self.module, name_w):
  92. delattr(self.module, name_w)
  93. setattr(self.module, name_w, w)
  94. def forward(self, *args):
  95. """ updates weights and delegates the propagation of the tensor to the wrapped module's
  96. forward method
  97. Args:
  98. *args: supplied arguments
  99. Returns:
  100. tensor obtained by running the forward method on the wrapped module.
  101. """
  102. self._setweights()
  103. return self.module.forward(*args)
  104. class EmbeddingDropout(nn.Module):
  105. """ Applies dropout in the embedding layer by zeroing out some elements of the embedding vector.
  106. Uses the dropout_mask custom layer to achieve this.
  107. Args:
  108. embed (torch.nn.Embedding): An embedding torch layer
  109. words (torch.nn.Variable): A torch variable
  110. dropout (float): dropout fraction to apply to the embedding weights
  111. scale (float): additional scaling to apply to the modified embedding weights
  112. Returns:
  113. tensor of size: (batch_size x seq_length x embedding_size)
  114. Example:
  115. >> embed = torch.nn.Embedding(10,3)
  116. >> words = Variable(torch.LongTensor([[1,2,4,5] ,[4,3,2,9]]))
  117. >> words.size()
  118. (2,4)
  119. >> embed_dropout_layer = EmbeddingDropout(embed)
  120. >> dropout_out_ = embed_dropout_layer(embed, words, dropout=0.40)
  121. >> dropout_out_
  122. Variable containing:
  123. (0 ,.,.) =
  124. 1.2549 1.8230 1.9367
  125. 0.0000 -0.0000 0.0000
  126. 2.2540 -0.1299 1.5448
  127. 0.0000 -0.0000 -0.0000
  128. (1 ,.,.) =
  129. 2.2540 -0.1299 1.5448
  130. -4.0457 2.4815 -0.2897
  131. 0.0000 -0.0000 0.0000
  132. 1.8796 -0.4022 3.8773
  133. [torch.FloatTensor of size 2x4x3]
  134. """
  135. def __init__(self, embed):
  136. super().__init__()
  137. self.embed = embed
  138. def forward(self, words, dropout=0.1, scale=None):
  139. if dropout:
  140. size = (self.embed.weight.size(0),1)
  141. mask = Variable(dropout_mask(self.embed.weight.data, size, dropout))
  142. masked_embed_weight = mask * self.embed.weight
  143. else: masked_embed_weight = self.embed.weight
  144. if scale: masked_embed_weight = scale * masked_embed_weight
  145. padding_idx = self.embed.padding_idx
  146. if padding_idx is None: padding_idx = -1
  147. if IS_TORCH_04:
  148. X = F.embedding(words,
  149. masked_embed_weight, padding_idx, self.embed.max_norm,
  150. self.embed.norm_type, self.embed.scale_grad_by_freq, self.embed.sparse)
  151. else:
  152. X = self.embed._backend.Embedding.apply(words,
  153. masked_embed_weight, padding_idx, self.embed.max_norm,
  154. self.embed.norm_type, self.embed.scale_grad_by_freq, self.embed.sparse)
  155. return X