dfm/extending-jax

nit on the need for transpose rules

shoyer opened this issue · 1 comments

I love the tutorial, thanks for putting together this great resource!

You write:

We're lucky in this case, and we don't need to add a "transpose rule", since JAX can actually work that out by itself (our JVP is linear in the tangents).

Every well behaved JVP is linear in the tangents, by definition -- the tangents are the "vector" part of "Jacobian-vector product."

What's special about this primitive (and which means that we don't need the transpose rule) is that it is non-linear. That means it that it can't appear in the tangent calculation (because again, output tangents are always a linear function of input tangents, per the chain rule), and only things that appear in the tangent calculations needs to be transposed to calculate cotangents (VJPs).

So the key question is actually whether a primitive is a linear function of one or more of its arguments. If so, then yes, you need a transpose rule to support reverse mode autodiff.

dfm commented

Good point - I'll try to fix the terminology shortly!