ℹ️ This is a WIP project. the implementation is still being tested.
A pure PyTorch implementation of the loss described in “Online Segment to Segment Neural Transduction” https://arxiv.org/abs/1609.08194.
There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don’t.
def ssnt_loss_mem( log_probs: Tensor, targets: Tensor, log_p_choose: Tensor, source_lengths: Tensor, target_lengths: Tensor, neg_inf: float = -1e4, reduction="mean", ): """The memory efficient implementation concatenates along the targets dimension to reduce wasted computation on padding positions. Assuming the summation of all targets in the batch is T_flat, then the original B x T x ... tensor is reduced to T_flat x ... The input tensors can be obtained by using target mask: Example: >>> target_mask = targets.ne(pad) # (B, T) >>> targets = targets[target_mask] # (T_flat,) >>> log_probs = log_probs[target_mask] # (T_flat, S, V) Args: log_probs (Tensor): Word prediction log-probs, should be output of log_softmax. tensor with shape (T_flat, S, V) where T_flat is the summation of all target lengths, S is the maximum number of input frames and V is the vocabulary of labels. targets (Tensor): Tensor with shape (T_flat,) representing the reference target labels for all samples in the minibatch. log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid. tensor with shape (T_flat, S) where T_flat is the summation of all target lengths, S is the maximum number of input frames. source_lengths (Tensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. target_lengths (Tensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. neg_inf (float, optional): The constant representing -inf used for masking. Default: -1e4 reduction (string, optional): Specifies reduction. suppoerts mean / sum. Default: None. """
import torch import torch.nn as nn import torch.nn.functional as F from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask B, S, H, T, V = 2, 100, 256, 10, 2000 # model transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda() predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda() joiner_trans = nn.Linear(H, V, bias=False).cuda() joiner_alpha = nn.Sequential( nn.Linear(H, 1, bias=True), nn.Tanh() ).cuda() # inputs src_embed = torch.rand(B, S, H).cuda().requires_grad_() tgt_embed = torch.rand(B, T, H).cuda().requires_grad_() targets = torch.randint(0, V, (B, T)).cuda() adjust = lambda x, goal: x * goal // x.max() source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S) target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T) # forward src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0)) tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0)) # memory efficient joint mask = ~lengths_to_padding_mask(target_lengths) lattice = F.relu( src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2) )[mask] log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1) lattice = joiner_trans(lattice).log_softmax(-1) # normal ssnt loss loss = ssnt_loss_mem( lattice, targets[mask], log_alpha, source_lengths=source_lengths, target_lengths=target_lengths, reduction="sum" ) / (B*T) loss.backward() print(loss.item())
This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized
cumprod to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.
Feel free to contact me if there are bugs in the code.