Skip to content

Commit

Permalink
Many to one propensity matching without replacement (#786)
Browse files Browse the repository at this point in the history
* Many to one matching without replacement.
* black formatting fixes
  • Loading branch information
spohngellert-o authored Aug 1, 2024
1 parent b54c201 commit 084a6d0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
18 changes: 11 additions & 7 deletions causalml/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class NearestNeighborMatch:
Attributes:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
ratio (int): ratio of control / treatment to be matched. used only if
replace=True.
ratio (int): ratio of control / treatment to be matched.
shuffle (bool): whether to shuffle the treatment group data before
matching
random_state (numpy.random.RandomState or int): RandomState or an int
Expand All @@ -112,6 +111,7 @@ def __init__(
Args:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
ratio (int): ratio of control / treatment to be matched.
shuffle (bool): whether to shuffle the treatment group data before
matching or not
random_state (numpy.random.RandomState or int): RandomState or an
Expand Down Expand Up @@ -200,11 +200,15 @@ def match(self, data, treatment_col, score_cols):
control.loc[control.unmatched, score_col]
- treatment.loc[t_idx, score_col]
)
c_idx_min = dist.idxmin()
if dist[c_idx_min] <= sdcal:
t_idx_matched.append(t_idx)
c_idx_matched.append(c_idx_min)
control.loc[c_idx_min, "unmatched"] = False
# Gets self.ratio lowest dists
c_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio]
c_idx_list = dist.index[c_np_idx_list]
for i, c_idx in enumerate(c_idx_list):
if dist[c_idx] <= sdcal:
if i == 0:
t_idx_matched.append(t_idx)
c_idx_matched.append(c_idx)
control.loc[c_idx, "unmatched"] = False

return data.loc[
np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)])
Expand Down
10 changes: 9 additions & 1 deletion tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,18 @@ def _generate_data():
yield _generate_data


def test_nearest_neighbor_match_ratio_2(generate_unmatched_data):
df, features = generate_unmatched_data()

psm = NearestNeighborMatch(replace=False, ratio=2, random_state=RANDOM_SEED)
matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL])
assert sum(matched[TREATMENT_COL] == 0) == 2 * sum(matched[TREATMENT_COL] != 0)


def test_nearest_neighbor_match_by_group(generate_unmatched_data):
df, features = generate_unmatched_data()

psm = NearestNeighborMatch(replace=False, ratio=1.0, random_state=RANDOM_SEED)
psm = NearestNeighborMatch(replace=False, ratio=1, random_state=RANDOM_SEED)

matched = psm.match_by_group(
data=df,
Expand Down

0 comments on commit 084a6d0

Please sign in to comment.