snap-stanford/GEARS

A question about the direction loss

zhan8855 opened this issue · 3 comments

Hi, it seems the direction loss is not working since torch.sign blocks the backward of gradient.

https://github.com/snap-stanford/GEARS/blob/master/gears/utils.py#L388

Here is a toy experiment on my local machine with torch version 2.0.0:

import torch
for i in range(-5, 5):
... x = torch.tensor(i, dtype=float, requires_grad=True)
... y = torch.sign(x)
... y.backward()
... print(x.grad)
...
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)

Thank you for raising this. Indeed it does look like pytorch zeros out the gradient through this layer. I'm surprised we hadn't noticed this before. I think the understanding was that it automatically uses some continuous approximation of the sign function like tanh when computing the gradients.

I'm continuing to run some tests and will experiment with explicitly changing the code to use torch.tanh instead. Will update repo based on results

It looks like torch.tanh works fine and seems to even improve performance. I will run tests over all the standard datasets before updating the repo but you can make the change locally in your own repo if you like

Thank you very much for your insightful suggestion!