Truly shift-invariant convolutional neural networks
Authors: Anadi Chaman and Ivan Dokmanić
Convolutional neural networks were always assumed to be shift invariant, until recently when it was shown that the classification accuracy of a trained CNN can take a serious hit with merely a 1-pixel shift in input image. One of the primary reasons for this problem is the use of downsampling (popularly known as stride) layers in the networks.
In this work, we present Adaptive Polyphase Sampling (APS), an easy-to-implement non-linear downsampling scheme that completely gets rid of this problem. The resulting CNNs yield 100% consistency in classification performance under shifts without any loss in accuracy. In fact, unlike prior works, the networks exhibit perfect consistency even before training, making it the first approach that makes CNNs truly shift invariant.
This repository contains our code in PyTorch to implement APS.
To train ResNet-18 model with APS on ImageNet use the following commands (training and evaluation with circular shifts).
cd imagenet_exps python3 main.py --out-dir OUT_DIR --arch resnet18_aps1 --seed 0 --data PATH-TO-DATASET
For training on multiple GPUs:
cd imagenet_exps python3 main.py --out-dir OUT_DIR --arch resnet18_aps1 --seed 0 --data PATH-TO-DATASET --workers NUM_WORKERS --dist-url tcp://127.0.0.1:FREE-PORT --dist-backend nccl --multiprocessing-distributed --world-size 1 --rank 0
--arch is used to specify the architecture. To use ResNet18 with APS layer and blur filter of size j, pass 'resnet18_apsj' as the argument to
--arch. List of currently supported network architectures are here.
--circular_data_aug can be used to additionally train the networks with random circular shifts.
Results are saved in OUT_DIR.
The following commands run our implementation on CIFAR-10 dataset.
cd cifar10_exps python3 main.py --arch 'resnet18_aps' --filter_size FILTER_SIZE --validate_consistency --seed_num 0 --device_id 0 --model_folder CURRENT_MODEL_DIRECTORY --results_root_path ROOT_DIRECTORY --dataset_path PATH-TO-DATASET
--data_augmentation_flag can be used to additionally train the networks with randomly shifted images. FILTER_SIZE can take the values between 1 to 7. The list of CNN architectures currently supported can be found here.
The results are saved in the path: ROOT_DIRECTORY/CURRENT_MODEL_DIRECTORY/