Skip to content

Commit

Permalink
Update y0 usage and add notebook utils (#25)
Browse files Browse the repository at this point in the history
This update does the following:

1. Externalizes some functionality that has been re-implemented upstream
in `y0`.
   - `add_ci_undirected_edges` is now part of `y0`
- `find_all_nodes_in_causal_paths` now has a more performant
implementation in `y0`
- the actual call to the evans simplification functions from y0 is now
used, which has a much more clean interface.
2. Moves shared code from the case study notebooks into
`eliater.notebook_utils` such that it can be reused. Further, this code
automatically adds descriptive text as well as more context than was
previously available in the case study notebooks. There are still some
open questions about the workflow for interpreting real data - from now
on, it makes more sense to implement this in the Python package rather
than having long runaway notebooks with tons and tons of code
3. the `_adjustment_set` argument was added into the eliater linear
regression to allow for pre-caching inside the notebook workflow
4. update to modernize setup.cfg. This includes the updates based on
y0-causal-inference/y0#218 and
y0-causal-inference/y0#219
  • Loading branch information
cthoyt committed Apr 25, 2024
1 parent d4888d7 commit a952744
Show file tree
Hide file tree
Showing 13 changed files with 16,275 additions and 3,119 deletions.
2,265 changes: 1,452 additions & 813 deletions notebooks/Case_study1_The_Sars_cov2_model.ipynb

Large diffs are not rendered by default.

4,063 changes: 3,265 additions & 798 deletions notebooks/Case_study2_The_Tsignaling_pathway.ipynb

Large diffs are not rendered by default.

5,240 changes: 5,191 additions & 49 deletions notebooks/Case_study3_The_EColi.ipynb

Large diffs are not rendered by default.

6,911 changes: 5,571 additions & 1,340 deletions notebooks/motivating_example.ipynb

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ keywords =

[options]
install_requires =
y0>=0.2.7
y0>=0.2.8
scipy
numpy
ananke-causal>=0.5.0
Expand Down Expand Up @@ -128,12 +128,16 @@ strictness = short
#########################
[flake8]
ignore =
S301 # pickle
S403 # pickle
# pickle
S301
# pickle
S403
S404
S603
W503 # Line break before binary operator (flake8 is wrong)
E203 # whitespace before ':'
# Line break before binary operator (flake8 is wrong)
W503
# whitespace before ':'
E203
exclude =
.tox,
.git,
Expand Down
21 changes: 20 additions & 1 deletion src/eliater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,33 @@

from .api import workflow
from .discover_latent_nodes import remove_nuisance_variables
from .network_validation import add_ci_undirected_edges, plot_ci_size_dependence
from .network_validation import (
add_ci_undirected_edges,
discretize_binary,
plot_ci_size_dependence,
plot_treatment_and_outcome,
)
from .notebook_utils import (
step_1_notebook,
step_2_notebook,
step_3_notebook,
step_5_notebook_real,
step_5_notebook_synthetic,
)
from .version import get_version

__all__ = [
"workflow",
"remove_nuisance_variables",
"add_ci_undirected_edges",
"plot_ci_size_dependence",
"plot_treatment_and_outcome",
"discretize_binary",
"step_1_notebook",
"step_2_notebook",
"step_3_notebook",
"step_5_notebook_real",
"step_5_notebook_synthetic",
]


Expand Down
119 changes: 26 additions & 93 deletions src/eliater/discover_latent_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,108 +141,40 @@
$Z_4$ is removed as its children are a subset of $Z_1$'s children.
"""

import itertools
from typing import Iterable, Optional, Set, Union
import warnings
from typing import Set, Union

import networkx as nx

from y0.algorithm.simplify_latent import simplify_latent_dag
from y0.algorithm.simplify_latent import evans_simplify
from y0.dsl import Variable
from y0.graph import DEFAULT_TAG, NxMixedGraph
from y0.graph import NxMixedGraph, _ensure_set, get_nodes_in_directed_paths

__all__ = [
"remove_nuisance_variables",
"find_nuisance_variables",
]


def remove_nuisance_variables(
graph: NxMixedGraph,
treatments: Union[Variable, Set[Variable]],
outcomes: Union[Variable, Set[Variable]],
tag: Optional[str] = None,
) -> NxMixedGraph:
"""Find all nuisance variables and remove them based on Evans' simplification rules.
:param graph: an NxMixedGraph
:param treatments: a list of treatments
:param outcomes: a list of outcomes
:param tag: The tag for which variables are latent
:return: the new graph after simplification
"""
rv = NxMixedGraph(
directed=graph.directed.copy(),
undirected=graph.undirected.copy(),
)
lv_dag = mark_nuisance_variables_as_latent(
graph=rv, treatments=treatments, outcomes=outcomes, tag=tag
)
simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)


def mark_nuisance_variables_as_latent(
graph: NxMixedGraph,
treatments: Union[Variable, Set[Variable]],
outcomes: Union[Variable, Set[Variable]],
tag: Optional[str] = None,
) -> nx.DiGraph:
"""Find all the nuisance variables and mark them as latent.
Mark nuisance variables as latent by first identifying them, then creating a new graph where these
nodes are marked as latent. Nuisance variables are the descendants of nodes in all proper causal paths
that are not ancestors of the outcome variables nodes. A proper causal path is a directed path from
treatments to the outcome. Nuisance variables should not be included in the estimation of the causal
effect as they increase the variance.
:param graph: an NxMixedGraph
:param treatments: a list of treatments
:param outcomes: a list of outcomes
:param tag: The tag for which variables are latent
:return: the modified graph after simplification, in place
"""
if tag is None:
tag = DEFAULT_TAG
nuisance_variables = find_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
lv_dag = NxMixedGraph.to_latent_variable_dag(graph, tag=tag)
# Set nuisance variables as latent
for node, data in lv_dag.nodes(data=True):
if Variable(node) in nuisance_variables:
data[tag] = True
return lv_dag


def find_all_nodes_in_causal_paths(
graph: NxMixedGraph,
treatments: Union[Variable, Set[Variable]],
outcomes: Union[Variable, Set[Variable]],
) -> Set[Variable]:
"""Find all the nodes in proper causal paths from treatments to outcomes.
A proper causal path is a directed path from treatments to the outcome.
:param graph: an NxMixedGraph
:param treatments: a list of treatments
:param outcomes: a list of outcomes
:return: the nodes on all causal paths from treatments to outcomes.
"""
if isinstance(treatments, Variable):
treatments = {treatments}
if isinstance(outcomes, Variable):
outcomes = {outcomes}

return {
node
for treatment, outcome in itertools.product(treatments, outcomes)
for causal_path in nx.all_simple_paths(graph.directed, treatment, outcome)
for node in causal_path
}
return evans_simplify(graph, latents=nuisance_variables)


def find_nuisance_variables(
graph: NxMixedGraph,
treatments: Union[Variable, Set[Variable]],
outcomes: Union[Variable, Set[Variable]],
) -> Iterable[Variable]:
) -> Set[Variable]:
"""Find the nuisance variables in the graph.
Nuisance variables are the descendants of nodes in all proper causal paths that are
Expand All @@ -255,25 +187,26 @@ def find_nuisance_variables(
:param outcomes: a list of outcomes
:returns: The nuisance variables.
"""
if isinstance(treatments, Variable):
treatments = {treatments}
if isinstance(outcomes, Variable):
outcomes = {outcomes}

# Find the nodes on all causal paths
nodes_on_causal_paths = find_all_nodes_in_causal_paths(
graph=graph, treatments=treatments, outcomes=outcomes
treatments = _ensure_set(treatments)
outcomes = _ensure_set(outcomes)
intermediaries = get_nodes_in_directed_paths(graph, treatments, outcomes)
return (
graph.descendants_inclusive(intermediaries)
- graph.ancestors_inclusive(outcomes)
- treatments
- outcomes
)

# Find the descendants of interest
descendants_of_nodes_on_causal_paths = graph.descendants_inclusive(nodes_on_causal_paths)

# Find the ancestors of outcome variables
ancestors_of_outcomes = graph.ancestors_inclusive(outcomes)

descendants_not_ancestors = descendants_of_nodes_on_causal_paths.difference(
ancestors_of_outcomes
def find_all_nodes_in_causal_paths(
graph: NxMixedGraph,
treatments: Union[Variable, Set[Variable]],
outcomes: Union[Variable, Set[Variable]],
) -> Set[Variable]:
"""Find all the nodes in proper causal paths from treatments to outcomes."""
warnings.warn(
"This has been replaced with an efficient implementation in y0",
DeprecationWarning,
stacklevel=1,
)

nuisance_variables = descendants_not_ancestors.difference(treatments.union(outcomes))
return nuisance_variables
return get_nodes_in_directed_paths(graph, treatments, outcomes)
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def generate(
*,
seed: int | None = None,
) -> pd.DataFrame:
"""Generate discrete testing data for the multiple_mediators_with_multiple_confounders_nuisances_discrete case study.
"""Generate discrete test data for the multiple_mediators_with_multiple_confounders_nuisances_discrete case study.
:param num_samples: The number of samples to generate. Try 1000.
:param treatments: An optional dictionary of the values to fix each variable to.
Expand Down
3 changes: 2 additions & 1 deletion src/eliater/examples/t_cell_signaling_pathway.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# does not want to read the reference. Spoon feed the important information
# 2. Is there associated data to go with this graph? Commit it into the examples folder


from eliater.data import load_sachs_df
from y0.algorithm.identify import Query
from y0.examples import Example
from y0.graph import NxMixedGraph
Expand All @@ -40,6 +40,7 @@
reference="K. Sachs, O. Perez, D. Pe’er, D. A. Lauffenburger, and G. P. Nolan. Causal protein-signaling"
"networks derived from multiparameter single-cell data. Science, 308(5721): 523–529, 2005.",
graph=graph,
data=load_sachs_df(),
description="This is an example of a protein signaling network of the T cell signaling pathway"
"It models the molecular mechanisms and regulatory processes of human cells involved"
"in T cell activation, proliferation, and function. The observational data consisted of quantitative"
Expand Down
52 changes: 38 additions & 14 deletions src/eliater/network_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,31 @@
"""

import time
import warnings
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from numpy import mean, quantile
from sklearn.preprocessing import KBinsDiscretizer
from tabulate import tabulate
from tqdm.auto import trange

from y0.algorithm.conditional_independencies import get_conditional_independencies
import y0.algorithm.conditional_independencies
from y0.algorithm.falsification import get_graph_falsifications
from y0.graph import NxMixedGraph
from y0.struct import CITest, _ensure_method, get_conditional_independence_tests
from y0.struct import (
DEFAULT_SIGNIFICANCE,
CITest,
_ensure_method,
get_conditional_independence_tests,
)

__all__ = [
"discretize_binary",
"plot_treatment_and_outcome",
"add_ci_undirected_edges",
"print_graph_falsifications",
"p_value_of_bootstrap_data",
Expand All @@ -182,7 +191,25 @@
]

TESTS = get_conditional_independence_tests()
DEFAULT_SIGNIFICANCE = 0.01


def plot_treatment_and_outcome(data, treatment, outcome, figsize=(8, 2.5)) -> None:
"""Plot the treatment and outcome histograms."""
fig, (lax, rax) = plt.subplots(1, 2, figsize=figsize)
sns.histplot(data=data, x=treatment.name, ax=lax)
lax.axvline(data[treatment.name].mean(), color="red")
lax.set_title("Treatment")

sns.histplot(data=data, x=outcome.name, ax=rax)
rax.axvline(data[outcome.name].mean(), color="red")
rax.set_ylabel("")
rax.set_title("Outcome")


def discretize_binary(data: pd.DataFrame) -> pd.DataFrame:
"""Discretize continuous data into binary data using K-Bins Discretization."""
kbins = KBinsDiscretizer(n_bins=2, encode="ordinal", strategy="uniform")
return pd.DataFrame(kbins.fit_transform(data), columns=data.columns)


def add_ci_undirected_edges(
Expand All @@ -204,18 +231,15 @@ def add_ci_undirected_edges(
the tested variables. If none, defaults to 0.05.
:returns: A copy of the input graph potentially with new undirected edges added
"""
rv = NxMixedGraph(
directed=graph.directed.copy(),
undirected=graph.undirected.copy(),
warnings.warn(
"This method has been replaced by a refactored implementation in "
"y0.algorithm.conditional_independencies.add_ci_undirected_edges",
DeprecationWarning,
stacklevel=1,
)
return y0.algorithm.conditional_independencies.add_ci_undirected_edges(
graph=graph, data=data, method=method, significance_level=significance_level
)
if significance_level is None:
significance_level = DEFAULT_SIGNIFICANCE
for judgement in get_conditional_independencies(rv):
if not judgement.test(
data, boolean=True, method=method, significance_level=significance_level
):
rv.add_undirected_edge(judgement.left, judgement.right)
return rv


def print_graph_falsifications(
Expand Down
Loading

0 comments on commit a952744

Please sign in to comment.