Rework ChainRules for DynamicExpressions
Opened this issue · 1 comments
avik-pal commented
DynamicExpressions supports ChainRules starting v0.17 SymbolicML/DynamicExpressions.jl#71. We can remove parts of our code with CRC.rrule_via_ad
. We still need to define a rule because we do an in-place node update. Additionally we need to extract the node parameters in the final parameter gradient.
avik-pal commented
Needs some investigation, I wasn't able to unthunk the Tangent coming from SymbolicML/DynamicExpressions.jl#71
This would need some further thought.
function Lux.__apply_dynamic_expression_rrule(
de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps)
Lux.__update_expression_constants!(expr, ps)
@static if pkgversion(DynamicExpressions) < v"0.17"
error("`DynamicExpressions` v0.17 or later is required for reverse mode to work.")
end
(y, _), pb_f = CRC.rrule(eval_tree_array, expr, x, operator_enum; de.turbo, de.bumper)
__∇apply_dynamic_expression = @closure Δ -> begin
_, ∂expr, ∂x, ∂operator_enum = pb_f((Δ, nothing))
∂ps = CRC.unthunk(∂expr).gradient
return NoTangent(), NoTangent(), NoTangent(), ∂operator_enum, ∂x, ∂ps, NoTangent()
end
return y, __∇apply_dynamic_expression
end
This works but we hit a clear regression on mixed-precision. Maybe once that is handled upstream we can use the rrule directly