taichi-dev/difftaichi

un-needed clear_states() in mass_spring.py

Closed this issue · 4 comments

In mass_spring.py optimization, clear() is called before each optimization iteration, but want to know if it needs to be.
https://github.com/yuanming-hu/difftaichi/blob/d5a1bb8b19eba4859e4c03d19c3ffdd39c4eeee8/examples/mass_spring.py#L323-L327

https://github.com/yuanming-hu/difftaichi/blob/d5a1bb8b19eba4859e4c03d19c3ffdd39c4eeee8/examples/mass_spring.py#L274-L281

I was trying to clear other gradient values besides those listed in clear_states(), and I realized that all of the gradients were set to 0 when I entered "with ti.Tape()". I looked into the taichi code, and I found that tape clears gradients by default.

https://github.com/taichi-dev/taichi/blob/b0b60a7da36ef2fb3a93924ebe8a44b4d2778622/python/taichi/lang/__init__.py#L266-L273

I figured I would just remove the call to clear() in my own code, but I wanted to double check before I did so. Is there another reason that clear() needs to be called? Or is it leftover code from older versions of taichi?

Nevermind, it does break things. I just need to figure out why some gradients are cleared while others are not.

Alright, I think I have figured it out. Tape does clear the gradients of all variables when you enter the with statement. The reason we still need clear_states() is because v_inc is a variable that accumulates values and it needs to be reset to 0 for each iteration.
v_inc[t, i] = ti.Vector([0.0, 0.0])

I have a meeting early in the afternoon, but I can submit a pull request later this evening. I have never contributed to an open source project before, so it might take a bit longer, but I'll figure it out.

Resolved in pull request #30