google/jax

jax.tree_utils do not keep dict key order

Conchylicultor opened this issue ยท 18 comments

With Python 3.6+, dict are guarantee to keep insertion order, similarly to OrderedDict.

Both deepmind tree and tf.nest keep dict order, but jax.tree_util does not.

import tensorflow as tf
import tree
import jax

data = {'z': None, 'a': None}

print(tf.nest.map_structure(lambda _: None, data))  # {'z': None, 'a': None}
print(tree.map_structure(lambda _: None, data))     # {'z': None, 'a': None}
print(jax.tree_map(lambda _: None, data))           # {'a': None, 'z': None}  << Oups, keys order inverted

The fact that dict and OrderedDict behave differently when dict guarantee to keep insertion order feel inconsistent.

Hmmm... Looks like ordereddict and defaultdict keep keys in order:

jax/jax/tree_util.py

Lines 247 to 255 in bf041fb

register_pytree_node(
collections.OrderedDict,
lambda x: (list(x.values()), list(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
register_pytree_node(
collections.defaultdict,
lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))),
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))

While standard dicts alphabetize their keys:

jax/jaxlib/pytree.cc

Lines 132 to 134 in c7aff1d

if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}

That is surprising behavior: it would be nice to make this more consistent.

I agree, we should not sort dictionary keys.

To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?

That is, I think the OP may have been mistaken, according to the Python docs:

Changed in version 3.7: Dictionary order is guaranteed to be insertion order. This behavior was an implementation detail of CPython from 3.6.

To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?

This is still the case even in Python 3.7+. Dictionaries and dictionary keys preserve order, but compare equal if they have the same elements regardless of order:

In [14]: x = {1: 1, 2: 2}

In [15]: y = {2: 2, 1: 1}

In [16]: x
Out[16]: {1: 1, 2: 2}

In [17]: y
Out[17]: {2: 2, 1: 1}

In [18]: x == y
Out[18]: True

In [19]: x.keys() == y.keys()
Out[19]: True

In [20]: x.keys()
Out[20]: dict_keys([1, 2])

In [21]: y.keys()
Out[21]: dict_keys([2, 1])

I believe Jax should behave like dm-tree where flattening any dict sort the keys, but packing dict restore the original dict key order.

import tree

x = {'z': 'z', 'a': 'a'}

print(tree.flatten(x))  # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1]))  # Key order restored: {'z': 1, 'a': 0}

This allow all dict to have the same flattened representation, to be mixed together:

import jax

d0 = {'z': 'z', 'a': 'a'}
d1 = collections.defaultdict(int, d0)

assert jax.tree_leaves(d0) == jax.tree_leaves(d1)  # AssertionError: Oups ['z', 'a'] != ['a', 'z']

@Conchylicultor good point! That sounds plausible.

@mattjj Would it be a good idea to register dict as node similar to what is done for ordereddict and defaultdict?

register_pytree_node(
  dict,
  lambda x: (tuple(x.values()),  {key: None for key in x}.keys()),
  lambda keys, values: dict(safe_zip(keys, values))
)

Of course things related to kDict should also be removed from jaxlib/pytree.{h,cc}

Adding my support for maintaining dict key ordering over flattening operations.

A related question: for a dictionary d, does tuple(d.values()) internally use tree_flatten? Because that operation also does not maintain key ordering when building the tuple.

Now that support Python 3.6 has been dropped, we can probably revisit this.

I believe this is an important bug to be fixed. Have we figured out why this is happening (or may I know where is the relevant source code of the C++ implementation so I can try digging into)? I think the behavior should be constant regardless of python 3.6, but it'd be great to revisit this one.

@wookayin - it's happening because that's how tree flattening of dicts is implemented. The line where the sort is taking place is here: https://github.com/tensorflow/tensorflow/blob/eb8425f115e5a93274f709cdfaf254798f9aa4c7/tensorflow/compiler/xla/python/pytree.cc#L167

The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2), which has deep and subtle implications in the implementation of JAX transforms throughout the package.

For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).

@jakevdp I don't understand your argument. This was already resolved in #4085 (comment)

