qile2000/LAMDA-TALENT

Why ModernNCA uses the double precision?

Yura52 opened this issue · 5 comments

Hi, thanks for the project! I noticed that ModernNCA relies on torch.float64 in several places:

I wonder what is the motivation behind this choice? Are there any specific parts of the computation that require the double precision?

Thank you for your question! In our toolkit, all deep learning methods, including ModernNCA, are implemented using double precision. The primary motivation behind this choice is to ensure a fair comparison across different methods with higher numerical accuracy. However, if you prefer to use float precision for your computations due to specific requirements or performance considerations, it's certainly possible. You would just need to make some minor modifications in the code to switch to float precision.

Thank you for the prompt reply! So if I am interested specifically in ModernNCA (not in other models), I can safely switch back to the default (torch.float32) for better efficiency?

Thank you for your follow-up! We haven’t conducted specific experiments regarding precision in ModernNCA. However, in another work within our group on a general tabular model called TabPTM (which also involves nearest neighbors), we found that using different precision could lead to slight differences when searching for nearest neighbors. As a result, we decided to implement double precision across all deep learning methods in the toolkit, including models like FTT and TabR. Moving forward, we plan to run performance comparison experiments for ModernNCA under different precision settings. If you're interested, we would be happy to share the results with you once they're available.

We have conducted experiments specifically on ModernNCA and found that there is no significant difference in performance between float32 and double precision. However, using float32 results in an improvement in efficiency. Regarding the sensitivity of TabPTM to precision, I suspect that it may be due to the fact that TabPTM searches for nearest neighbors directly on raw features and relies on NumPy for certain computations. Based on this, we will be adding an option in future releases that allows users to select their preferred precision for model training. Thank you again for your valuable question!

This totally answers my question, thank you for the insights and the additional experiments!