1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from .lm_rnn import *
- class AdaptiveSoftmax(nn.Module):
- def __init__(self, input_size, cutoff):
- super().__init__()
- self.input_size,self.cutoff = input_size,cutoff
- self.output_size = cutoff[0] + len(cutoff) - 1
- self.head = nn.Linear(input_size, self.output_size)
- self.tail = nn.ModuleList()
- for i in range(len(cutoff) - 1):
- seq = nn.Sequential(nn.Linear(input_size, input_size // 4 ** i, False),
- nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False))
- self.tail.append(seq)
- def reset(self):
- nn.init.xavier_normal(self.head.weight)
- for tail in self.tail:
- nn.init.xavier_normal(tail[0].weight)
- nn.init.xavier_normal(tail[1].weight)
- def set_target(self, target):
- self.id = []
- for i in range(len(self.cutoff) - 1):
- mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
- if mask.sum() > 0:
- self.id.append(Variable(mask.float().nonzero().squeeze(1)))
- else: self.id.append(None)
- def forward(self, input):
- output = [self.head(input)]
- for i in range(len(self.id)):
- if self.id[i] is not None:
- output.append(self.tail[i](input.index_select(0, self.id[i])))
- else: output.append(None)
- return output
- def log_prob(self, input):
- lsm = nn.LogSoftmax().cuda()
- head_out = self.head(input)
- batch_size = head_out.size(0)
- prob = torch.zeros(batch_size, self.cutoff[-1]).cuda()
- lsm_head = lsm(head_out)
- prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
- for i in range(len(self.tail)):
- pos = self.cutoff[i]
- i_size = self.cutoff[i + 1] - pos
- buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1)
- buffer = buffer.expand(batch_size, i_size)
- lsm_tail = lsm(self.tail[i](input))
- prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data)
- return prob
- class AdaptiveLoss(nn.Module):
- def __init__(self, cutoff):
- super().__init__()
- self.cutoff = cutoff
- self.criterions = nn.ModuleList([nn.CrossEntropyLoss(size_average=False) for i in self.cutoff])
- def remap_target(self, target):
- new_target = [target.clone()]
- for i in range(len(self.cutoff) - 1):
- mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
- new_target[0][mask] = self.cutoff[0] + i
- if mask.sum() > 0: new_target.append(target[mask].add(-self.cutoff[i]))
- else: new_target.append(None)
- return new_target
- def forward(self, input, target):
- batch_size = input[0].size(0)
- target = self.remap_target(target.data)
- output = 0.0
- for i in range(len(input)):
- if input[i] is not None:
- assert(target[i].min() >= 0 and target[i].max() <= input[i].size(1))
- criterion = self.criterions[i]
- output += criterion(input[i], Variable(target[i]))
- output /= batch_size
- return output
|