Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt CIFAR-10 dataset.
Original Paper of the model : https://arxiv.org/abs/2012.11747
So how are RealFormers at vision tasks?
Run the train.py with
model = ViR( image_pix = 32, patch_pix = 4, class_cnt = 10, layer_cnt = 4 )
to Test how RealFormer works on CIFAR-10 dataset compared to just classical ViT, which is
model = ViT( image_pix = 32, patch_pix = 4, class_cnt = 10, layer_cnt = 4 )
... which is of course, much, much smaller version of ViT compared to the origianl ones ().
Model : layers = 4, hidden_dim = 128, feedforward_dim = 512, head_cnt = 4
Trained 10 epochs
After 10'th epoch, Realformer achieves 65.45% while Transformer achieves 64.59%
RealFormer seems to consistently have about 1% greater accuracy, which seems reasonable (as the papaer suggested simillar result)
Model : layers = 8, hidden_dim = 128, feedforward_dim = 512, head_cnt = 4
Having 4 more layers obviously improves in general, and still, RealFormer consistently wins in terms of accuracy (68.3% vs 66.3%). Notice that larger the model, bigger the difference seems to follow here too. (I wonder how much of difference it would make on ViT-Large)
When it comes to computation time, there was almost zero difference. (I guess adding residual attention score is O(L^2) operation, compared to matrix multiplication in softmax which is O(L^2 * D))
Use RealFormer. It benifits with almost zero additional resource!
To make a custom RealFormer for other tasks
Its not a pip package, but you can use the ResEncoderBlock module in the models.py to make a Encoder Only Transformer like the following :
import ResEncoderBlock from models def RealFormer(nn.Module): ... def __init__(self, ...): ... self.mains = nn.Sequential(*[ResEncoderBlock(emb_s = 32, head_cnt = 8, dp1 = 0.1, dp2 = 0.1) for _ in range(layer_cnt)]) ... def forward(self, x): ... prev = None for resencoder in self.mains: x, prev = resencoder(x, prev = prev) ... return x
If you're not really clear what is going on or what to do, request me to make this a pip package.