JAX implementation of Swin-Transformer v2


This project compared the performance (training/validation speed and accuracy for sanity checking) of Swin-Transformer v2 implemented in JAX and PyTorch. All of these works had been done in Colab environment with Tesla V100-SMX2 GPU. Some of the features in Swin-Transformer v2 has not been implemented, or omitted, yet in JAX setting, such as absolute positional embedding or using pretrained window.

Getting Started


Since this project was done in the Colab environment, which pre-installed all the DL related packages (PyTorch, JAX, Tensorflow), instructions for installing those packages are omitted. If you are not using the Colab, please visit the links above to install those packages.

All the remaining packages can be installed with the following command:

pip install -r requirements.txt

Download Imagenette dataset

This project used Imagenette dataset. Imagenette dataset is a subset of 10 classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). If you want to train/test yourself, please click “Imagenette” above to download the dataset.

Its file size is 1.45 GB, and contains 9,469 training images and 3,925 validation images.

Train Swin-Transformer v2 (PyTorch/JAX)

Experiment & results

With batch size=64 and image size=(256,256) settings, JAX was 21.9% faster than PyTorch during training and 59.6% faster during validation.




This project was inspired by Swin-Transformer and vision_transformer. Some of the codes were borrowed from them.


