-
Notifications
You must be signed in to change notification settings - Fork 2
/
README.Rmd
173 lines (121 loc) · 7.58 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# rcf
<!-- badges: start -->
[![R-CMD-check](https://github.com/till-tietz/rcf/workflows/R-CMD-check/badge.svg)](https://github.com/till-tietz/rcf/actions)
<!-- badges: end -->
rcf is a simplified fully R based implementation of causal forests based on some features of Julie Tibshirani’s, Susan Athey’s, Stefan Wager’s and the grf lab team's causal forest functionality (https://grf-labs.github.io/grf/).
This was always meant to be a fun project to get stuck into the workings of the algorithm and try to get to grips with the theory behind it
(as far as I understood it) a bit better by just trying to build the functions from scratch. I wouldn't suggest using this (for any high stakes
work at least) for the following reasons:
1. this is really just an experimental implementation of my understanding of the grf causal forest algorithm
2. the functionality is much more limited than what grf offers
That being said when running these causal forest functions and the grf implementation on the same data the results are quite similar
(same ballpark when looking at correlations between conditional average treatment effect estimates produced by both implementations).
So if you prefer working in R and want to experiment around with some causal forest code hopefully this is somewhat useful.
##### References and Reading:
Rina Friedberg, Julie Tibshirani, Susan Athey, and Stefan Wager. Local Linear Forests. 2018
(https://arxiv.org/abs/1807.11408)
Stefan Wager and Susan Athey. Estimation and Inference of Heterogeneous Treatment Effects using Random Forests. Journal of the American Statistical Association, 113(523), 2018.
(https://arxiv.org/abs/1510.04342)
Susan Athey and Stefan Wager. Estimating Treatment Effects with Causal Forests: An Application. Observational Studies, 5, 2019.
(https://arxiv.org/abs/1902.07409)
Susan Athey, Julie Tibshirani and Stefan Wager. Generalized Random Forests. Annals of Statistics, 47(2), 2019.
(https://arxiv.org/abs/1610.01271)
## Installation
You can install the development version of rcf from (https://github.com/) with:
```{r, warning = FALSE, eval = FALSE}
devtools::install_github("till-tietz/rcf")
```
## Usage
```{r, warning = FALSE}
# generate some data
data <- as.data.frame(do.call(cbind, replicate(10, rnorm(100), simplify=FALSE)))
data[["treat"]] <- rbinom(nrow(data),1,0.5)
vars <- colnames(data)[1:(ncol(data)-2)]
# set up parallel processing
future::plan("multisession")
# build causal forest
cf <- rcf::causal_forest(n_trees = 1000, data = data, outcome = "V10",
covariates = vars, treat = "treat", minsize = 5,
alpha = 0.05, feature_fraction = 0.5, sample_fraction = 0.5,
honest_split = TRUE, honesty_fraction = 0.5)
# predict cates
cate <- rcf::predict_causal_forest(data = data, cf = cf, predict_oob = TRUE)
```
predict_causal_forest returns a data.frame of observation ids and cate estimates
```{r echo = FALSE, results = "asis", warning = FALSE}
knitr::kable(cate[c(1:10),])
```
variable_importance generates a data.frame of variable importance metrics
```{r, warning = FALSE}
var_importance <- rcf::variable_importance(cf = cf, covariates = vars, n = 4, d = 2)
```
```{r echo = FALSE, results = "asis", warning = FALSE}
knitr::kable(var_importance)
```
## Performance compared to grf
We'll build 500 grf and rcf causal forests respectively and compare the means of their cate predictions for each observation.
```{r, warning = FALSE, eval = FALSE}
grf_sim <- function(x){
grf <- grf::causal_forest(X = data[,vars], Y = data[,"V10"], W = data[,"treat"],
num.trees = 1000,mtry = 5, min.node.size = 5,
honesty = TRUE, honesty.fraction = 0.5, alpha = 0.05)
results <- as.data.frame(t(predict(grf)[["predictions"]]))
return(results)
}
results_grf <- furrr::future_map_dfr(1:500, ~grf_sim(.x), .progress = TRUE)%>%
dplyr::summarise_all(mean)%>%
t()
rcf_sim <- function(x){
cf <- rcf::causal_forest(n_trees = 1000, data = data, outcome = "V10",
covariates = vars, treat = "treat", minsize = 5,
alpha = 0.05, feature_fraction = 0.5, honest_split = TRUE,
honesty_fraction = 0.5)
results <- as.data.frame(t(rcf::predict_causal_forest(data = data, cf = cf, predict_obb = TRUE)[["cate"]]))
return(results)
}
results_rcf <- furrr::future_map_dfr(1:500, ~rcf_sim(.x), .progress = TRUE)%>%
dplyr::summarise_all(mean)%>%
t()
```
The rcf cate predictions match those generated by grf relatively well.
```{r echo = FALSE}
plot <- readRDS("man/figures/grf_rcf_plot.rds")
plot
```
## Performance compared to other methods
We'll test the performance of the rcf causal forest against a linear regression and knn approach to estimating heterogeneous treatment effects. We'll use a simulated data set with explicit treatment effect heterogeneity across two variables.
![](man/figures/performance.png)
## Methodology
### Explicitly Optimizing Heterogeneity
rfc serves as an estimator for conditional average treatment effects by explicitly
optimizing on treatment effect heterogeneity. This is achieved by recursively splitting a sample such as to maximize the following quantity of interest:
![](man/figures/eq_1.PNG)
Mean squared difference in treatment effects $\tau$ across sub-samples created by a set of all possible partitions of a sample $P$ minus the sum of variances in outcomes for treatment and control units summed across sub-samples. The two components of the equation are weighted by the parameter $\alpha$.
### Algorithm
1. Draw a sample of `size = n_data (feature_fraction)` without replacement
2. If honest_split is `TRUE`, split this sample into a tree fitting sample of `size = n_sample (1 – honesty_fraction)` and an honest estimation sample of `size = n_sample(honesty_fraction)`
3. Draw a sample of covariates of `size = n_covariates (feature_fraction)`
4. Find unique values of all sampled covariates in the tree fitting sample
5. Split the tree fitting sample at each unique value and assess if there are `n > minsize` treatment and control observations in each sub_sample created by the split (keep only those split points where the minsize requirement is met)
6. For each valid split point compute the above quantity of interest (variance of treatment effects across sub-samples minus sum of variances in outcomes for treatment
and control units in each sub-sample). Choose the split that maximizes this value.
7. Keep recursively splitting each sub-sample of the tree fitting sample until no split can satisfy the minsize requirement. The tree is fully grown at this point.
8. Push the honest estimation sample down the tree (i.e. subset the honest estimation
sample according to the splitting rules of the tree grown with the tree fitting sample).
9. Repeat 1-8 `n_trees` times.
10. Push a test sample down each tree in the forest (i.e. subset the test sample according to the splitting rules of each tree in the forest). For each observation, record the honest sample observations in each terminal leaf it falls into. Compute CATE for each observation using its honest sample observations neighbours.
### Variable Importance
Variable Importance is computed as a weighted sum of how often a variable was split at depth k within a tree.
![](man/figures/eq_2.PNG)