123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import math
- import torch
- from torch.autograd import Variable
- from cupy.cuda import function
- from cupy.cuda.compiler import _NVRTCProgram
- from collections import namedtuple
- ##Adapted from the code here https://github.com/salesforce/pytorch-qrnn, implementation of the QRNN by
- ##Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard
- ##https://arxiv.org/abs/1611.01576
- kernel = '''
- extern "C"
- __global__ void recurrent_forget_mult(float *dst, const float *f, const float *x, int SEQ, int BATCH, int HIDDEN)
- {
- /*
- Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1}
- This means dst array has a separate index than that of f or x
- */
- int hid = blockIdx.x * blockDim.x + threadIdx.x;
- int bid = blockIdx.y * blockDim.y + threadIdx.y;
- if(hid >= HIDDEN || bid >= BATCH)
- return;
- //
- for (int ts = 0 + 1; ts < SEQ + 1; ts++) {
- // Good sanity check for debugging - only perform additions to a zeroed chunk of memory
- // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads
- // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc
- // To move timesteps, we step HIDDEN * BATCH
- // To move batches, we move HIDDEN
- // To move neurons, we move +- 1
- // Note: dst[dst_i] = ts * 100 + bid * 10 + hid; is useful for debugging
- int i = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
- int dst_i = (ts - 0) * HIDDEN * BATCH + bid * HIDDEN + hid;
- int dst_iminus1 = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
- dst[dst_i] = f[i] * x[i];
- dst[dst_i] += (1 - f[i]) * dst[dst_iminus1];
- }
- }
- extern "C"
- __global__ void bwd_recurrent_forget_mult(const float *h, const float *f, const float *x, const float *gh, float *gf, float *gx, float *ghinit, int SEQ, int BATCH, int HIDDEN)
- {
- /*
- Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1}
- This means dst array has a separate index than that of f or x
- */
- int hid = blockIdx.x * blockDim.x + threadIdx.x;
- int bid = blockIdx.y * blockDim.y + threadIdx.y;
- if(hid >= HIDDEN || bid >= BATCH)
- return;
- //
- double running_f = 0;
- for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) {
- int i = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
- int dst_i = (ts - 0) * HIDDEN * BATCH + bid * HIDDEN + hid;
- int dst_iminus1 = (ts - 1) * HIDDEN * BATCH + bid * HIDDEN + hid;
- //
- running_f += gh[dst_iminus1];
- // Gradient of X
- gx[i] = f[i] * running_f;
- // Gradient of F
- gf[i] = (x[i] - h[dst_iminus1]) * running_f;
- //
- // The line below is likely more numerically stable than (1 - f[i]) * running_f;
- running_f = running_f - f[i] * running_f;
- }
- ghinit[bid * HIDDEN + hid] = running_f;
- }
- '''
- ###
- class CPUForgetMult(torch.nn.Module):
- def __init__(self):
- super(CPUForgetMult, self).__init__()
- def forward(self, f, x, hidden_init=None):
- result = []
- ###
- forgets = f.split(1, dim=0)
- prev_h = hidden_init
- for i, h in enumerate((f * x).split(1, dim=0)):
- if prev_h is not None: h = h + (1 - forgets[i]) * prev_h
- # h is (1, batch, hidden) when it needs to be (batch_hidden)
- # Calling squeeze will result in badness if batch size is 1
- h = h.view(h.size()[1:])
- result.append(h)
- prev_h = h
- ###
- return torch.stack(result)
- class GPUForgetMult(torch.autograd.Function):
- configured_gpus = {}
- ptx = None
- def __init__(self):
- super(GPUForgetMult, self).__init__()
- def compile(self):
- if self.ptx is None:
-
- program = _NVRTCProgram(kernel.encode(), 'recurrent_forget_mult.cu'.encode())
- GPUForgetMult.ptx = program.compile()
- if torch.cuda.current_device() not in GPUForgetMult.configured_gpus:
- m = function.Module()
- m.load(bytes(self.ptx.encode()))
- self.forget_mult = m.get_function('recurrent_forget_mult')
- self.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult')
- Stream = namedtuple('Stream', ['ptr'])
- self.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)
- GPUForgetMult.configured_gpus[torch.cuda.current_device()] = (self.forget_mult, self.bwd_forget_mult, self.stream)
- self.forget_mult, self.bwd_forget_mult, self.stream = GPUForgetMult.configured_gpus[torch.cuda.current_device()]
- def forward(self, f, x, hidden_init=None):
- self.compile()
- seq_size, batch_size, hidden_size = f.size()
- result = f.new(seq_size + 1, batch_size, hidden_size)
- # We only zero the result array (result[0]) if we don't set a hidden initial state
- # All other values (result[1:]) are overwritten by default
- if hidden_init is not None: result[0, :, :] = hidden_init
- else: result = result.zero_()
- ###
- grid_hidden_size = min(hidden_size, 512)
- grid = (math.ceil(hidden_size / grid_hidden_size), batch_size)
- self.forget_mult(grid=grid, block=(grid_hidden_size, 1), args=[result.data_ptr(), f.data_ptr(), x.data_ptr(), seq_size, batch_size, hidden_size], stream=self.stream)
- self.save_for_backward(f, x, hidden_init)
- self.result = result
- return result[1:, :, :]
- def backward(self, grad_h):
- self.compile()
- f, x, hidden_init = self.saved_tensors
- h = self.result
- ###
- seq_size, batch_size, hidden_size = f.size()
- # Zeroing is not necessary as these will be overwritten
- grad_f = f.new(*f.size())
- grad_x = f.new(*f.size())
- grad_h_init = f.new(batch_size, hidden_size)
- ###
- grid_hidden_size = min(hidden_size, 512)
- grid = (math.ceil(hidden_size / grid_hidden_size), batch_size)
- self.bwd_forget_mult(grid=grid, block=(grid_hidden_size, 1), args=[h.data_ptr(), f.data_ptr(), x.data_ptr(), grad_h.data_ptr(), grad_f.data_ptr(), grad_x.data_ptr(), grad_h_init.data_ptr(), seq_size, batch_size, hidden_size], stream=self.stream)
- ###
- if hidden_init is not None:
- return grad_f, grad_x, grad_h_init
- return grad_f, grad_x
- class ForgetMult(torch.nn.Module):
- r"""ForgetMult computes a simple recurrent equation:
- h_t = f_t * x_t + (1 - f_t) * h_{t-1}
- This equation is equivalent to dynamic weighted averaging.
- Inputs: X, hidden
- - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
- - F (seq_len, batch, input_size): tensor containing the forget gate values, assumed in range [0, 1].
- - hidden_init (batch, input_size): tensor containing the initial hidden state for the recurrence (h_{t-1}).
- - use_cuda: If True, use the fast element-wise CUDA kernel for recurrence. If False, uses naive for loop. Default: True.
- """
- def __init__(self):
- super(ForgetMult, self).__init__()
- def forward(self, f, x, hidden_init=None, use_cuda=True):
- # Use CUDA by default unless it's available
- use_cuda = use_cuda and torch.cuda.is_available()
- # Ensure the user is aware when ForgetMult is not GPU version as it's far faster
- if use_cuda: assert f.is_cuda and x.is_cuda, 'GPU ForgetMult with fast element-wise CUDA kernel requested but tensors not on GPU'
- ###
- # Avoiding 'RuntimeError: expected a Variable argument, but got NoneType' when hidden_init is None
- if hidden_init is None: return GPUForgetMult()(f, x) if use_cuda else CPUForgetMult()(f, x)
- return GPUForgetMult()(f, x, hidden_init) if use_cuda else CPUForgetMult()(f, x, hidden_init)
|