LuxDL/Boltz.jl

Rework ChainRules for DynamicExpressions

Opened this issue · 1 comments

DynamicExpressions supports ChainRules starting v0.17 SymbolicML/DynamicExpressions.jl#71. We can remove parts of our code with CRC.rrule_via_ad. We still need to define a rule because we do an in-place node update. Additionally we need to extract the node parameters in the final parameter gradient.

Needs some investigation, I wasn't able to unthunk the Tangent coming from SymbolicML/DynamicExpressions.jl#71

This would need some further thought.

function Lux.__apply_dynamic_expression_rrule(
        de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps)
    Lux.__update_expression_constants!(expr, ps)
    @static if pkgversion(DynamicExpressions) < v"0.17"
        error("`DynamicExpressions` v0.17 or later is required for reverse mode to work.")
    end
    (y, _), pb_f = CRC.rrule(eval_tree_array, expr, x, operator_enum; de.turbo, de.bumper)
    __∇apply_dynamic_expression = @closure Δ -> begin
        _, ∂expr, ∂x, ∂operator_enum = pb_f((Δ, nothing))
        ∂ps = CRC.unthunk(∂expr).gradient
        return NoTangent(), NoTangent(), NoTangent(), ∂operator_enum, ∂x, ∂ps, NoTangent()
    end
    return y, __∇apply_dynamic_expression
end

This works but we hit a clear regression on mixed-precision. Maybe once that is handled upstream we can use the rrule directly