yoxu515/aot-benchmark

Work Around for Pytorch-Correlation-extension

Opened this issue · 5 comments

Hi, Thanks for your nice work!

I noticed a lot of questions are on Pytorch-Correlation-extension (PCE). As this package mainly supports pytorch<=1.7 and CUDA < 11.0, it is becoming more difficult to use AOT on recent GPUs (A40s, A100s). Therefore, I am wondering if there are any workarounds for this. I can go with a slower speed, though.

For example, if I need to replace the modules using Pytorch-Correlation-extension with a naive MultiheadAttention and train the model again, I am also happy to do this, as long as the performance is similar.

Thank you so much!

Best,

Ziqi

You could try to set enable_corr=false (for AOT/DeAOT) to not Pytorch-Correlation-extension during inference/training.

By doing this, the short-term attention will be calculated like a MultiheadAttention with mask.

@z-x-yang Thanks for the quick answer! I previously tried enable_corr=false for inference on DAVIS using your provided checkpoints, and the performance is lower than reported in MODEL_ZOO.md. Just double-check that this is normal, and training the model again is needed when enable_corr=false?

Setting enable_corr to false or true should lead to the same performance. Which model did you test? What was the performance?

@z-x-yang Thanks for the reply! I forgot to update that I figured this out.

However, I believe that this line in MultiHeadLocalAttentionv3 needs to add .permute(2, 0, 1, 3).reshape(h * w, n, c) to make the things smooth. Perhaps you could integrate this change in case others also need this work-around?

@z-x-yang What do you think about the suggested permute? It seems that the shape doesn't match in the original formulation.