/TorchCRF

An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0

Primary LanguagePythonMIT LicenseMIT

Torch CRF

CircleCI Coverage Status MIT License

Python Versions PyPI version

Implementation of CRF (Conditional Random Fields) in PyTorch

Requirements

  • python3 (>=3.6)
  • PyTorch (>=1.0)

Installation

$ pip install TorchCRF

Usage

>>> import torch
>>> from TorchCRF import CRF
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> batch_size = 2
>>> sequence_size = 3
>>> num_labels = 5
>>> mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
>>> labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device)  # (batch_size, sequence_size)
>>> hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
>>> crf = CRF(num_labels)

Computing log-likelihood (used where forward)

>>> crf.forward(hidden, labels, mask)
tensor([-7.6204, -3.6124], device='cuda:0', grad_fn=<ThSubBackward>)

Decoding (predict labels of sequences)

>>> crf.viterbi_decode(hidden, mask)
[[0, 2, 2], [4, 0]]

License

MIT

References