rcf is a simplified fully R based implementation of causal forests based on some features of Julie Tibshirani’s, Susan Athey’s, Stefan Wager’s and the grf lab team’s causal forest functionality (https://grf-labs.github.io/grf/). This was always meant to be a fun project to get stuck into the workings of the algorithm and try to get to grips with the theory behind it (as far as I understood it) a bit better by just trying to build the functions from scratch. I wouldn’t suggest using this (for any high stakes work at least) for the following reasons:
- this is really just an experimental implementation of my understanding of the grf causal forest algorithm
- the functionality is much more limited than what grf offers
That being said when running these causal forest functions and the grf implementation on the same data the results are quite similar (same ballpark when looking at correlations between conditional average treatment effect estimates produced by both implementations). So if you prefer working in R and want to experiment around with some causal forest code hopefully this is somewhat useful.
Rina Friedberg, Julie Tibshirani, Susan Athey, and Stefan Wager. Local Linear Forests. 2018 (https://arxiv.org/abs/1807.11408)
Stefan Wager and Susan Athey. Estimation and Inference of Heterogeneous Treatment Effects using Random Forests. Journal of the American Statistical Association, 113(523), 2018. (https://arxiv.org/abs/1510.04342)
Susan Athey and Stefan Wager. Estimating Treatment Effects with Causal Forests: An Application. Observational Studies, 5, 2019. (https://arxiv.org/abs/1902.07409)
Susan Athey, Julie Tibshirani and Stefan Wager. Generalized Random Forests. Annals of Statistics, 47(2), 2019. (https://arxiv.org/abs/1610.01271)
You can install the development version of rcf from (https://github.com/) with:
devtools::install_github("till-tietz/rcf")
# generate some data
data <- as.data.frame(do.call(cbind, replicate(10, rnorm(100), simplify=FALSE)))
data[["treat"]] <- rbinom(nrow(data),1,0.5)
vars <- colnames(data)[1:(ncol(data)-2)]
# set up parallel processing
future::plan("multisession")
# build causal forest
cf <- rcf::causal_forest(n_trees = 1000, data = data, outcome = "V10",
covariates = vars, treat = "treat", minsize = 5,
alpha = 0.05, feature_fraction = 0.5, sample_fraction = 0.5,
honest_split = TRUE, honesty_fraction = 0.5)
# predict cates
cate <- rcf::predict_causal_forest(data = data, cf = cf, predict_oob = TRUE)
predict_causal_forest returns a data.frame of observation ids and cate estimates
obs | cate |
---|---|
1 | 0.1446461 |
2 | 0.0681436 |
3 | 0.0930128 |
4 | 0.1624110 |
5 | 0.0203669 |
6 | 0.0689765 |
7 | 0.1531094 |
8 | 0.0721602 |
9 | 0.0791071 |
10 | 0.1382835 |
variable_importance generates a data.frame of variable importance metrics
var_importance <- rcf::variable_importance(cf = cf, covariates = vars, n = 4, d = 2)
variable | importance |
---|---|
V8 | 0.172652118100128 |
V6 | 0.108693196405648 |
V5 | 0.104996148908858 |
V3 | 0.102408215661104 |
V2 | 0.0898382541720154 |
V7 | 0.0854017971758665 |
V4 | 0.0783774069319641 |
V9 | 0.0735712451861361 |
V1 | 0.0621103979460847 |
We’ll build 500 grf and rcf causal forests respectively and compare the means of their cate predictions for each observation.
grf_sim <- function(x){
grf <- grf::causal_forest(X = data[,vars], Y = data[,"V10"], W = data[,"treat"],
num.trees = 1000,mtry = 5, min.node.size = 5,
honesty = TRUE, honesty.fraction = 0.5, alpha = 0.05)
results <- as.data.frame(t(predict(grf)[["predictions"]]))
return(results)
}
results_grf <- furrr::future_map_dfr(1:500, ~grf_sim(.x), .progress = TRUE)%>%
dplyr::summarise_all(mean)%>%
t()
rcf_sim <- function(x){
cf <- rcf::causal_forest(n_trees = 1000, data = data, outcome = "V10",
covariates = vars, treat = "treat", minsize = 5,
alpha = 0.05, feature_fraction = 0.5, honest_split = TRUE,
honesty_fraction = 0.5)
results <- as.data.frame(t(rcf::predict_causal_forest(data = data, cf = cf, predict_obb = TRUE)[["cate"]]))
return(results)
}
results_rcf <- furrr::future_map_dfr(1:500, ~rcf_sim(.x), .progress = TRUE)%>%
dplyr::summarise_all(mean)%>%
t()
The rcf cate predictions match those generated by grf relatively well.
We’ll test the performance of the rcf causal forest against a linear regression and knn approach to estimating heterogeneous treatment effects. We’ll use a simulated data set with explicit treatment effect heterogeneity across two variables.
rfc serves as an estimator for conditional average treatment effects by explicitly optimizing on treatment effect heterogeneity. This is achieved by recursively splitting a sample such as to maximize the following quantity of interest:
Mean squared difference in treatment effects (\tau) across sub-samples created by a set of all possible partitions of a sample (P) minus the sum of variances in outcomes for treatment and control units summed across sub-samples. The two components of the equation are weighted by the parameter (\alpha).
- Draw a sample of
size = n_data (feature_fraction)
without replacement - If honest_split is
TRUE
, split this sample into a tree fitting sample ofsize = n_sample (1 – honesty_fraction)
and an honest estimation sample ofsize = n_sample(honesty_fraction)
- Draw a sample of covariates of
size = n_covariates (feature_fraction)
- Find unique values of all sampled covariates in the tree fitting sample
- Split the tree fitting sample at each unique value and assess if
there are
n > minsize
treatment and control observations in each sub_sample created by the split (keep only those split points where the minsize requirement is met) - For each valid split point compute the above quantity of interest (variance of treatment effects across sub-samples minus sum of variances in outcomes for treatment and control units in each sub-sample). Choose the split that maximizes this value.
- Keep recursively splitting each sub-sample of the tree fitting sample until no split can satisfy the minsize requirement. The tree is fully grown at this point.
- Push the honest estimation sample down the tree (i.e. subset the honest estimation sample according to the splitting rules of the tree grown with the tree fitting sample).
- Repeat 1-8
n_trees
times. - Push a test sample down each tree in the forest (i.e. subset the test sample according to the splitting rules of each tree in the forest). For each observation, record the honest sample observations in each terminal leaf it falls into. Compute CATE for each observation using its honest sample observations neighbours.
Variable Importance is computed as a weighted sum of how often a variable was split at depth k within a tree.