FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
mainguyenanhvu opened this issue · 1 comments
I think you should change the code to prevent the warning.
/home/tools/umol_package/Umol/src/net/model/mapping.py:49: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
values_tree_def = jax.tree_flatten(values)[1]
/home/tools/umol_package/Umol/src/net/model/mapping.py:53: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
return jax.tree_unflatten(values_tree_def, flat_axes)
/home/tools/umol_package/Umol/src/net/model/mapping.py:124: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
flat_sizes = jax.tree_flatten(in_sizes)[0]
Hi, you are right - thanks. Done.