JAX implementation of Swin-Transformer v2
This project compared the performance (training/validation speed and accuracy for sanity checking) of Swin-Transformer v2 implemented in JAX and PyTorch. All of these works had been done in Colab environment with Tesla V100-SMX2 GPU. Some of the features in Swin-Transformer v2 has not been implemented, or omitted, yet in JAX setting, such as absolute positional embedding or using pretrained window.
Since this project was done in the Colab environment, which pre-installed all the DL related packages (PyTorch, JAX, Tensorflow), instructions for installing those packages are omitted. If you are not using the Colab, please visit the links above to install those packages.
All the remaining packages can be installed with the following command:
pip install -r requirements.txt
Download Imagenette dataset
This project used Imagenette dataset. Imagenette dataset is a subset of 10 classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). If you want to train/test yourself, please click “Imagenette” above to download the dataset.
Its file size is 1.45 GB, and contains 9,469 training images and 3,925 validation images.