/Transformer-Unet

Unet implementation with transformer encoder

Primary LanguagePython

Transformer-Unet

Unet implementation with transformer encoder

Usage

from unet import TransUnet

a = torch.rand(2, 3, 128, 128)

model = TransUnet(in_channels=3, img_dim=128, vit_blocks=1, vit_dim_linear_mhsa_block=512, classes=5)
y = model(a)
print('final out shape:', y.shape)