nachiket273/Vision_transformer_pytorch

Technical details on implementation

MjdMahasneh opened this issue · 2 comments

Hello,
hope you are keeping well :) First, let me thank you for the amazing implementation! I found it hard to understand the paper and some other implementations I found on GitHub until I stumbled on this repo!! THANK YOU.

Now to my questions:

1-It would be really helpful if you provide information on any differences from the original paper (if any)?

2-Ideally, should the transformer unit be used on its own (e.g. replacing a backbone CNN), or should it be used along with a CNN (e.g. serves as a layer in the network)? For example, I have an encoder-decoder architecture where I am trying to use a transformer unit within it, would it make sense to maybe place it in the latent space where features coming out of the encoder are flattened and passed to the transformer unit?

3-if the answer for question 2 is yes, then do you think it is possible to generalize the implementation to a 3D CNN?

4-if the answer for 3 is yes, is there any tips or things to look for when attempting to do that?

Thank you so much, and will be looking forward to hearing from you.

Hi,

Thank you for the questions and sorry for the late reply.

Here' are some of my thoughts.

  1. there is no difference between the paper and implementation, but the implementation closely follow official jax implementation of paper ( https://github.com/google-research/vision_transformer).

  2. Paper mainly suggests to use transformers in place of cnn(although it also evaluates the model where they feed the feature map from cnn backbone to transformer, but they don't find it to perform any better).

  3. The image is fed to transfomer as series of image patches with positional embeddings, if same can be done for 3d image(splitting it into series of image patches), it can be used for 3d images too.

Thank you @nachiket273 for the detailed response and the great effort, it is very much appreciated.