Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an explicit treatment input argument to causaltree/forest #776

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion causalml/inference/tree/_tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def __init__(

@abstractmethod
def fit(
self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated"
self,
X,
treatment,
y,
sample_weight=None,
check_input=True,
X_idx_sorted="deprecated",
):
pass

Expand Down
3 changes: 2 additions & 1 deletion causalml/inference/tree/_tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cdef class Criterion:

# Internal structures
cdef const DOUBLE_t[:, ::1] y # Values of y
cdef DOUBLE_t* treatment # Treatment assignment
cdef DOUBLE_t* sample_weight # Sample weights

cdef SIZE_t* samples # Sample indices in X, y
Expand Down Expand Up @@ -56,7 +57,7 @@ cdef class Criterion:
# statistics correspond to samples[start:pos] and samples[pos:end].

# Methods
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1
cdef int reset(self) nogil except -1
Expand Down
8 changes: 6 additions & 2 deletions causalml/inference/tree/_tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ cdef class Criterion:
def __setstate__(self, d):
pass

cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1:
"""Placeholder for a method which will initialize the criterion.
Expand All @@ -60,6 +60,8 @@ cdef class Criterion:
----------
y : array-like, dtype=DOUBLE_t
y is a buffer that can store values for n_outputs target variables
treatment : array-like, dtype=DOUBLE_t
The treatment assignment of each sample.
sample_weight : array-like, dtype=DOUBLE_t
The weight of each sample
weighted_n_samples : double
Expand Down Expand Up @@ -224,6 +226,7 @@ cdef class RegressionCriterion(Criterion):
The total number of samples to fit on
"""
# Default values
self.treatment = NULL
self.sample_weight = NULL

self.samples = NULL
Expand Down Expand Up @@ -259,7 +262,7 @@ cdef class RegressionCriterion(Criterion):
def __reduce__(self):
return (type(self), (self.n_outputs, self.n_samples), self.__getstate__())

cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1:
"""Initialize the criterion.
Expand All @@ -269,6 +272,7 @@ cdef class RegressionCriterion(Criterion):
"""
# Initialize fields
self.y = y
self.treatment = treatment
self.sample_weight = sample_weight
self.samples = samples
self.start = start
Expand Down
3 changes: 2 additions & 1 deletion causalml/inference/tree/_tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cdef class Splitter:
cdef SIZE_t end # End position for the current node

cdef const DOUBLE_t[:, ::1] y
cdef DOUBLE_t* treatment
cdef DOUBLE_t* sample_weight

# The samples vector `samples` is maintained by the Splitter object such
Expand All @@ -83,7 +84,7 @@ cdef class Splitter:

# Methods
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight) except -1
DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1

cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil except -1
Expand Down
13 changes: 11 additions & 2 deletions causalml/inference/tree/_tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ cdef class Splitter:
self.n_features = 0
self.feature_values = NULL

self.treatment = NULL
self.sample_weight = NULL

self.max_features = max_features
Expand All @@ -118,6 +119,7 @@ cdef class Splitter:
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter.

Expand All @@ -134,6 +136,9 @@ cdef class Splitter:
y : ndarray, dtype=DOUBLE_t
This is the vector of targets, or true labels, for the samples

treatment : DOUBLE_t*
The treatment assignments of the samples.

sample_weight : DOUBLE_t*
The weights of the samples, where higher weighted samples are fit
closer than lower weight samples. If not provided, all samples
Expand Down Expand Up @@ -180,6 +185,7 @@ cdef class Splitter:
self.y = y

self.sample_weight = sample_weight
self.treatment = treatment
return 0

cdef int node_reset(self, SIZE_t start, SIZE_t end,
Expand All @@ -203,6 +209,7 @@ cdef class Splitter:
self.end = end

self.criterion.init(self.y,
self.treatment,
self.sample_weight,
self.weighted_n_samples,
self.samples,
Expand Down Expand Up @@ -243,6 +250,7 @@ cdef class BaseDenseSplitter(Splitter):
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter

Expand All @@ -251,7 +259,7 @@ cdef class BaseDenseSplitter(Splitter):
"""

# Call parent init
Splitter.init(self, X, y, sample_weight)
Splitter.init(self, X, y, treatment, sample_weight)

self.X = X
return 0
Expand Down Expand Up @@ -802,14 +810,15 @@ cdef class BaseSparseSplitter(Splitter):
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter

Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
# Call parent init
Splitter.init(self, X, y, sample_weight)
Splitter.init(self, X, y, treatment, sample_weight)

if not isinstance(X, csc_matrix):
raise ValueError("X should be in csc format")
Expand Down
2 changes: 2 additions & 0 deletions causalml/inference/tree/_tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ cdef class TreeBuilder:
Tree tree,
object X,
cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=*,
)

cdef _check_input(
self,
object X,
cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight,
)
19 changes: 14 additions & 5 deletions causalml/inference/tree/_tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ cdef class TreeBuilder:
"""Interface for different tree building strategies."""

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""
pass

cdef inline _check_input(self, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight):
"""Check input dtype, layout and format"""
if issparse(X):
Expand All @@ -122,13 +124,16 @@ cdef class TreeBuilder:
if y.dtype != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

if treatment.dtype != DOUBLE or not treatment.flags.contiguous:
treatment = np.ascontiguousarray(treatment, dtype=DOUBLE)

if (sample_weight is not None and
(sample_weight.dtype != DOUBLE or
not sample_weight.flags.contiguous)):
sample_weight = np.asarray(sample_weight, dtype=DOUBLE,
order="C")

return X, y, sample_weight
return X, y, treatment, sample_weight

# Depth first builder ---------------------------------------------------------

Expand All @@ -146,12 +151,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -175,7 +182,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef double min_impurity_decrease = self.min_impurity_decrease

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef SIZE_t start
cdef SIZE_t end
Expand Down Expand Up @@ -328,12 +335,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -346,7 +355,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_split = self.min_samples_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef vector[FrontierRecord] frontier
cdef FrontierRecord record
Expand Down
21 changes: 15 additions & 6 deletions causalml/inference/tree/causal/_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray treatment,
np.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -80,7 +82,7 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
cdef double min_impurity_decrease = self.min_impurity_decrease

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef SIZE_t start
cdef SIZE_t end
Expand Down Expand Up @@ -239,13 +241,20 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder):
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None):
cpdef build(
self,
Tree tree,
object X,
np.ndarray y,
np.ndarray treatment,
np.ndarray sample_weight=None
):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -258,7 +267,7 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_split = self.min_samples_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef vector[FrontierRecord] frontier
cdef FrontierRecord record
Expand Down
Loading
Loading