Perceiver IO

Unofficial PyTorch implementation of

This implementation supports training of Perceiver IO models with Pytorch Lightning
on some example tasks via a command line interface. Perceiver IO models are constructed using generic encoder
and decoder classes and task-specific input and output adapters (see Model API).

Setup

conda env create -f environment.yml
conda activate perceiver-io
export PYTHONPATH=.

Tasks

In the following subsections, Perceiver IO models are trained on some example tasks at smaller scale. In particular,
they were trained on two NVIDIA GTX 1080 GPUs (8 GB memory each) using Pytorch Lightning’s support for
distributed data-parallel
training. I didn’t really tune model architectures and other hyper-parameters, so you’ll probably get better results
with a bit of experimentation. Support for more datasets and tasks will be added later.

Masked language modeling

Pretrain a Perceiver IO model on masked language modeling (MLM) with text from the IMDB training set. The
pretrained encoder is then used for training a sentiment classification model.

python train/train_mlm.py --dataset=imdb --learning_rate=1e-3 --batch_size=64 \
  --max_epochs=200 --dropout=0.0 --weight_decay=0.0 \
  --accelerator=ddp --gpus=-1

All available command line options and their default values can be displayed with python train/train_mlm.py -h.

Sentiment classification

Train a classification decoder using a frozen encoder from masked language modeling.
If you ran MLM yourself you’ll need to modify the --mlm_checkpoint argument accordingly, otherwise download
checkpoints from here and extract them in the root directory of
this project.

python train/train_seq_clf.py --dataset=imdb --learning_rate=1e-3 --batch_size=128 \
  --max_epochs=15 --dropout=0.0 --weight_decay=1e-3 --freeze_encoder \
  --accelerator=ddp --gpus=-1 \
  --mlm_checkpoint 'logs/mlm/version_0/checkpoints/epoch=199-val_loss=4.899.ckpt'

Unfreeze the encoder and jointly fine-tune it together with the decoder that has been trained in the previous step.
If you ran the previous step yourself you’ll need to modify the --clf_checkpoint argument accordingly, otherwise
download checkpoints from here.

python train/train_seq_clf.py --dataset=imdb --learning_rate=1e-4 --batch_size=128 \
  --max_epochs=15 --dropout=0.2 --weight_decay=1e-3 \
  --accelerator=ddp --gpus=-1 \
  --clf_checkpoint 'logs/seq_clf/version_0/checkpoints/epoch=014-val_loss=0.350.ckpt'

All available command line options and their default values can be displayed with python train/train_seq_clf.py -h.

Image classification

Classify MNIST images. See also Model API for details about the underlying Perceiver IO model.

python train/train_img_clf.py --dataset=mnist --learning_rate=1e-3 --batch_size=128 \
  --max_epochs=20 --dropout=0.0 --weight_decay=1e-4 \
  --accelerator=ddp --gpus=-1

All available command line options and their default values can be displayed with python train/train_img_clf.py -h.

Model API

The model API is based on generic encoder and decoder classes (PerceiverEncoder and
PerceiverDecoder) and task-specific input and output adapters. The following snippet
shows how they can be used to create an MNIST image classifier, for example:

from perceiver.adapter import ImageInputAdapter, ClassificationOutputAdapter
from perceiver.model import PerceiverIO, PerceiverEncoder, PerceiverDecoder

latent_shape = (32, 128)

# Fourier-encode pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(image_shape=(28, 28, 1), num_frequency_bands=32)

# Project generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(num_classes=10, num_output_channels=128)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
    input_adapter=input_adapter,
    latent_shape=latent_shape,
    num_layers=3,
    num_cross_attention_heads=4,
    num_self_attention_heads=4,
    num_self_attention_layers_per_block=3,
    dropout=0.0)

# Generic Perceiver decoder
decoder = PerceiverDecoder(
    output_adapter=output_adapter,
    latent_shape=latent_shape,
    num_cross_attention_heads=1,
    dropout=0.0)

# MNIST classifier implemented as Perceiver IO model
mnist_classifier = PerceiverIO(encoder, decoder)

Tensorboard

Commands in section Tasks write Tensorboard logs to the logs directory. They can be visualized with
tensorboard --logir logs. MLM training additionally writes predictions of masked sample text to Tensorboard’s
TEXT page. For example, the command

python train/train_mlm.py --dataset=imdb --learning_rate=1e-3 --batch_size=64 \
  --max_epochs=200 --dropout=0.0 --weight_decay=0.0 \
  --accelerator=ddp --gpus=-1 --predict_k=5 \
  --predict_samples='i have watched this [MASK] and it was awesome'  

writes the top 5 predictions for I have watched this [MASK] and it was awesome to Tensorboard after each epoch:

i have watched this [MASK] and it was awesome
i have watched this movie and it was awesome
i have watched this show and it was awesome
i have watched this film and it was awesome
i have watched this series and it was awesome
i have watched this dvd and it was awesome

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

@misc{jaegle2021perceiver,
    title   = {Perceiver IO: A General Architecture for Structured Inputs & Outputs},
    author  = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
    year    = {2021},
    eprint  = {2107.14795},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

GitHub

https://github.com/krasserm/perceiver-io