yhtang/FunFact

[REQ] Support for access to internal variables + plugin output in factorize algorithm

Opened this issue · 0 comments

For a historic_convergence plugin, we would want to save the loss, penalty and combination of both for every vector instance separately.

This is currently not possible externally as the GradientDescentState only stores the combined loss over all vector instances.

This can be overcome by defining the plugin inside the factorize method:

    hc = []
    '''define plugin for historic convergence data'''
    @gradient_descent_plugin(every=1)
    def historic_convergence(state: GradientDescentState):
        # TODO: use external validation set
        loss_val = ab.to_numpy(
            loss(fac(), target, sum_vec=False, vectorized_along_last=append)
        )
        penalty_val = ab.to_numpy(
            fac.penalty(sum_leafs=True, sum_vec=False)
        )

        for i, lp in enumerate(zip(loss_val, penalty_val)):
            hc.append(
                dict(
                    step=state.step,
                    vec=i,
                    loss=lp[0],
                    penalty=lp[1],
                    loss_and_penalty=lp[0] + penalty_weight * lp[1]
                )
            )

It would be good to have a way to actually do this as an external plugin.

Secondly, for the historic_convergence plugin the user would want access to the hc data. So there should be a way to let a plugin return information as well.