HazyResearch/metal

Make special (but common) case of k=2 more efficient

bhancock8 opened this issue · 1 comments

For k > 2, it's good to have our task head output k logits that can be softmaxed. But when k=2, it slows us down and hurts stability to be calculating two logits instead of just one. When k=2, I think we can have the head output size be 1 but then in predictions inflate it to a 2D tensor so that it stays compatible with the rest of MeTaL. It may be as simple as two if statements, or we may decide to subclass EndModel into BinaryEndModel.

The age-old question! For now, the Snorkel Classification is the latest implementation of the classifier package. Feel free to re-raise an issue in that repo if still relevant!