relationship to torch.autograd.functional.jacobian
fhung65 opened this issue · 2 comments
Out of curiosity, what are the differences between pytorch's auto-grad jacobian functionality to the SerialChain jacobian function?
And, given that we can backprop through forward kinematics, is there a way to reproduce the calculation using the auto-grad jacobian?
# for example, comparing j1 and j2:
def get_pt(th):
return chain.forward_kinematics(th).transform_points(torch.zeros((1,3)))
j1 = torch.autograd.functional.jacobian(get_pt, input=th)
j2 = chain.jacobian(th)
For one thing, j1 here is probably not set up correctly to include angular change, it ends up with shape (N, 3, DOF),
while j2 ends up with shape (N,6,DOF)
oh, I might have found a cause.
My chain's urdf file has some interesting values for axis, and the joint axis isn't used in jacobian.py for revolute joints.
a patch that made things work for me was replacing the d and delta revolute calculations with:
d = torch.einsum(
'ni,nij->nj' ,
torch.cross(f.joint.axis[None,], cur_transform[:,:3,3]) ,
cur_transform[:,:3,:3]
)
delta = torch.einsum('i,nij->nj', f.joint.axis, cur_transform[:,:3,:3])
@fhung65 The Jacobian calculation has changed quite a bit but I think the changes should have resolved issues you had for revolute joints. I added some tests to compare against the autograd computation of Jacobian, the difference being:
- PK Jacobian is faster and memory efficient
- PK Jacobian computes the angular Jacobian as well
In terms of speed, we have
for N=1000 on cuda autograd:1060.3219540007558ms
for N=1000 on cuda functorch:16.28805300060776ms
for N=1000 on cuda pytorch-kinematics:8.923257000787999ms
For autograd, we're actually computing the batch Jacobian of all N x 3 point outputs against all N x DOF inputs. The ith point only has a non-zero Jacobian wrt the ith input so this is mostly wasted.
j1 = torch.autograd.functional.jacobian(get_pt, inputs=th, vectorize=True)
To avoid this, you can try functorch
which will batch the Jacobian calculations and has other accelerations in the background:
grad_func = functorch.vmap(functorch.jacrev(get_pt))
j3 = grad_func(th).squeeze(1)
Which will slow be slower and also uses dramatically more memory.