New interface does not support custom empty pytree class inherited from dict
ZaberKo opened this issue · 1 comments
ZaberKo commented
Reproduction code:
class PyTreeDict(dict):
pass
jax.tree_util.register_pytree_node(
PyTreeDict,
lambda d: (tuple(d.values()), tuple(d.keys())),
lambda keys, values: PyTreeDict(dict(zip(keys, values)))
)
a={"a": PyTreeDict()} # ValueError: Expected dict, got {}.
# a=PyTreeDict() # ValueError: Found empty item
path = ocp.test_utils.erase_and_create_empty('./debug').resolve()/'ckpt'
ckpt.save(path, a)
ckpt.restore(path, args=ocp.args.StandardRestore(a))
This issue is related to #720 and a066d9c.
@niketkumar
cpgaffney1 commented
Thanks for reporting, we're looking some refactoring that will resolve these empty node issues.