Official PyTorch implementation and pretrained models of the paper Self-Supervised Classification Network. Self-Classifier is a self-supervised end-to-end classification neural network. It learns labels and representations simultaneously in a single-stage end-to-end manner.
Self-Classifier architecture. Two augmented views of the same image are processed by a shared network. The cross-entropy of the two views is minimized to promote same class prediction while avoiding degenerate solutions by asserting a uniform prior. The resulting model learns representations and class labels in a single-stage end-to-end unsupervised manner. CNN: Convolutional Neural Network; FC: Fully Connected.
Install Conda environment:
conda env create -f ./environment.yml
Install Apex with CUDA extension:
export TORCH_CUDA_ARCH_LIST="7.0" # see https://en.wikipedia.org/wiki/CUDA#GPUs_supported pip install git+git://github.com/NVIDIA/[email protected]7e31ca87514e17c3cd3bbc03f4204579d0 --install-option="--cuda_ext"
Training & Evaluation
Distributed training & evaluation is available via Slurm. See SBATCH scripts here.
IMPORTANT: set DATASET_PATH, EXPERIMENT_PATH and PRETRAINED_PATH to match your local paths.
For training self-classifier on 4 nodes of 4 GPUs each for 800 epochs run:
Image Classification with Linear Models
For training a supervised linear classifier on a frozen backbone, run:
Unsupervised Image Classification
For computing unsupervised image classification metrics (NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information and ARI: Adjusted Rand-Index) and generating qualitative examples, run:
Image Classification with kNN
For running K-nearest neighbor classifier on ImageNet validation set, run:
For training the 100-epoch ablation study baseline, run:
For training any of the ablation study runs presented in the paper, run:
Download pretrained 100/800 epochs models here.
Qualitative Examples (classes predicted by Self-Classifier on ImageNet validation set)
Low entropy classes predicted by Self-Classifier on ImageNet validation set. Images are sampled randomly from each predicted class. Note that the predicted classes capture a large variety of different backgrounds and viewpoints.
To reproduce qualitative examples, run: