tokenlearner-pytorch
Unofficial PyTorch implementation of TokenLearner by Ryoo et al. from Google AI (abs, pdf)
Installation
You can install TokenLearner via pip:
pip install tokenlearner-pytorch
Usage
You can access the TokenLearner class from the tokenlearner_pytorch package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch
from tokenlearner_pytorch import TokenLearner
tklr = TokenLearner(S=8)
x = torch.rand(512, 32, 32, 3)
y = tklr(x) # [512, 8, 3]
You can also use TokenLearner and TokenFuser together with Multi-head Self-Attention as done in the paper:
import torch
import torch.nn as nn
from tokenlearner_pytorch import TokenLearner, TokenFuser
mhsa = nn.MultiheadAttention(3, 1)
tklr = TokenLearner(S=8)
tkfr = TokenFuser(H=32, W=32, C=3, S=8)
x = torch.rand(512, 32, 32, 3) # a batch of images
y = tklr(x)
y = y.view(8, 512, 3)
y, _ = mhsa(y, y, y) # ignore attn weights
y = y.view(512, 8, 3)
out = tkfr(y, x) # [512, 32, 23, 3]
TODO
- Add support for temporal dimension
T - Implement
TokenFuserwithViT - Implement
TokenFuserwithViViT
Contributions
If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!
