google-deepmind/tree

`map_structure` seems to be slower than python implementation

mattbarrett98 opened this issue · 1 comments

Implementing a basic recursive python implementation to do the same (or similar) is faster than map_structure:

import tree
from time import perf_counter
import collections

# recursive implementation
def nested_map(fn, nest):
    if isinstance(nest, list):
        return [nested_map(fn, v) for v in nest]
    elif isinstance(nest, tuple):
        return tuple(nested_map(fn, v) for v in nest)
    elif isinstance(nest, collections.abc.Mapping):
        return {k: nested_map(fn, v) for k, v in nest.items()}
    return fn(nest)

args = [1, 2, [3, 4], {"a": 5}]

s = perf_counter()
tree.map_structure(lambda x: x**2, args)
print(perf_counter() - s)  # 9.1e-5

s = perf_counter()
nested_map(lambda x: x**2, args)
print(perf_counter() - s)  # 9.6e-6

This library looks like a really useful tool but doesn't seem to give the results I would have expected.

map_structure is more generic and dealing with many more edge-cases than you do, so it is not surprising. For instance, if you pass a OrderedDict as input, in your case you would get a plain dict in output, which is wrong. Doing all the necessary checks and calling the right constructor in every possible scenario comes with a cost. In my view, tree is not about speed but rather versatility.