google/learned_optimization

jnp.sign(mean_rms) is always 1

OhadRubin opened this issue · 0 comments

In the list of features for nn_adam, you are using this feature, but if I understand correctly, this feature is always 1.

inputs["mean_sign"] = jnp.sign(mean_rms)