adaptive_softmax.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from .lm_rnn import *
  2. class AdaptiveSoftmax(nn.Module):
  3. def __init__(self, input_size, cutoff):
  4. super().__init__()
  5. self.input_size,self.cutoff = input_size,cutoff
  6. self.output_size = cutoff[0] + len(cutoff) - 1
  7. self.head = nn.Linear(input_size, self.output_size)
  8. self.tail = nn.ModuleList()
  9. for i in range(len(cutoff) - 1):
  10. seq = nn.Sequential(nn.Linear(input_size, input_size // 4 ** i, False),
  11. nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False))
  12. self.tail.append(seq)
  13. def reset(self):
  14. nn.init.xavier_normal(self.head.weight)
  15. for tail in self.tail:
  16. nn.init.xavier_normal(tail[0].weight)
  17. nn.init.xavier_normal(tail[1].weight)
  18. def set_target(self, target):
  19. self.id = []
  20. for i in range(len(self.cutoff) - 1):
  21. mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
  22. if mask.sum() > 0:
  23. self.id.append(Variable(mask.float().nonzero().squeeze(1)))
  24. else: self.id.append(None)
  25. def forward(self, input):
  26. output = [self.head(input)]
  27. for i in range(len(self.id)):
  28. if self.id[i] is not None:
  29. output.append(self.tail[i](input.index_select(0, self.id[i])))
  30. else: output.append(None)
  31. return output
  32. def log_prob(self, input):
  33. lsm = nn.LogSoftmax().cuda()
  34. head_out = self.head(input)
  35. batch_size = head_out.size(0)
  36. prob = torch.zeros(batch_size, self.cutoff[-1]).cuda()
  37. lsm_head = lsm(head_out)
  38. prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
  39. for i in range(len(self.tail)):
  40. pos = self.cutoff[i]
  41. i_size = self.cutoff[i + 1] - pos
  42. buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1)
  43. buffer = buffer.expand(batch_size, i_size)
  44. lsm_tail = lsm(self.tail[i](input))
  45. prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data)
  46. return prob
  47. class AdaptiveLoss(nn.Module):
  48. def __init__(self, cutoff):
  49. super().__init__()
  50. self.cutoff = cutoff
  51. self.criterions = nn.ModuleList([nn.CrossEntropyLoss(size_average=False) for i in self.cutoff])
  52. def remap_target(self, target):
  53. new_target = [target.clone()]
  54. for i in range(len(self.cutoff) - 1):
  55. mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
  56. new_target[0][mask] = self.cutoff[0] + i
  57. if mask.sum() > 0: new_target.append(target[mask].add(-self.cutoff[i]))
  58. else: new_target.append(None)
  59. return new_target
  60. def forward(self, input, target):
  61. batch_size = input[0].size(0)
  62. target = self.remap_target(target.data)
  63. output = 0.0
  64. for i in range(len(input)):
  65. if input[i] is not None:
  66. assert(target[i].min() >= 0 and target[i].max() <= input[i].size(1))
  67. criterion = self.criterions[i]
  68. output += criterion(input[i], Variable(target[i]))
  69. output /= batch_size
  70. return output