You can install TokenLearner via
pip install tokenlearner-pytorch
You can access the
TokenLearner class from the
tokenlearner_pytorch package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch from tokenlearner_pytorch import TokenLearner tklr = TokenLearner(S=8) x = torch.rand(512, 32, 32, 3) y = tklr(x) # [512, 8, 3]
You can also use
TokenFuser together with Multi-head Self-Attention as done in the paper:
import torch import torch.nn as nn from tokenlearner_pytorch import TokenLearner, TokenFuser mhsa = nn.MultiheadAttention(3, 1) tklr = TokenLearner(S=8) tkfr = TokenFuser(H=32, W=32, C=3, S=8) x = torch.rand(512, 32, 32, 3) # a batch of images y = tklr(x) y = y.view(8, 512, 3) y, _ = mhsa(y, y, y) # ignore attn weights y = y.view(512, 8, 3) out = tkfr(y, x) # [512, 32, 23, 3]
- Add support for temporal dimension
If I’ve made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!