Dependency CRF add labels
allanj opened this issue · 5 comments
Just check the documentation of dependency CRF, it seems that the dependency labels (i.e. relations) are not considered (yet). Am I right?
Good point. I will add labels to the class.
Hi, is there an update on this? I'd love to use it for labeled parsing as well.
If not, are you aware of workarounds especially for marginals? I mean, for computing partition and argmax, we can simply do logsumexp or max over the labels, and then use DependencyCRF
. Is there a similar trick for marginal probabilities?
@kmkurn let's add labels. Would you be interested in trying it?
It's very straightforward. Simply take change log-potentials of the form BATCH x N x N x LABELS
if they give a BATCH x N x N tensor convert to to BATCH x N x N x 1.
Before running the main algorithm, on this line
https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/deptree.py#L53
Call arc_scores = semiring.sum(arc_scores)
. That will logsumexp/max out over the labels just as you say.
Merged.