pytorch/tensordict

[Feature Request] DoubleDict class to use as TensorDict backend

vmoens opened this issue · 3 comments

Several methods in tensordict are notoriously slow. Many of them are being slowed down by instance checks over the items gathered from the td, since treatment of a TensorDict-like and a tensor can differ drastically.

To solve this problem, I propose 2 solutions:

  1. Use a DoubleDict class, which stores internally 2 dictionaries (one for tensors, one for tensordicts)
class doubledict:
  def __init__(self, **kwargs):
      self._tensor_dict = {}
      self._dict_dict = {}
      for key, item in kwargs.items():
          if isinstance(item, torch.Tensor):
              self._tensor_dict[key] = item
          else:
              self._dict_dict[key] = item
  
  def __getitem__(self, key):
      result = self._tensor_dict.get(key, None)
      if result is None:
          return self._dict_dict[key]
      return result

  def __setitem__(self, key, value):
      if isinstance(value, torch.Tensor):
          self._tensor_dict[key] = value
          self._dict_dict.pop(key, None)
      else:
          self._dict_dict[key] = value
          self._tensor_dict.pop(key, None)

  def __iter__(self):
      yield from self._tensor_dict
      yield from self._dict_dict

  def keys(self):
      return tuple(self._tensor_dict.keys()) + tuple(self._dict_dict.keys())

  def keys_tensors(self):
      return self._tensor_dict.keys()

  def items(self):
      yield from self._tensor_dict.items()
      yield from self._dict_dict.items()

  def items_tensors(self):
      return self._tensor_dict.items()

Expected slowdown: writing and reading entries will be slow because we must attempt an access in both dicts every time.

Expected faster ops: iterating over tensors or tensordicts should be faster.

There could be a way of accelerating __getitem__/__setitem__ (maybe pop is too slow for instance?)
A C++ version of this class could also be written and would be a first step towards a C++ based tensordict (cc @fedebotu)

  1. Another way of solving this problem is to store a boolean with the item that indicates its type:
def __setitem__(self, key, value):
     self[key] = (isinstance(value, Tensor), value)

That way, while we iterate over the keys we can immediately know if the value is or isn't a leaf. The problem with this solution is that retrievign all leaves or all tensordicts still requires to go through the entire dict, but at least the instance check is done only once.

class tddict(dict):
  def __setitem__(self, key, value):
      return super().__setitem__(key, (isinstance(value, torch.Tensor), value))

  def __init__(self, **kwargs):
      for key, val in kwargs.items():
          self[key] = val

I ran some quick experiments with this:

kwargs = {str(i): torch.zeros(()) for i in range(100)}
kwargs.update({str(i): i for i in range(100, 200)})

d0 = doubledict(**kwargs)

from copy import copy
d1 = copy(kwargs)

d2 = tddict(**kwargs)

tensor = torch.zeros(())

%%timeit
d0["0"] = tensor
# 256 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

%%timeit
d1["0"] = tensor
# 58.6 ns ± 2.84 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

%%timeit
d2["0"] = tensor
# 406 ns ± 4.81 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

%%timeit
list(d0.items_tensors())
# 2.66 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%%timeit
list((item for item in d1.items() if isinstance(item[1], torch.Tensor)))
# 32.4 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
list(((key, val[1]) for key, val in d2.items() if val[0]))
# 18.3 µs ± 944 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The gain with solution 2 for querying items isn't fabulous and the __setitem__ is actually slower than solution 1!

A potential downside of doubledict is that we lose the information about the order of insertion of keys (so we can't return items in an insertion-sorted manner) though maybe insertion-order was never maintained anyways ?

That's a valid point!

Items would be "doubly" ordered (first tensors, then tensordicts, each ordered by arrival).

Given that we have this ordering, we could make it so that tensordict.keys(include_nested=True) provides first all first-order entries, then second, then third etc. It wouldn't be hard also to do checks like if ("my", "nested", "key") in tensordict.keys(True): ... faster given that we can directly fetch the 3rd level with some extra code on the basis of the above code (a check like this is currently pretty expensive IIRC).

Currently the order is kept and consistent (we discussed this with @tcbegley back in the days), but it isn't used anywhere in our code base (users may rely on this but I don't think we made it a feature so it's a gray area I guess).

well results aren't super encouraging lol