The dagma
library is a Python 3 package for learning DAGs (a.k.a. Bayesian networks) from data.
DAGMA works by optimizing a given score/loss function, where the structure that relates the variables is constrained to be a directed acyclic graph (DAG). Due to the super-exponential number of DAGs w.r.t. the number of variables, the vanilla formulation results in a hard combinatorial optimization problem. DAGMA reformulates this optimization problem, by replacing the combinatorial constraint with a non-convex differentiable function that exactly characterizes DAGs, thus, making the optimization amenable to continuous optimization methods such as gradient descent.
This is an implementation of the following paper:
[1] Bello K., Aragam B., Ravikumar P. (2022). DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization. NeurIPS'22.
If you find this code useful, please consider citing:
@inproceedings{bello2022dagma,
author = {Bello, Kevin and Aragam, Bryon and Ravikumar, Pradeep},
booktitle = {Advances in Neural Information Processing Systems},
title = {{DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization}},
year = {2022}
}
- Supports continuous data for linear (see dagma.linear) and nonlinear models (see dagma.nonlinear).
- Supports binary (0/1) data for generalized linear models, via dagma.linear.DagmaLinear and using
logistic
as score. - Faster than other continuous optimization methods for structure learning, e.g., NOTEARS, GOLEM.
We recommend using a virtual environment via virtualenv
or conda
, and use pip
to install the dagma
package.
$ pip install dagma
See an example on how to use dagma in this iPython notebook.
We propose a new acyclicity characterization of DAGs via a log-det function for learning DAGs from observational data. Similar to previously proposed acyclicity functions (e.g. NOTEARS), our characterization is also exact and differentiable. However, when compared to existing characterizations, our log-det function: (1) Is better at detecting large cycles; (2) Has better-behaved gradients; and (3) Its runtime is in practice about an order of magnitude faster. These advantages of our log-det formulation, together with a path-following scheme, lead to significant improvements in structure accuracy (e.g. SHD).
Let
where
Given the exact differentiable characterization of a DAG, we are interested in solving the following optimization problem:
where
where
Let us give an illustration of how DAGMA works in a two-node graph (see Figure 1 in [1] for more details). Here
Below we have 4 plots, where each illustrates the solution to an unconstrained problem for different values of
- Python 3.7+
numpy
scipy
igraph
tqdm
torch
: Only used for nonlinear models.
linear.py
- implementation of DAGMA for linear models with l1 regularization (supports L2 and Logistic losses).nonlinear.py
- implementation of DAGMA for nonlinear models using MLPlocally_connected.py
- special layer structure used for MLPutils.py
- graph simulation, data simulation, and accuracy evaluation
We thank the authors of the NOTEARS repo for making their code available. Part of our code is based on their implementation, specially the utils.py
file and some code from their implementation of nonlinear models.