qrnn.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import torch
  2. from torch import nn
  3. from torch.autograd import Variable
  4. from .forget_mult import ForgetMult
  5. ##Adapted from the code here https://github.com/salesforce/pytorch-qrnn, implementation of the QRNN by
  6. ##Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard
  7. ##https://arxiv.org/abs/1611.01576
  8. class QRNNLayer(nn.Module):
  9. r"""Applies a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.
  10. Args:
  11. input_size: The number of expected features in the input x.
  12. hidden_size: The number of features in the hidden state h. If not specified, the input size is used.
  13. save_prev_x: Whether to store previous inputs for use in future convolutional windows (i.e. for a continuing sequence such as in language modeling). If true, you must call reset to remove cached previous values of x. Default: False.
  14. window: Defines the size of the convolutional window (how many previous tokens to look when computing the QRNN values). Supports 1 and 2. Default: 1.
  15. zoneout: Whether to apply zoneout (i.e. failing to update elements in the hidden state) to the hidden state updates. Default: 0.
  16. output_gate: If True, performs QRNN-fo (applying an output gate to the output). If False, performs QRNN-f. Default: True.
  17. use_cuda: If True, uses fast custom CUDA kernel. If False, uses naive for loop. Default: True.
  18. Inputs: X, hidden
  19. - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
  20. - hidden (batch, hidden_size): tensor containing the initial hidden state for the QRNN.
  21. Outputs: output, h_n
  22. - output (seq_len, batch, hidden_size): tensor containing the output of the QRNN for each timestep.
  23. - h_n (batch, hidden_size): tensor containing the hidden state for t=seq_len
  24. """
  25. def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, window=1, output_gate=True, use_cuda=True):
  26. super(QRNNLayer, self).__init__()
  27. assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2"
  28. self.window = window
  29. self.input_size = input_size
  30. self.hidden_size = hidden_size if hidden_size else input_size
  31. self.zoneout = zoneout
  32. self.save_prev_x = save_prev_x
  33. self.prevX = None
  34. self.output_gate = output_gate
  35. self.use_cuda = use_cuda
  36. # One large matmul with concat is faster than N small matmuls and no concat
  37. self.linear = nn.Linear(self.window * self.input_size, 3 * self.hidden_size if self.output_gate else 2 * self.hidden_size)
  38. def reset(self):
  39. # If you are saving the previous value of x, you should call this when starting with a new state
  40. self.prevX = None
  41. def forward(self, X, hidden=None):
  42. seq_len, batch_size, _ = X.size()
  43. source = None
  44. if self.window == 1:
  45. source = X
  46. elif self.window == 2:
  47. # Construct the x_{t-1} tensor with optional x_{-1}, otherwise a zeroed out value for x_{-1}
  48. Xm1 = []
  49. Xm1.append(self.prevX if self.prevX is not None else X[:1, :, :] * 0)
  50. # Note: in case of len(X) == 1, X[:-1, :, :] results in slicing of empty tensor == bad
  51. if len(X) > 1:
  52. Xm1.append(X[:-1, :, :])
  53. Xm1 = torch.cat(Xm1, 0)
  54. # Convert two (seq_len, batch_size, hidden) tensors to (seq_len, batch_size, 2 * hidden)
  55. source = torch.cat([X, Xm1], 2)
  56. # Matrix multiplication for the three outputs: Z, F, O
  57. Y = self.linear(source)
  58. # Convert the tensor back to (batch, seq_len, len([Z, F, O]) * hidden_size)
  59. if self.output_gate:
  60. Y = Y.view(seq_len, batch_size, 3 * self.hidden_size)
  61. Z, F, O = Y.chunk(3, dim=2)
  62. else:
  63. Y = Y.view(seq_len, batch_size, 2 * self.hidden_size)
  64. Z, F = Y.chunk(2, dim=2)
  65. ###
  66. Z = torch.nn.functional.tanh(Z)
  67. F = torch.nn.functional.sigmoid(F)
  68. # If zoneout is specified, we perform dropout on the forget gates in F
  69. # If an element of F is zero, that means the corresponding neuron keeps the old value
  70. if self.zoneout:
  71. if self.training:
  72. mask = Variable(F.data.new(*F.size()).bernoulli_(1 - self.zoneout), requires_grad=False)
  73. F = F * mask
  74. else:
  75. F *= 1 - self.zoneout
  76. # Ensure the memory is laid out as expected for the CUDA kernel
  77. # This is a null op if the tensor is already contiguous
  78. Z = Z.contiguous()
  79. F = F.contiguous()
  80. # The O gate doesn't need to be contiguous as it isn't used in the CUDA kernel
  81. # Forget Mult
  82. # For testing QRNN without ForgetMult CUDA kernel, C = Z * F may be useful
  83. C = ForgetMult()(F, Z, hidden, use_cuda=self.use_cuda)
  84. # Apply (potentially optional) output gate
  85. if self.output_gate:
  86. H = torch.nn.functional.sigmoid(O) * C
  87. else:
  88. H = C
  89. # In an optimal world we may want to backprop to x_{t-1} but ...
  90. if self.window > 1 and self.save_prev_x:
  91. self.prevX = Variable(X[-1:, :, :].data, requires_grad=False)
  92. return H, C[-1:, :, :]
  93. class QRNN(torch.nn.Module):
  94. r"""Applies a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.
  95. Args:
  96. input_size: The number of expected features in the input x.
  97. hidden_size: The number of features in the hidden state h. If not specified, the input size is used.
  98. num_layers: The number of QRNN layers to produce.
  99. layers: List of preconstructed QRNN layers to use for the QRNN module (optional).
  100. save_prev_x: Whether to store previous inputs for use in future convolutional windows (i.e. for a continuing sequence such as in language modeling). If true, you must call reset to remove cached previous values of x. Default: False.
  101. window: Defines the size of the convolutional window (how many previous tokens to look when computing the QRNN values). Supports 1 and 2. Default: 1.
  102. zoneout: Whether to apply zoneout (i.e. failing to update elements in the hidden state) to the hidden state updates. Default: 0.
  103. output_gate: If True, performs QRNN-fo (applying an output gate to the output). If False, performs QRNN-f. Default: True.
  104. use_cuda: If True, uses fast custom CUDA kernel. If False, uses naive for loop. Default: True.
  105. Inputs: X, hidden
  106. - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
  107. - hidden (layers, batch, hidden_size): tensor containing the initial hidden state for the QRNN.
  108. Outputs: output, h_n
  109. - output (seq_len, batch, hidden_size): tensor containing the output of the QRNN for each timestep.
  110. - h_n (layers, batch, hidden_size): tensor containing the hidden state for t=seq_len
  111. """
  112. def __init__(self, input_size, hidden_size,
  113. num_layers=1, bias=True, batch_first=False,
  114. dropout=0, bidirectional=False, layers=None, **kwargs):
  115. assert bidirectional == False, 'Bidirectional QRNN is not yet supported'
  116. assert batch_first == False, 'Batch first mode is not yet supported'
  117. assert bias == True, 'Removing underlying bias is not yet supported'
  118. super(QRNN, self).__init__()
  119. self.layers = torch.nn.ModuleList(layers if layers else [QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, **kwargs) for l in range(num_layers)])
  120. self.input_size = input_size
  121. self.hidden_size = hidden_size
  122. self.num_layers = len(layers) if layers else num_layers
  123. self.bias = bias
  124. self.batch_first = batch_first
  125. self.dropout = dropout
  126. self.bidirectional = bidirectional
  127. def reset(self):
  128. r'''If your convolutional window is greater than 1, you must reset at the beginning of each new sequence'''
  129. [layer.reset() for layer in self.layers]
  130. def forward(self, input, hidden=None):
  131. next_hidden = []
  132. for i, layer in enumerate(self.layers):
  133. input, hn = layer(input, None if hidden is None else hidden[i])
  134. next_hidden.append(hn)
  135. if self.dropout != 0 and i < len(self.layers) - 1:
  136. input = torch.nn.functional.dropout(input, p=self.dropout, training=self.training, inplace=False)
  137. next_hidden = torch.cat(next_hidden, 0).view(self.num_layers, *next_hidden[0].size()[-2:])
  138. return input, next_hidden