lm_rnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import warnings
  2. from .imports import *
  3. from .torch_imports import *
  4. from .rnn_reg import LockedDropout,WeightDrop,EmbeddingDropout
  5. from .model import Stepper
  6. from .core import set_grad_enabled
  7. IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4')
  8. def seq2seq_reg(output, xtra, loss, alpha=0, beta=0):
  9. hs,dropped_hs = xtra
  10. if alpha: # Activation Regularization
  11. loss = loss + (alpha * dropped_hs[-1].pow(2).mean()).sum()
  12. if beta: # Temporal Activation Regularization (slowness)
  13. h = hs[-1]
  14. if len(h)>1: loss = loss + (beta * (h[1:] - h[:-1]).pow(2).mean()).sum()
  15. return loss
  16. def repackage_var(h):
  17. """Wraps h in new Variables, to detach them from their history."""
  18. if IS_TORCH_04: return h.detach() if type(h) == torch.Tensor else tuple(repackage_var(v) for v in h)
  19. else: return Variable(h.data) if type(h) == Variable else tuple(repackage_var(v) for v in h)
  20. class RNN_Encoder(nn.Module):
  21. """A custom RNN encoder network that uses
  22. - an embedding matrix to encode input,
  23. - a stack of LSTM or QRNN layers to drive the network, and
  24. - variational dropouts in the embedding and LSTM/QRNN layers
  25. The architecture for this network was inspired by the work done in
  26. "Regularizing and Optimizing LSTM Language Models".
  27. (https://arxiv.org/pdf/1708.02182.pdf)
  28. """
  29. initrange=0.1
  30. def __init__(self, ntoken, emb_sz, n_hid, n_layers, pad_token, bidir=False,
  31. dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5, qrnn=False):
  32. """ Default constructor for the RNN_Encoder class
  33. Args:
  34. bs (int): batch size of input data
  35. ntoken (int): number of vocabulary (or tokens) in the source dataset
  36. emb_sz (int): the embedding size to use to encode each token
  37. n_hid (int): number of hidden activation per LSTM layer
  38. n_layers (int): number of LSTM layers to use in the architecture
  39. pad_token (int): the int value used for padding text.
  40. dropouth (float): dropout to apply to the activations going from one LSTM layer to another
  41. dropouti (float): dropout to apply to the input layer.
  42. dropoute (float): dropout to apply to the embedding layer.
  43. wdrop (float): dropout used for a LSTM's internal (or hidden) recurrent weights.
  44. Returns:
  45. None
  46. """
  47. super().__init__()
  48. self.ndir = 2 if bidir else 1
  49. self.bs, self.qrnn = 1, qrnn
  50. self.encoder = nn.Embedding(ntoken, emb_sz, padding_idx=pad_token)
  51. self.encoder_with_dropout = EmbeddingDropout(self.encoder)
  52. if self.qrnn:
  53. #Using QRNN requires cupy: https://github.com/cupy/cupy
  54. from .torchqrnn.qrnn import QRNNLayer
  55. self.rnns = [QRNNLayer(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir,
  56. save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(n_layers)]
  57. if wdrop:
  58. for rnn in self.rnns:
  59. rnn.linear = WeightDrop(rnn.linear, wdrop, weights=['weight'])
  60. else:
  61. self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir,
  62. 1, bidirectional=bidir) for l in range(n_layers)]
  63. if wdrop: self.rnns = [WeightDrop(rnn, wdrop) for rnn in self.rnns]
  64. self.rnns = torch.nn.ModuleList(self.rnns)
  65. self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
  66. self.emb_sz,self.n_hid,self.n_layers,self.dropoute = emb_sz,n_hid,n_layers,dropoute
  67. self.dropouti = LockedDropout(dropouti)
  68. self.dropouths = nn.ModuleList([LockedDropout(dropouth) for l in range(n_layers)])
  69. def forward(self, input):
  70. """ Invoked during the forward propagation of the RNN_Encoder module.
  71. Args:
  72. input (Tensor): input of shape (sentence length x batch_size)
  73. Returns:
  74. raw_outputs (tuple(list (Tensor), list(Tensor)): list of tensors evaluated from each RNN layer without using
  75. dropouth, list of tensors evaluated from each RNN layer using dropouth,
  76. """
  77. sl,bs = input.size()
  78. if bs!=self.bs:
  79. self.bs=bs
  80. self.reset()
  81. with set_grad_enabled(self.training):
  82. emb = self.encoder_with_dropout(input, dropout=self.dropoute if self.training else 0)
  83. emb = self.dropouti(emb)
  84. raw_output = emb
  85. new_hidden,raw_outputs,outputs = [],[],[]
  86. for l, (rnn,drop) in enumerate(zip(self.rnns, self.dropouths)):
  87. current_input = raw_output
  88. with warnings.catch_warnings():
  89. warnings.simplefilter("ignore")
  90. raw_output, new_h = rnn(raw_output, self.hidden[l])
  91. new_hidden.append(new_h)
  92. raw_outputs.append(raw_output)
  93. if l != self.n_layers - 1: raw_output = drop(raw_output)
  94. outputs.append(raw_output)
  95. self.hidden = repackage_var(new_hidden)
  96. return raw_outputs, outputs
  97. def one_hidden(self, l):
  98. nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz)//self.ndir
  99. if IS_TORCH_04: return Variable(self.weights.new(self.ndir, self.bs, nh).zero_())
  100. else: return Variable(self.weights.new(self.ndir, self.bs, nh).zero_(), volatile=not self.training)
  101. def reset(self):
  102. if self.qrnn: [r.reset() for r in self.rnns]
  103. self.weights = next(self.parameters()).data
  104. if self.qrnn: self.hidden = [self.one_hidden(l) for l in range(self.n_layers)]
  105. else: self.hidden = [(self.one_hidden(l), self.one_hidden(l)) for l in range(self.n_layers)]
  106. class MultiBatchRNN(RNN_Encoder):
  107. def __init__(self, bptt, max_seq, *args, **kwargs):
  108. self.max_seq,self.bptt = max_seq,bptt
  109. super().__init__(*args, **kwargs)
  110. def concat(self, arrs):
  111. return [torch.cat([l[si] for l in arrs]) for si in range(len(arrs[0]))]
  112. def forward(self, input):
  113. sl,bs = input.size()
  114. for l in self.hidden:
  115. for h in l: h.data.zero_()
  116. raw_outputs, outputs = [],[]
  117. for i in range(0, sl, self.bptt):
  118. r, o = super().forward(input[i: min(i+self.bptt, sl)])
  119. if i>(sl-self.max_seq):
  120. raw_outputs.append(r)
  121. outputs.append(o)
  122. return self.concat(raw_outputs), self.concat(outputs)
  123. class LinearDecoder(nn.Module):
  124. initrange=0.1
  125. def __init__(self, n_out, n_hid, dropout, tie_encoder=None, bias=False):
  126. super().__init__()
  127. self.decoder = nn.Linear(n_hid, n_out, bias=bias)
  128. self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
  129. self.dropout = LockedDropout(dropout)
  130. if bias: self.decoder.bias.data.zero_()
  131. if tie_encoder: self.decoder.weight = tie_encoder.weight
  132. def forward(self, input):
  133. raw_outputs, outputs = input
  134. output = self.dropout(outputs[-1])
  135. decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
  136. result = decoded.view(-1, decoded.size(1))
  137. return result, raw_outputs, outputs
  138. class LinearBlock(nn.Module):
  139. def __init__(self, ni, nf, drop):
  140. super().__init__()
  141. self.lin = nn.Linear(ni, nf)
  142. self.drop = nn.Dropout(drop)
  143. self.bn = nn.BatchNorm1d(ni)
  144. def forward(self, x): return self.lin(self.drop(self.bn(x)))
  145. class PoolingLinearClassifier(nn.Module):
  146. def __init__(self, layers, drops):
  147. super().__init__()
  148. self.layers = nn.ModuleList([
  149. LinearBlock(layers[i], layers[i + 1], drops[i]) for i in range(len(layers) - 1)])
  150. def pool(self, x, bs, is_max):
  151. f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
  152. return f(x.permute(1,2,0), (1,)).view(bs,-1)
  153. def forward(self, input):
  154. raw_outputs, outputs = input
  155. output = outputs[-1]
  156. sl,bs,_ = output.size()
  157. avgpool = self.pool(output, bs, False)
  158. mxpool = self.pool(output, bs, True)
  159. x = torch.cat([output[-1], mxpool, avgpool], 1)
  160. for l in self.layers:
  161. l_x = l(x)
  162. x = F.relu(l_x)
  163. return l_x, raw_outputs, outputs
  164. class SequentialRNN(nn.Sequential):
  165. def reset(self):
  166. for c in self.children():
  167. if hasattr(c, 'reset'): c.reset()
  168. def get_language_model(n_tok, emb_sz, n_hid, n_layers, pad_token,
  169. dropout=0.4, dropouth=0.3, dropouti=0.5, dropoute=0.1, wdrop=0.5, tie_weights=True, qrnn=False, bias=False):
  170. """Returns a SequentialRNN model.
  171. A RNN_Encoder layer is instantiated using the parameters provided.
  172. This is followed by the creation of a LinearDecoder layer.
  173. Also by default (i.e. tie_weights = True), the embedding matrix used in the RNN_Encoder
  174. is used to instantiate the weights for the LinearDecoder layer.
  175. The SequentialRNN layer is the native torch's Sequential wrapper that puts the RNN_Encoder and
  176. LinearDecoder layers sequentially in the model.
  177. Args:
  178. n_tok (int): number of unique vocabulary words (or tokens) in the source dataset
  179. emb_sz (int): the embedding size to use to encode each token
  180. n_hid (int): number of hidden activation per LSTM layer
  181. n_layers (int): number of LSTM layers to use in the architecture
  182. pad_token (int): the int value used for padding text.
  183. dropouth (float): dropout to apply to the activations going from one LSTM layer to another
  184. dropouti (float): dropout to apply to the input layer.
  185. dropoute (float): dropout to apply to the embedding layer.
  186. wdrop (float): dropout used for a LSTM's internal (or hidden) recurrent weights.
  187. tie_weights (bool): decide if the weights of the embedding matrix in the RNN encoder should be tied to the
  188. weights of the LinearDecoder layer.
  189. qrnn (bool): decide if the model is composed of LSTMS (False) or QRNNs (True).
  190. bias (bool): decide if the decoder should have a bias layer or not.
  191. Returns:
  192. A SequentialRNN model
  193. """
  194. rnn_enc = RNN_Encoder(n_tok, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token,
  195. dropouth=dropouth, dropouti=dropouti, dropoute=dropoute, wdrop=wdrop, qrnn=qrnn)
  196. enc = rnn_enc.encoder if tie_weights else None
  197. return SequentialRNN(rnn_enc, LinearDecoder(n_tok, emb_sz, dropout, tie_encoder=enc, bias=bias))
  198. def get_rnn_classifier(bptt, max_seq, n_class, n_tok, emb_sz, n_hid, n_layers, pad_token, layers, drops, bidir=False,
  199. dropouth=0.3, dropouti=0.5, dropoute=0.1, wdrop=0.5, qrnn=False):
  200. rnn_enc = MultiBatchRNN(bptt, max_seq, n_tok, emb_sz, n_hid, n_layers, pad_token=pad_token, bidir=bidir,
  201. dropouth=dropouth, dropouti=dropouti, dropoute=dropoute, wdrop=wdrop, qrnn=qrnn)
  202. return SequentialRNN(rnn_enc, PoolingLinearClassifier(layers, drops))
  203. get_rnn_classifer=get_rnn_classifier