pytorch/examples

Wrong Positional embedding in project examples/vision_transformer

pb07210028 opened this issue · 0 comments

the absolute position embedding used in examples/vision_transformer/main.py seemed to be incorrect:

# Positional embedding
self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

which should look like this

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, self.latent_size)).to(self.device)