ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into convolution layers by adding squeeze-and-excitation modules.
Also, ContextNet supports three size models: small, medium, and large. ContextNet uses the global parameter alpha to control the scaling of the model by changing the number of channels in the convolution filter.
This repository contains only model code, but you can train with ContextNet at openspeech.
- Configuration of the ContextNet encoder
If you choose the model size among small, medium, and large, the number of channels in the convolution filter is set using the global parameter alpha. If the stride of a convolution block is 2, its last conv layer has a stride of two while the rest of the conv layers has a stride of one.
- A convolution block architecuture
ContextNet has 23 convolution blocks C0, .... ,C22. All convolution blocks have five layers of convolution except C0 and C22 which only have one layer of convolution each. A skip connection with projection is applied on the output of the squeeze-and-excitation(SE) block.
- 1D Squeeze-and-excitation(SE) module
Average pooling is applied to condense the convolution result into a 1D vector and then followed two fully connected (FC) layers with activation functions. The output goes through a Sigmoid function to be mapped to (0, 1) and then tiled and applied on the convolution output using pointwise multiplications.
Please check the paper for more details.
pip install -e .
from contextnet.model import ContextNet import torch BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE, NUM_VOCABS = 3, 500, 80, 10 cuda = torch.cuda.is_available() device = torch.device('cuda' if cuda else 'cpu') model = ContextNet( model_size='large', num_vocabs=10, ).to(device) inputs = torch.FloatTensor(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE).to(device) input_lengths = torch.IntTensor([500, 450, 350]) targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2], [1, 3, 3, 3, 3, 3, 4, 5, 2, 0], [1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device) target_lengths = torch.LongTensor([9, 8, 7]) # Forward propagate outputs = model(inputs, input_lengths, targets, target_lengths) # Recognize input speech outputs = model.recognize(inputs, input_lengths)