harvardnlp/pytorch-struct

Dependency CRF add labels

allanj opened this issue · 5 comments

Just check the documentation of dependency CRF, it seems that the dependency labels (i.e. relations) are not considered (yet). Am I right?

srush commented

Good point. I will add labels to the class.

Hi, is there an update on this? I'd love to use it for labeled parsing as well.

If not, are you aware of workarounds especially for marginals? I mean, for computing partition and argmax, we can simply do logsumexp or max over the labels, and then use DependencyCRF. Is there a similar trick for marginal probabilities?

srush commented

@kmkurn let's add labels. Would you be interested in trying it?

It's very straightforward. Simply take change log-potentials of the form BATCH x N x N x LABELS if they give a BATCH x N x N tensor convert to to BATCH x N x N x 1.

Before running the main algorithm, on this line
https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/deptree.py#L53

Call arc_scores = semiring.sum(arc_scores). That will logsumexp/max out over the labels just as you say.

@srush Sure. I've made a draft PR on this, please check it out!

srush commented

Merged.