Stopping condition 'madsen-nielsen' incorrect
Joshuaalbert opened this issue · 0 comments
Joshuaalbert commented
The documentation says a different thing than code. Specifically, the -
is inconsistent with the +
in docstring at this part (tree_l2_norm(params) - self.xtol)
.
Docstring says:
the convergence is achieved once the
coeff update satisfies ``||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or
the gradient satisfies ``||grad(f)||_inf <= gtol``.
Code says:
tree_mul_term = self.xtol * (tree_l2_norm(params) - self.xtol)
return jnp.all(jnp.array([
tree_inf_norm(state.gradient) > self.gtol,
tree_l2_norm(state.delta) > tree_mul_term
]))
Additionally, rather than all(array(...))
you should use jnp.bitwise_and(..., ...)
or | & and ~ ops.
My suggestion
Upon reading up about madsen-nielsen
stopping condition it seems that there is no single version of it. From my optimisation work I find incorporating both absolute and relative tolerance in parameter changes is quite useful. (Currently it looks like it's only relative)
def leaves_vec(tree_x):
return jnp.concatenate(tree_leaves(tree_map(jnp.ravel, tree_x)))
atol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.atol)
rtol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.rtol * jnp.abs(tree_vec(params)))
grad_cond = jnp.max(jnp.abs(leaves_vec(state.gradient))) <= self.gtol
done = atol_cond | rtol_cond | grad_cond
return ~done
# defaults
atol = 0. # effectively turned off unless user wants it on, to be backward compatible with current.
rtol = 1e-3
gtol = 1e-3