Minimal implementation of SimSiam (Exploring Simple Siamese Representation Learning by Xinlei Chen & Kaiming He) in TensorFlow 2.
The purpose of this repository is to demonstrate the workflow of SimSiam and NOT to implement it note to note and at the same time I will try not to miss out on the major bits discussed in the paper. For that matter, I'll be using the Flowers dataset.
Following depicts the workflow of SimSiam (taken from the paper) -
The authors have also provided a PyTorch-like psuedocode in the paper (how cool!) -
# f: backbone + projection mlp # h: prediction mlp for x in loader: # load a minibatch x with n samples x1, x2 = aug(x), aug(x) # random augmentation z1, z2 = f(x1), f(x2) # projections, n-by-d p1, p2 = h(z1), h(z2) # predictions, n-by-d L = D(p1, z2)/2 + D(p2, z1)/2 # loss L.backward() # back-propagate update(f, h) # SGD update def D(p, z): # negative cosine similarity z = z.detach() # stop gradient p = normalize(p, dim=1) # l2-normalize z = normalize(z, dim=1) # l2-normalize return -(p*z).sum(dim=1).mean()
The authors emphasize the
stop_gradient operation that helps the network to avoid collapsing solutions. Further details about this are available in the paper. SimSiam eliminates the need for using large batch sizes, momentum encoders, memory banks, negative samples, etc. that are important components of the modern self-supervised learning frameworks for visual recognition. This makes SimSiam an easily approachable framework for practical problems.
About the notebooks
SimSiam_Pre_training.ipynb: Pre-trains a ResNet50 using SimSiam.
SimSiam_Evaluation.ipynb: Evaluates (linear evaluation) ResNet50 as pre-trained in
|Pre-training Schedule||Validation Accuracy (Linear Evaluation)|
I think with further hyperparameter-tuning and regularization these scores can be improved.
|Training Type||Validation Accuracy (Linear Evaluation)|
|Supervised ImageNet-trained ResNet50 Features||48.36%|
|From Scratch Training with ResNet50||63.64%|
The figure below shows the training loss plots from two different pre-training schedules (50 epochs and 75 epochs) -
We see that the loss gets plateaued after 35 epochs. We can experiment with the following components to further improve this -
- data augmentation pipeline
- architectures of the two MLP heads
- learning rate schedule used during pre-training
and so on.