Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)

Contents

Overview

We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.

Background

fastText

We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:

# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)

# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)

# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)

Quick to Use

Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:

<div class="highlight highlight-source-python position-relative" data-snippet-clipboard-copy-content="import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
'''
conduct scheduled sampling based on the index of decoded tokens
param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections
param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
param max_seq_len: scalar, the max lengh of target sequence
param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
'''

# indexs of decoding steps
t = torch.range(0, max_seq_len-1)

# differenct sampling strategy based on decoding steps
if sampling_strategy == "exponential":
threshold_table = exp_radix ** t
elif sampling_strategy == "sigmoid":
threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
elif sampling_strategy == "linear":
threshold_table = torch.max(epsilon, 1 – t / max_seq_len)
else:
ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

# convert threshold_table to [batch_size, seq_len]
threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
thresholds = current_thresholds[:, :seq_len]

# conduct sampling based on the above thresholds
random_select_seed = torch.rand([batch_size, seq_len])
second_decoder_inputs = torch.where(random_select_seed

import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
    '''
    conduct scheduled sampling based on the index of decoded tokens 
    param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections 
    param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
    param max_seq_len: scalar, the max lengh of target sequence
    param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
    '''

    # indexs of decoding steps
    t = torch.range(0, max_seq_len-1)

    # differenct sampling strategy based on decoding steps
    if sampling_strategy == "exponential":
        threshold_table = exp_radix ** t  
    elif sampling_strategy == "sigmoid":
        threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
    elif sampling_strategy == "linear":        
        threshold_table = torch.max(epsilon, 1 - t / max_seq_len)
    else:
        ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

    # convert threshold_table to [batch_size, seq_len]
    threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
    thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
    thresholds = current_thresholds[:, :seq_len]

    # conduct sampling based on the above thresholds
    random_select_seed = torch.rand([batch_size, seq_len]) 
    second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs)

    return second_decoder_inputs