dagrad
is a Python package that provides an extensible, modular platform for developing and experimenting with differentiable (gradient-based) structure learning methods.
It builds upon the NOTEARS framework and also functions as an updated repository of state-of-the-art implementations for various methods.
dagrad
provides the following key features:
- A universal framework for implementing general score-based methods that are end-to-end differentiable
- A modular implementation that makes it easy to swap out different score functions, acyclicity constraints, regularizers, and optimizers
- An extensible approach that allows users to implement their own objectives and constraints using PyTorch
- Speed and scalability enabled through GPU acceleration
A directed acyclic graphical model (also known as a Bayesian network) with d
nodes represents a distribution over a random vector of size d
. The focus of this library is on Bayesian Network Structure Learning (BNSL): Given samples G
?
The problem can be formulated as the following differentiable continuous optimization:
This formulation is versatile enough to encompass both linear and nonlinear models with any smooth objective (e.g. log-likelihood, least-squares, cross-entropy, etc.).
dagrad
provides a unified framework that allows users to employ either predefined or customized loss functions
GPU acceleration is also supported.
To install dagrad
:
$ git clone https://github.com/Duntrain/dagrad.git
$ cd dagrad/
$ pip install -e .
$ cd tests
$ python test_fast.py
The above installs dagrad
and runs the original NOTEARS method [1] on a randomly generated 10-node Erdos-Renyi graph with 1000 samples. The output should look like the below:
{'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 10}
Want to try more examples? See an example in this iPython notebook.
Here is a simple demo:
$ import dagrad
$ n, d, s0, graph_type, sem_type = 1000, 10, 10, 'ER', 'gauss'
$ X, W, G = dagrad.generate_linear_data(n, d, s0, graph_type, sem_type)
$ W_est = dagrad.dagrad(X = X, model = 'linear', method = 'notears')
$ acc = dagrad.count_accuracy(G, W_est != 0)
$ print('Accuracy: ', acc)
Below is an overview of the functionalities provided by the package:
Method(method ) |
Model(model ) |
Loss(loss_fn ) |
Regularizers(reg ) |
h(h_fn ) |
Optimizer(optimizer ) |
Computation Library(compute_lib ) |
Device(device ) |
---|---|---|---|---|---|---|---|
'notears' [1] |
'linear' ,'nonlinear' |
'l2' , 'logll' , 'user_loss' |
'l1' 'l2' 'mcp' 'none' 'user_reg' |
'h_exp_sq' 'h_poly_sq' 'h_poly_abs' 'user_h' |
Adam('adam' ),LBFGS( 'lbfgs' ) |
Numpy('numpy' ),Torch( 'torch' ), |
CPU('cpu' )CUDA( 'cuda' ) |
'dagma' [2] |
'linear' ,'nonlinear' |
'l2' , 'logll' , 'user_loss' |
'l1' 'l2' 'mcp' 'none' , 'user_reg' |
'h_logdet_sq' 'h_logdet_abs' 'user_h' |
Adam('adam' ) |
Numpy('numpy' )Torch( 'torch' ) |
CPU('cpu' )CUDA( 'cuda' ) |
'topo' [3] |
'linear' ,'nonlinear' |
'l2' ,'logll' ,'user_loss' |
'l1' 'l2' 'mcp' 'none' 'user_reg' |
'h_exp_topo' 'h_logdet_topo' 'h_poly_topo' 'user_h' |
Adam('adam' ),LBFGS( 'lbfgs' ) |
Numpy('numpy' ) for linear Torch( 'torch' ) for nonlinear |
CPU('cpu' ) |
- For the linear (
'linear'
) model, the loss function (loss_fn
) can be configured as logistic loss ('logistic'
) for all three methods. - In the linear (
'linear'
) model, the default optimizer ('optimizer'
) for TOPO ('topo'
) is scikit-learn ('sklearn'
), a state-of-the-art package for solving linear model problems. - In the linear (
'linear'
) model, NOTEARS ('notears'
) and DAGMA ('dagma'
) also support computation libraries (compute_lib
) such as Torch ('torch'
), and can perform computations on either CPU ('cpu'
) or GPU ('cuda'
).
- Python 3.7+
numpy
scipy
scikit-learn
python-igraph
tqdm
dagma
notears
: installed from github repotorch
: Used for models with GPU acceleration
[1] Zheng X, Aragam B, Ravikumar P, & Xing EP DAGs with NO TEARS: Continuous optimization for structure learning (NeurIPS 2018, Spotlight).
[2] Zheng X, Dan C, Aragam B, Ravikumar P, & Xing EP Learning sparse nonparametric DAGs (AISTATS 2020).
[3] Bello K, Aragam B, Ravikumar P DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization (NeurIPS 2022).
[4] Deng C, Bello K, Aragam B, Ravikumar P Optimizing NOTEARS Objectives via Topological Swaps (ICML 2023).
[5] Deng C, Bello K, Ravikumar P, Aragam B Likelihood-based differentiable structure learning (NeurIPS 2024).