Flattening would still be sorted, so if d1 == d2, then tree_flatten(d1) == tree_flatten(d2), irrespectively of the d1 and d2 key order.

However, the key order would be restored during packing:

x = {'z': 'z', 'a': 'a'}

print(tree.flatten(x))  # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1]))  # Key order restored: {'z': 1, 'a': 0}

So all dict (OrderedDict,...) would have the same flattened representation, but would still preserve the keys order when unflattened.

then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2)

This is exactly the problem with the current Jax implementation:

d0 = {'z': 'z', 'a': 'a'}
d1 = collections.OrderedDict(d0)

assert d0 == d1  # Works
assert jax.tree_leaves(d0) == jax.tree_leaves(d1)  # << AssertionError: Oups ['z', 'a'] != ['a', 'z']

# By comparison, DM `tree` / tf.nest works as expected:
assert tree.flatten(d0) == tree.flatten(d1)  # Works: ['a', 'z'] == ['a', 'z']

That makes sense, thanks. I'd missed that comment from a few years ago.

It seems that there's broad agreement here that this should be fixed โ€“ we just need someone to take on the project.

Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example:

jax.tree_util.dict_key_sorted(True)  # default behavior
jax.tree_util.dict_key_sorted(False)

In this issue, all the keys are strings, which are sortable. There is another issue about dict key sorting #11871. For a general PyTree, the keys are not always comparable:

tree = {1: '1', 'a': 'a'}

sorted(tree)  # <- TypeError: '<' not supported between instances of 'str' and 'int'

Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example: (breaks referential transparency)

jax.tree_util.dict_key_sorted(True)  # default behavior
jax.tree_util.dict_key_sorted(False)

The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2), which has deep and subtle implications in the implementation of JAX transforms throughout the package.

For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).

I agree, referential transparency should be a key feature for the pytree utilities: equal inputs implies equal outputs. However, the current implementation always sorts the key order and returns a new sorted dict after unfattening. Nowadays, many Python code rely on the builtins.dict is guaranteed insertion order. This may cause potential bugs that many people do not aware this behavior in JAX pytree (sorted keys after tree_unflatten).

d = {'b': 2, 'a': 1}
# Map with the identity function changes the key order
out = jax.tree_util.tree_map(lambda x: x, d)  # => {'a': 1, 'b': 2}
d == out  # => True
list(d) == list(out)  # => False  ['b', 'a'] != ['a', 'b']

For example, use tree_map to process kwarges (PEP 468 โ€“ Preserving the order of **kwargs in a function.):

def func(*args, **kwargs):
    args, kwargs = jax.tree_util.tree_map(do_something, (args, kwargs))  # changes key order in kwargs
    ...
In [1]: import jax

In [2]: from typing import NamedTuple

In [3]: class Ints(NamedTuple):
   ...:     foo: int
   ...:     bar: int
   ...:     

In [4]: Ints(1, 2)
Out[4]: Ints(foo=1, bar=2)

In [5]: Ints(1, 2).foo
Out[5]: 1

In [6]: Ints.__annotations__
Out[6]: {'foo': <class 'int'>, 'bar': <class 'int'>}

In [7]: Floats = NamedTuple('Floats', **jax.tree_util.tree_map(lambda ann: float, Ints.__annotations__))

In [8]: Floats(1.0, 2.0)
Out[8]: Floats(bar=1.0, foo=2.0)

In [9]: Floats(1.0, 2.0).foo
Out[9]: 2.0

One solution is to store the input dict keys in insertion order in Node during flatten, and update the PyTreeDef.unflatten method to respect the key order while reconstructing the output pytree.

leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1})
leaves   # [1, 2]
treedef  # PyTreeDef({'a': *, 'b': *})
treedef.unflatten([11, 22])  # {'b': 22, 'a': 11} # respect original key order

Ref:

Commenting to add that I just encountered this behavior and I find it quite annoying.

If I was to implement this, I'd use as treedef a dictionary with the same keys but filled with None values. This way the implementation would completely piggyback on Python and remain consistent under all circumstances.