Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)
Contents
Overview
We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.
Background
We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:
# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)
# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)
# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)
Quick to Use
Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:
<div class="highlight highlight-source-python position-relative" data-snippet-clipboard-copy-content="import torch
def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
'''
conduct scheduled sampling based on the index of decoded tokens
param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections
param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
param max_seq_len: scalar, the max lengh of target sequence
param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
'''
# indexs of decoding steps
t = torch.range(0, max_seq_len-1)
# differenct sampling strategy based on decoding steps
if sampling_strategy == "exponential":
threshold_table = exp_radix ** t
elif sampling_strategy == "sigmoid":
threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
elif sampling_strategy == "linear":
threshold_table = torch.max(epsilon, 1 – t / max_seq_len)
else:
ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)
# convert threshold_table to [batch_size, seq_len]
threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
thresholds = current_thresholds[:, :seq_len]
# conduct sampling based on the above thresholds
random_select_seed = torch.rand([batch_size, seq_len])
second_decoder_inputs = torch.where(random_select_seed
import torch def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths) ''' conduct scheduled sampling based on the index of decoded tokens param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens param max_seq_len: scalar, the max lengh of target sequence param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch ''' # indexs of decoding steps t = torch.range(0, max_seq_len-1) # differenct sampling strategy based on decoding steps if sampling_strategy == "exponential": threshold_table = exp_radix ** t elif sampling_strategy == "sigmoid": threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k )) elif sampling_strategy == "linear": threshold_table = torch.max(epsilon, 1 - t / max_seq_len) else: ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy) # convert threshold_table to [batch_size, seq_len] threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril() thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len) thresholds = current_thresholds[:, :seq_len] # conduct sampling based on the above thresholds random_select_seed = torch.rand([batch_size, seq_len]) second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs) return second_decoder_inputs