forget_mult.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import math
  2. import torch
  3. from torch.autograd import Variable
  4. from cupy.cuda import function
  5. from cupy.cuda.compiler import _NVRTCProgram
  6. from collections import namedtuple
  7. ##Adapted from the code here https://github.com/salesforce/pytorch-qrnn, implementation of the QRNN by
  8. ##Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard
  9. ##https://arxiv.org/abs/1611.01576
  10. kernel = '''
  11. extern "C"
  12. __global__ void recurrent_forget_mult(float *dst, const float *f, const float *x, int SEQ, int BATCH, int HIDDEN)
  13. {
  14. /*
  15. Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1}
  16. This means dst array has a separate index than that of f or x
  17. */
  18. int hid = blockIdx.x * blockDim.x + threadIdx.x;
  19. int bid = blockIdx.y * blockDim.y + threadIdx.y;
  20. if(hid >= HIDDEN || bid >= BATCH)
  21. return;
  22. //
  23. for (int ts = 0 + 1; ts < SEQ + 1; ts++) {
  24. // Good sanity check for debugging - only perform additions to a zeroed chunk of memory
  25. // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads
  26. // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc
  27. // To move timesteps, we step HIDDEN * BATCH
  28. // To move batches, we move HIDDEN
  29. // To move neurons, we move +- 1
  30. // Note: dst[dst_i] = ts * 100 + bid * 10 + hid; is useful for debugging
  31. int i = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
  32. int dst_i = (ts - 0) * HIDDEN * BATCH + bid * HIDDEN + hid;
  33. int dst_iminus1 = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
  34. dst[dst_i] = f[i] * x[i];
  35. dst[dst_i] += (1 - f[i]) * dst[dst_iminus1];
  36. }
  37. }
  38. extern "C"
  39. __global__ void bwd_recurrent_forget_mult(const float *h, const float *f, const float *x, const float *gh, float *gf, float *gx, float *ghinit, int SEQ, int BATCH, int HIDDEN)
  40. {
  41. /*
  42. Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1}
  43. This means dst array has a separate index than that of f or x
  44. */
  45. int hid = blockIdx.x * blockDim.x + threadIdx.x;
  46. int bid = blockIdx.y * blockDim.y + threadIdx.y;
  47. if(hid >= HIDDEN || bid >= BATCH)
  48. return;
  49. //
  50. double running_f = 0;
  51. for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) {
  52. int i = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
  53. int dst_i = (ts - 0) * HIDDEN * BATCH + bid * HIDDEN + hid;
  54. int dst_iminus1 = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
  55. //
  56. running_f += gh[dst_iminus1];
  57. // Gradient of X
  58. gx[i] = f[i] * running_f;
  59. // Gradient of F
  60. gf[i] = (x[i] - h[dst_iminus1]) * running_f;
  61. //
  62. // The line below is likely more numerically stable than (1 - f[i]) * running_f;
  63. running_f = running_f - f[i] * running_f;
  64. }
  65. ghinit[bid * HIDDEN + hid] = running_f;
  66. }
  67. '''
  68. ###
  69. class CPUForgetMult(torch.nn.Module):
  70. def __init__(self):
  71. super(CPUForgetMult, self).__init__()
  72. def forward(self, f, x, hidden_init=None):
  73. result = []
  74. ###
  75. forgets = f.split(1, dim=0)
  76. prev_h = hidden_init
  77. for i, h in enumerate((f * x).split(1, dim=0)):
  78. if prev_h is not None: h = h + (1 - forgets[i]) * prev_h
  79. # h is (1, batch, hidden) when it needs to be (batch_hidden)
  80. # Calling squeeze will result in badness if batch size is 1
  81. h = h.view(h.size()[1:])
  82. result.append(h)
  83. prev_h = h
  84. ###
  85. return torch.stack(result)
  86. class GPUForgetMult(torch.autograd.Function):
  87. configured_gpus = {}
  88. ptx = None
  89. def __init__(self):
  90. super(GPUForgetMult, self).__init__()
  91. def compile(self):
  92. if self.ptx is None:
  93. program = _NVRTCProgram(kernel.encode(), 'recurrent_forget_mult.cu'.encode())
  94. GPUForgetMult.ptx = program.compile()
  95. if torch.cuda.current_device() not in GPUForgetMult.configured_gpus:
  96. m = function.Module()
  97. m.load(bytes(self.ptx.encode()))
  98. self.forget_mult = m.get_function('recurrent_forget_mult')
  99. self.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult')
  100. Stream = namedtuple('Stream', ['ptr'])
  101. self.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)
  102. GPUForgetMult.configured_gpus[torch.cuda.current_device()] = (self.forget_mult, self.bwd_forget_mult, self.stream)
  103. self.forget_mult, self.bwd_forget_mult, self.stream = GPUForgetMult.configured_gpus[torch.cuda.current_device()]
  104. def forward(self, f, x, hidden_init=None):
  105. self.compile()
  106. seq_size, batch_size, hidden_size = f.size()
  107. result = f.new(seq_size + 1, batch_size, hidden_size)
  108. # We only zero the result array (result[0]) if we don't set a hidden initial state
  109. # All other values (result[1:]) are overwritten by default
  110. if hidden_init is not None: result[0, :, :] = hidden_init
  111. else: result = result.zero_()
  112. ###
  113. grid_hidden_size = min(hidden_size, 512)
  114. grid = (math.ceil(hidden_size / grid_hidden_size), batch_size)
  115. self.forget_mult(grid=grid, block=(grid_hidden_size, 1), args=[result.data_ptr(), f.data_ptr(), x.data_ptr(), seq_size, batch_size, hidden_size], stream=self.stream)
  116. self.save_for_backward(f, x, hidden_init)
  117. self.result = result
  118. return result[1:, :, :]
  119. def backward(self, grad_h):
  120. self.compile()
  121. f, x, hidden_init = self.saved_tensors
  122. h = self.result
  123. ###
  124. seq_size, batch_size, hidden_size = f.size()
  125. # Zeroing is not necessary as these will be overwritten
  126. grad_f = f.new(*f.size())
  127. grad_x = f.new(*f.size())
  128. grad_h_init = f.new(batch_size, hidden_size)
  129. ###
  130. grid_hidden_size = min(hidden_size, 512)
  131. grid = (math.ceil(hidden_size / grid_hidden_size), batch_size)
  132. self.bwd_forget_mult(grid=grid, block=(grid_hidden_size, 1), args=[h.data_ptr(), f.data_ptr(), x.data_ptr(), grad_h.data_ptr(), grad_f.data_ptr(), grad_x.data_ptr(), grad_h_init.data_ptr(), seq_size, batch_size, hidden_size], stream=self.stream)
  133. ###
  134. if hidden_init is not None:
  135. return grad_f, grad_x, grad_h_init
  136. return grad_f, grad_x
  137. class ForgetMult(torch.nn.Module):
  138. r"""ForgetMult computes a simple recurrent equation:
  139. h_t = f_t * x_t + (1 - f_t) * h_{t-1}
  140. This equation is equivalent to dynamic weighted averaging.
  141. Inputs: X, hidden
  142. - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
  143. - F (seq_len, batch, input_size): tensor containing the forget gate values, assumed in range [0, 1].
  144. - hidden_init (batch, input_size): tensor containing the initial hidden state for the recurrence (h_{t-1}).
  145. - use_cuda: If True, use the fast element-wise CUDA kernel for recurrence. If False, uses naive for loop. Default: True.
  146. """
  147. def __init__(self):
  148. super(ForgetMult, self).__init__()
  149. def forward(self, f, x, hidden_init=None, use_cuda=True):
  150. # Use CUDA by default unless it's available
  151. use_cuda = use_cuda and torch.cuda.is_available()
  152. # Ensure the user is aware when ForgetMult is not GPU version as it's far faster
  153. if use_cuda: assert f.is_cuda and x.is_cuda, 'GPU ForgetMult with fast element-wise CUDA kernel requested but tensors not on GPU'
  154. ###
  155. # Avoiding 'RuntimeError: expected a Variable argument, but got NoneType' when hidden_init is None
  156. if hidden_init is None: return GPUForgetMult()(f, x) if use_cuda else CPUForgetMult()(f, x)
  157. return GPUForgetMult()(f, x, hidden_init) if use_cuda else CPUForgetMult()(f, x, hidden_init)