[BUG] Cannot override __getitem__ or __setitem__ for tensorclass with non-tensor data
alexanderswerdlow opened this issue · 6 comments
Describe the bug
As described here, it should be possible to override getitem and setitem to implement indexing for non-tensor data types.
To Reproduce
Following the linked issue:
import tensordict
@tensorclass
class MyClass:
images: torch.Tensor
captions: List[str]
def __getitem__(self, item):
c = super().__getitem__(item)
c.captions = self.captions[item]
return c
data = MyClass(torch.randn(2, 3, 64, 64), ["a", "b"], batch_size=[2])
print(data[0])
Expected behavior
The printed class to contain captions=["a"]
.
Additional context
It appears that monkey patching works if I copy and modify _getitem
from tensorclass.py.
MyClass.__getitem__ = _getitem
In addition, this might be outside the scope of this issue and a separate feature request, but it seems that even with this monkey patching, stacking/concat does not properly handle this non-tensor data. Is this something that is supported?
I very often have non-tensordata that is associated with batch elements [e.g., a caption for each image] and it would be very weird to cat/stack and have the captions be unchanged.
Oddly I was just working on that :)
In general, I think #663 should solve this issue more generally but we can patch it (though I'm not sure of how I can make super()
to work since there is no real inheritance...)!
Wow thanks for the fast response! I'd be happy to test it out when you have a working version :)
Now your new method will be used but super
won't work. I can make a follow-up PR to make this a documented option. Because we don't explicitly inherit from anything but use a decorator instead, I think that making super()
to work won't be a good idea anyway since it'll break python convention and people will end up with a class that inherits from something they never asked for (in a way, you don't expect super()
to do anything with a dataclass, why should it?).
To me the API should me more something like
@tensorclass
class MyClass:
X: torch.tensor
def __setitem__(self, name, value):
# your code here
self.__tensorclass_function__().__setitem__(name, value)
or
self.__tensorclass_function__("__setitem__", name, value)
whichever people think is more appropriate.
@vmoens Thanks so much. I tried playing around a bit but I wasn't able to get things working as I'd expect. Perhaps my use-case just isn't supported but just wanted to let you know what I encountered.
I wasn't able to use __tensorclass_function__
as you described (perhaps this was meant to be a generic statement, e.g., $tensorclass_function
?), but I was able to override e.g. __getitem__
as shown below. I wasn't sure how to make things work for stack/cat, etc. so I tried overriding the _get and _get_at functions but that didn't seem to do it.
from tensordict import tensorclass
from tensordict.tensorclass import _getitem, _setitem, _get, _get_at, _set, _set_at_
import torch
def my_setitem(self, item, value):
print(f"Called setitem with {item} and type {type(value)}")
_setitem(self, item, value)
if isinstance(self.captions, list):
self.captions[item] = value.captions
else:
raise ValueError(f"Invalid type for captions: {type(item)} and {type(self.captions)} and {type(value.captions)}")
def my_getitem(self, name):
print(f"Called getitem with {name}")
obj = _getitem(self, name)
obj.captions = self.captions[name]
return obj
def my_get(self, key):
print("Called get")
obj = _get(self, key)
return obj
def my_get_at(self, key, idx):
print("Called get_at")
obj = _get_at(self, key, idx)
return obj
def my_set_at_(self, key, value, idx):
print("Called set_at_")
_set_at_(self, key, value, idx)
def my_set(self, key, value):
print("Called set")
_set(self, key, value)
@tensorclass
class MyClass:
images: torch.Tensor
captions: List[str]
MyClass.__setitem__ = my_setitem
MyClass.__getitem__ = my_getitem
MyClass.get__ = my_get
MyClass.get_at = my_get_at
# MyClass.set = my_set # Causes an error
MyClass.set_at_ = my_set_at_
data = MyClass(torch.ones(4, 3, 64, 64), ["a", "b", "c", "d"], batch_size=[4])
data.images = torch.randn(4, 3, 64, 64) # Works
data.captions = ["d", "c", "b", "a"] # Works
data[0].images = torch.randn(3, 64, 64) # Works
data[0].captions = "a" # Understandably does not work but unintuitive
print(data[0].captions) # Prints "d"
data[1] = MyClass(torch.randn(3, 64, 64), "b", batch_size=[]) # Works
data[2:4] = MyClass(torch.randn(2, 3, 64, 64), ["e", "f"], batch_size=[2]) # Works
print(data[2:4].captions) # Prints ["e", "f"]
cat_data = torch.cat([data, data], dim=0) # Does not modify captions
stack_data = torch.stack([data, data], dim=0) # Does not modify captions
breakpoint()
I wasn't able to use tensorclass_function as you described (perhaps this was meant to be a generic statement, e.g., $tensorclass_function?), but I was able to override e.g. getitem as shown below. I wasn't sure how to make things work for stack/cat, etc. so I tried overriding the _get and _get_at functions but that didn't seem to do it.
The important part of my message was this
Now your new method will be used but super won't work. I can make a follow-up PR to make this a documented option.
i.e. I did not implement it yet. You can overwrite the function but you won't be able to call super() or anything similar as of now. I was asking for feedback about the feature before implementing it :)
Ah my bad! I think it's a good workaround (personally I like the first option where it's called directly and not with the string name, as that seems unfriendly to type checking), but it's not immediately clear as an end-user how it would affect the variety of tensor operations that tensordict supports.
I think the extent of supported operations for non-tensor data should be well-defined [e.g., you need to implement these 4 methods to cover all supported operations] and some are not possible, ideally there's a [disableable] warning message that's shown when you perform an unsupported operation. Otherwise, I'd be concerned that some operations work just fine but there are some unexpected corner cases that cause surprises. In the past I've implemented my own very basic version of a tensorclass which only supported a couple operations but worked in a very understandable way which I think is critical.
I'm not sure how complicated it would be to support all the different kinds of advanced indexing possible on tensors. One possibility I could think of would be to only support modifying non-tensordata for a single batch dimension although I can easily see situations that I'd want to implement e.g., 2 dimensions.