a1600012888/MUTR3D

Using torch.checkpoint

Closed this issue · 8 comments

Hi, i was wondering have you tried using torch.checkpoint to save gpu memory? I was thinking the computation graph may be not static therefore using torch.checkpoint is not possible?

Looking forward to your reply.

Hi JianingWang

Very good questions.
Short answer: for current implementation, it is hard to use memory checkpoint.

We used checkpoint in training DETR3D with huge backbones.
But for this tracking, there is one key difference: The loss is computted as a sum after multiple round of model forward! (because in each frame, we compute a loss, finnally summed togheater to have one backpropogation)

In this case, you cannot use checkpoint. Unless you change the code, to put all the forward in one call (or in an easier case, you can change the code to have something like backbone.forward(image_at_all_frames), and use checkpoint in the backbone part).

Or you can change the logic of the model, to performance forward-backward propogation at each frame independently. This change will make the whole training much more efficient.

To do this, you need definitely change some part of the model, and I belive this is also a hot topic in studying end-to-end tracking with transformers.

Thousand thanks for you detailed explanation!

To have a forward-backward propogation at each frame, we definitely need a global track_instances defined in tracker.py, right? And the matching and loss calculation part can remain the same?

Looking forward to your reply.

Hi I worked on using torch.checkpoint about one year ago, from my understanding.

You only need to forward-backward once for the checkpoint part.
Suppose you only use checkpoint for the backbone feature extractor.

Then you can perform forward for the tracking head frame by frame.

Then you sum up all the loss and perform one backward.

I think this might be the correct implementation.

Also, I sometimes feel the checkpoint implementation hard to be 100% correct, because when it is implemented in a wrong way ( The gradient might be None(which means no gradient) for the checkpoint part), the training code still can run withour rasing errors, and the loss can still decrease as the tracker head is training. So I suggest you to check the gradient after implementing checkpoint, and compare the gradient with normal implementation..

Thanks.

Thanks a lot for your explaination.

Is it possible to train the model with a complete nuscenes scene during each iteration instead of video clips with the length of 3 or 5 frames for ResNet50 and ResNet101 respectively?

I never tried this, but I feel this is hard..
Each scene should be longer than ~10s seconds.

Thanks, and may i ask where do you consider the scene-switching sceneario? Let's say for an iteration, the first frame are from the last scene while the rest two are from the next scene? Cause i havn't found the related code yet

Hi.
For the training. I don't need to change anything. Cause the tracking loss will figure out it is now the same instance.

For the inference, see: https://github.com/a1600012888/MUTR3D/blob/main/plugin/track/models/tracker.py#L674
I am using the change in the timestamp to tell if we ran into the next scene.

Thanks.

Thanks a lot.