/gmtorch

Graphical Modeling in PyTorch

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause

gmtorch - Graphical Modeling in PyTorch

This is a library to do graphical modeling with the advantages of PyTorch, which include easy parallelism, GPU usage, and automatic differentiation. Currently, only Markov random fields are supported.

gmtorch is heavily inspired by the duality between graphical models and tensor networks (Robeva and Seigal, 2018), and in fact uses the opt_einsum tensor contraction library (Smith and Gray, 2018) to marginalize graphical models.

Example

gmtorch uses named PyTorch tensors (an experimental feature as of now) to encode MRF potentials. Example:

import gmtorch as gm
import torch

g = gm.Graph()
g.add_factor(torch.ones(2, 2, names=['A', 'B']))
g.add_factor(torch.ones(2, 3, 4, names=['B', 'C', 'D']))
g.add_factor(torch.ones(3, 5, names=['C', 'E']))
gm.plot(g, show_factors=True, show_cardinality=True)

You can marginalize like this:

print(g['A'])

# Result: tensor([120., 120.], names=('A',))

Installation

You can install gmtorch from the source as follows:

git clone https://github.com/rballester/gmtorch.git
cd gmtorch
pip install .

Main dependences:

  • NumPy
  • pgmpy (for reading networks and moralizing Bayesian networks)
  • PyTorch (as numerical and autodiff backend)
  • opt_einsum (for efficient marginalization)

Tests

We use pytest, and the tests depend on tntorch. To run them, do:

cd tests/
pytest

Contributing

Pull requests are welcome!

Besides using the issue tracker, feel also free to contact me at rafael.ballester@ie.edu.