PyTorch Implementations of various state of the art architectures.
import torch
from MLP_Mixer import MLPMixer
model = MLPMixer(
classes= 10,
blocks= 6,
img_dim= 128,
patch_dim= 128,
in_channels= 3,
dim= 512,
token_dim= 256,
channel_dim= 2048
)
x = torch.randn(1, 3, 128, 128)
model(x) # (1, 10)
import torch
from TransUNet import TransUNet
model = TransUNet(
img_dim= 128,
patch_dim= 16,
in_channels= 3,
classes= 2,
blocks= 6,
heads= 8,
linear_dim= 1024
)
x = torch.randn(1, 3, 128, 128)
model(x) # (2, 128, 128)
import torch
from ViT import ViT
model = ViT(
img_dim= 128,
in_channels= 3,
patch_dim= 16,
classes= 10,
dim= 512,
blocks= 6,
heads= 4,
linear_dim= 1024,
classification= True
)
x = torch.randn(1, 3, 128, 128)
model(x) # (1, 10)
import torch
from fastformer import FastFormer
model = FastFormer(
in_dims= 256,
token_dim= 512,
num_heads= 8
)
x = torch.randn(1, 128, 256)
model(x) # (1, 128, 512)
import torch
from han import HAN
model = HAN(in_channels= 3)
x = torch.randn(1, 3, 256, 256)
model(x) # (1, 3, 512, 512)