qile2000/LAMDA-TALENT

cuda out of memory

xiaoye23 opened this issue · 1 comments

When I use one RTX4090 to train the modernNCA model, the cuda out of memory occurs. It seems that the modernNCA model need large cuda memory for trainning. My datasets have one million or more samples with 10 category features and 31 numerical features. How to make the trainning work by single RTX4090? Thank you !

Thank you for trying modernNCA! As a retrieval model, the memory usage during training increases linearly with the size of the retrieval set, meaning that memory consumption can become significant when dealing with large-scale tabular datasets. If your computational resources are limited, you can try the following methods to reduce the memory usage of modernNCA:

We’ve found that even with extremely low sample rates (e.g., 1%–5%), performance only slightly decreases with proper hyperparameter tuning on most large-scale datasets. Therefore, you can opt to tune the model parameters with low sample rate during training and use a smaller batch size during inference. This will lower memory usage with small performance impact.

ModernNCA uses PLR (lite) embedding for numerical features to further enhance its capabilities, while it also increases memory usage. In some cases, you can disable PLR (lite) embedding by setting num_embeddings to None:
"num_embeddings": ["categorical",[None]]

Overall, we recommend using multiple GPUs for training if you have them available, as this will allow the model to fully realize its potential. You can achieve this by slightly modifying the framework with torch.nn.parallel.DataParallel.