[REQ] Support for access to internal variables + plugin output in factorize algorithm
Opened this issue · 0 comments
campsd commented
For a historic_convergence
plugin, we would want to save the loss
, penalty
and combination of both for every vector instance separately.
This is currently not possible externally as the GradientDescentState
only stores the combined loss over all vector instances.
This can be overcome by defining the plugin inside the factorize
method:
hc = []
'''define plugin for historic convergence data'''
@gradient_descent_plugin(every=1)
def historic_convergence(state: GradientDescentState):
# TODO: use external validation set
loss_val = ab.to_numpy(
loss(fac(), target, sum_vec=False, vectorized_along_last=append)
)
penalty_val = ab.to_numpy(
fac.penalty(sum_leafs=True, sum_vec=False)
)
for i, lp in enumerate(zip(loss_val, penalty_val)):
hc.append(
dict(
step=state.step,
vec=i,
loss=lp[0],
penalty=lp[1],
loss_and_penalty=lp[0] + penalty_weight * lp[1]
)
)
It would be good to have a way to actually do this as an external plugin.
Secondly, for the historic_convergence
plugin the user would want access to the hc
data. So there should be a way to let a plugin return information as well.