How to apply force during for-loop in diffmpm
lynamer opened this issue · 3 comments
Take the diffmpm_simple.py as an example, I was thinking to add an additional force during the p2g step, that is,
(at line 87)
grid_v_in[f, base + offset] += weight * (p_mass * v[f, p] - dt * x.grad[f, p] + affine @ dpos)
where x.grad is supposed to be
However, since we already use ti.Tape(loss=loss)
to store the gradient of init_v, I am wondering how to get both
Here are some more detailed information about my question, I first tried the Nested ti.Tape()
to record both the total_energy and loss:
with ti.Tape(loss=loss):
set_v()
for s in range(steps - 1):
with ti.Tape(loss=total_energy):
compute_total_energy(s)
substep(s)
compute_x_avg()
compute_loss()
Then there was no output in the program
Then I tried to use func.grad() instead of the AutoDiff function:
# foward
for s in range(total_steps - 1):
clear_grid()
total_energy[None] = 0
compute_total_energy(s)
total_energy.grad[None] = 1
p2g(s)
grid_op(s)
g2p(s)
x_avg[None] = [0, 0]
compute_x_avg()
compute_loss()
# backward
clear_particle_grad()
compute_loss.grad()
compute_x_avg.grad()
for s in reversed(range(steps - 1)):
clear_grid()
p2g(s)
grid_op(s)
g2p.grad(s)
grid_op.grad(s)
p2g.grad(s)
This still could not work since all the gradients were zero.
Hi @lynamer ,
I think nested tape is not how it designed to use for (i.e., may induce unexpected behavior).
The tape is designed for the scenario with one loss
. For multiple losses
, I would recommend to use kernel.grad()
, i.e. reverse the call trace manually. Here is the code I modified according to your case:
# Forward
set_v()
for s in range(steps - 1):
# total_energy.grad[None] = 1
compute_total_energy(s)
# compute_total_energy.grad(s)
substep(s)
x_avg[None] = [0, 0]
compute_x_avg()
compute_loss()
# Backward
loss.grad[None] = 1
total_energy.grad[None] = 1
clear_particle_grad()
compute_loss.grad()
compute_x_avg.grad()
for s in reversed(range(steps - 1)):
g2p.grad(s)
grid_op.grad(s)
p2g.grad(s)
compute_total_energy.grad(s)
set_v.grad()