DependencyCRF marginals possible error
kmkurn opened this issue · 2 comments
kmkurn commented
Hi, while working on #63, I noticed that DependencyCRF marginals may have numerical errors:
>>> crf = DependencyCRF(torch.zeros(1,2,2))
>>> print(crf.partition.exp().item())
3.0
>>> crf.marginals.exp()
tensor([[[1.9477, 1.3956],
[1.3956, 1.9477]]], grad_fn=<ExpBackward>)
crf.partition
is correct; there are 3 trees. Since all edges have weight 1, I'd expect the marginals to be (very close to) 2 on diagonals, and 1 on off-diagonals. But they're not. Is this an error or am I misunderstanding something?
srush commented
I don't think there is a bug.
import torch
crf = torch_struct.DependencyCRF(torch.zeros(1,2,2))
print(crf.partition.exp().item())
crf.marginals
3.0
tensor([[[0.6667, 0.3333],
[0.3333, 0.6667]]], grad_fn=<SqueezeBackward1>)
Marginals are just probabilities (not in log space). p(arc | x)
I should have called the partition -> log_partition for accuracy. I will update that.
kmkurn commented
I see. Thanks for the clarification!