Compatibility with gradient accumulation
quasimik opened this issue · 1 comments
I'm bringing my own PyTorch training script, and I'm interested in using SM Debugger to profile function calls in my training jobs. The API Glossary states:
Step: Step means one the work done by the training job for one batch (i.e. forward and backward pass).
I assume I will have to register my module with hook.register_module(module)
in the training script for SM Debugger to work at all. I further assume that SM Debugger then registers its own hooks into the module's forward() and/or backward() passes to track when a "step" happens.
However, my training script accumulates gradients from several forward() passes before running a single backward() pass.
My questions:
- Will this interfere with the functionality of SM Debugger?
- Assuming this is okay, does SM Debugger consider the forward() or the backward() pass to be one "step"?
After looking at the code, I think I can answer question (2) for myself.
Here, register_module()
registers a function self.forward_pre_hook()
on the module's forward call.
sagemaker-debugger/smdebug/pytorch/hook.py
Line 603 in 99282cd
Here, self.forward_pre_hook()
increments the step count.
sagemaker-debugger/smdebug/pytorch/hook.py
Line 321 in 99282cd
I'm still curious about question (1), though. Any insight is appreciated.