coreylowman/dfdx

Please add a split method, to iterate over subtensors of the given tensor

emchristiansen opened this issue · 7 comments

Please add a method for splitting on a given axis, returning an iterator over subtensors selected in that axis.

For example, suppose mytensor is this 2x3 tensor:

[
  [1, 2, 3],
  [4, 5, 6],
]

Then mytensor.split(0) should return the iterator of length 2 containing tensors [1, 2, 3] and [4, 5, 6], and mytensor.split(1) should return the iterator of length 3 containing tensors [1, 4], [2, 5], and [3, 6].

Note, this is close to what select does (IIUC), but select consumes its input and so can't be used to construct an iterator like this.

Also, if the general case doesn't seem worth it, I'd be happy just having the function split_0 which would just split on the first axis.

Thanks!

BTW if you give me a pointer on how to do this I'd be happy to give it a go.

Note, this is close to what select does (IIUC), but select consumes its input and so can't be used to construct an iterator like this.

Yeah split can be used for this. You'd have to clone the tensor since as you state it takes ownership of input.

I think no matter what since Tensors currently own their data, we wouldn't be able to create a iterator over subslice references.

You could probably create an iterator that uses clone for each subslice though? 🤔 Might be worth sketching out what the API for it would look like using cloning

Since the data is stored behind Arc, cloning is very cheap

Yeah split can be used for this. You'd have to clone the tensor since as you state it takes ownership of input.

You mean select, right?

I'd be happy with an iter containing owned slices (not references).
But the only way I see to do this now requires doing N - 1 clones, where N is the iter length.
(But maybe it doesn't matter if clones are cheap?)

I'll try to write a function that is semantically what I'm looking for, albeit inefficient.

Yep select sorry 😁

And yeah cloning has a very tiny cost since its just cloning the Arc (so just ups the reference count). The data itself isn't actually cloned

I took a stab at split_0 here, which I'm calling unstack for symmetry: #831

swfsql commented

Nice! I'm interested in your unstack impl, and I hope to try it out soon.

I'm actually interested in the opposite of a concatenate operation, eg. (6, 2) -> [(3, 2), (3, 2)], but after reading #43 (comment) I understood that it applies both ways. I could just first reshape (6, 2) -> (2, 3, 2) and then run unstack, getting [(3, 2), (3, 2)].

As a side note, some interesting new archs may require those (or similar) operations. My motivation so far is this candle/chunk operation.

swfsql commented

I've linked a pr because I think I'd need something different from a shape reduction, because some tensors may require to be split into non-equal dimensions over the same axis and still keep other axes the same. Or in other words, the opposite of concat instead of the opposite of stack.