rusty1s/pytorch_sparse

Branching gradient flow

KukumavMozolo opened this issue · 4 comments

Hi Rusty1s!,
I was wondering if there is a way to have a branching that doesn't brake gradients with your library, e.g. like this using torch.where

x = torch.where(z==0,y,x,)

Thank you for this great project, it has been a live saver so far.

Sorry, can you clarify what you mean exactly?

Let me try:
Let x be a sparse tensor, and f(x) = y, some transformation that potentially makes y less sparse,
then i want to ensure that wherever x was 0 y is also 0.
with a dense tensor it can be done like this:

y = torch.where(x==0,torch.tensor([0]),y)

So is there a way to do the same thing if x and y are both sparse tensors?

It looks like what you wanna do is take the intersection of two sparse matrices, where for duplicated values, you wanna decide for the value in y. This is generally possible (although not through torch.where syntax):

row, col, _ = x.coo()
flat_index_x = row * x.size(1) + col

row, col, value = y.coo()
flat_index_y = row * y.size(1) + col

mask = torch.isin(x, y)

# filter y by intersection:
y = SparseTensor(row=row[mask], col=col[mask], value=value[mask])

This works great, unfortunately in my case this mask = torch.isin(flat_index_x, flat_index_y)is quite slow since one array is quite large so as it turned out using torch.where is actually faster.
Anyhow thank you very much for your suggestion!