Branching gradient flow
KukumavMozolo opened this issue · 4 comments
Hi Rusty1s!,
I was wondering if there is a way to have a branching that doesn't brake gradients with your library, e.g. like this using torch.where
x = torch.where(z==0,y,x,)
Thank you for this great project, it has been a live saver so far.
Sorry, can you clarify what you mean exactly?
Let me try:
Let x be a sparse tensor, and f(x) = y, some transformation that potentially makes y less sparse,
then i want to ensure that wherever x was 0 y is also 0.
with a dense tensor it can be done like this:
y = torch.where(x==0,torch.tensor([0]),y)
So is there a way to do the same thing if x and y are both sparse tensors?
It looks like what you wanna do is take the intersection of two sparse matrices, where for duplicated values, you wanna decide for the value in y
. This is generally possible (although not through torch.where
syntax):
row, col, _ = x.coo()
flat_index_x = row * x.size(1) + col
row, col, value = y.coo()
flat_index_y = row * y.size(1) + col
mask = torch.isin(x, y)
# filter y by intersection:
y = SparseTensor(row=row[mask], col=col[mask], value=value[mask])
This works great, unfortunately in my case this mask = torch.isin(flat_index_x, flat_index_y)
is quite slow since one array is quite large so as it turned out using torch.where is actually faster.
Anyhow thank you very much for your suggestion!