MrNeRF/gaussian-splatting-cuda

Apex FusedAdam vs PyTorch Adam

fhahlbohm opened this issue · 2 comments

Hey, this looks like a cool project!

I was wondering whether you tried to replace PyTorch Adam with the FusedAdam implementation from NVIDIA Apex (https://nvidia.github.io/apex/optimizers.html). Note that while PyTorch itself includes a fused Adam implementation, I found that it does not work properly back when I first tested it earlier this year.

I just tested this (PyTorch Adam -> Apex FusedAdam) within my re-implementation of the official Python implementation and on my 4090 the training for the garden scene went from 24 minutes to about 18 minutes, a 33% improvement.

As I like Python-based implementations for research-driven development, I will not look into this for your project myself so I thought I just let you know.

MrNeRF commented

Cool and thank you! Did not know about this one. I need to test it out. Would be interesting if it has an effect on the performance.
But let me better understand. My profiling did not point out the optimizing step as bottleneck. Not much is happening here. Is this the speedup you have measured? Because 33% seems to be a rather large gain for something that does not add much to the total runtime. The bottlenecks are rather found at different place. As I have spend much time to measure performance, I pretty much know what kills it.

Or do you mean something like that it minimizes more effectively? In this case the optimizer would need fewer training steps and this could speed up training considerably.

What is exactly your training setup for the garden scene (resolution, iterations)?
As I have also a RTX 4090, I think it would be interesting to compare.
On the lowest resolution (648x420) I measured 87.54s over 7000 iterations with a final number of ~1.7 Mio splats.

My Python-based implementation trains the default mipNeRF360 garden scene configuration (i.e. 30k iterations, resolution of 1297x840) in 16:57 (29.48it/s)

This is after replacing PT Adam with Fused Adam. It's literally just installing the apex library, making all parameters contiguous when initializing them, removing the set_to_none=True (apex FusedAdam does this by default), and then exchanging torch.optim.Adam() with apex.optimizers.FusedAdam().

As far as I know the updates to parameters are applied all at once with FusedAdam (hence the term "fused"). So the improvement is not due to better optimization but purely due to a more performant implementation.