diff --git a/src/samplics/weighting/adjustment.py b/src/samplics/weighting/adjustment.py index 1696a59a..8dba703e 100644 --- a/src/samplics/weighting/adjustment.py +++ b/src/samplics/weighting/adjustment.py @@ -358,6 +358,7 @@ def poststratify( raise AssertionError("control or factor must be specified.") if isinstance(control, dict): + # breakpoint() if (np.unique(domain) != np.unique(list(control.keys()))).any(): raise ValueError("control dictionary keys do not match domain values.") @@ -396,6 +397,7 @@ def rake( ll_bound: Optional[Union[DictStrNum, Number]] = None, up_bound: Optional[Union[DictStrNum, Number]] = None, tol: float = 1e-4, + ctrl_tol: float = 1e-4, max_iter: int = 100, display_iter: bool = False, ) -> np.ndarray: @@ -413,49 +415,58 @@ def rake( print(f"\nIteration {iter + 1}") if iter == 0: + rk_wgt = samp_weight wgt_prev = samp_weight for margin in margins: domain = formats.numpy_array(margins[margin]) if control is not None: - wgt = self.poststratify(samp_weight=wgt_prev, control=control[margin], domain=domain) + rk_wgt = self.poststratify(samp_weight=rk_wgt, control=control[margin], domain=domain) elif factor is not None: - wgt = self.poststratify(samp_weight=wgt_prev, factor=factor[margin], domain=domain) + rk_wgt = self.poststratify(samp_weight=rk_wgt, factor=factor[margin], domain=domain) else: raise AssertionError("control or factor must be specified!") - wgt_prev = wgt - sum_wgt = {} + sum_prev_wgt = {} for margin in margins: domain = formats.numpy_array(margins[margin]) sum_wgt_domain = {} + sum_prev_wgt_domain = {} for d in control[margin]: - sum_wgt_domain[d] = np.sum(wgt[domain == d]) + sum_wgt_domain[d] = np.sum(rk_wgt[domain == d]) + sum_prev_wgt_domain[d] = np.sum(wgt_prev[domain == d]) sum_wgt[margin] = sum_wgt_domain + sum_prev_wgt[margin] = sum_prev_wgt_domain # diff = {} max_diff = 0 + max_ctrl_diff = 0 for margin in margins: if display_iter: print(f" Margin: {margin}") diff_margin = {} + diff_ctrl_margin = {} for d in control[margin]: - diff_margin[d] = np.abs(control[margin][d] - sum_wgt[margin][d]) + diff_margin[d] = np.abs(sum_wgt[margin][d] - sum_prev_wgt[margin][d]) / sum_prev_wgt[margin][d] + diff_ctrl_margin[d] = np.abs(sum_wgt[margin][d] - control[margin][d]) / control[margin][d] if display_iter: - print(f" Difference for '{d}': {diff_margin[d]}") + print(f" Difference against previous value for '{d}': {diff_margin[d]}") + print(f" Difference against control value for '{d}': {diff_ctrl_margin[d]}") # diff[margin] = diff_margin max_diff = max(max_diff, max(diff_margin.values())) + max_ctrl_diff = max(max_ctrl_diff, max(diff_ctrl_margin.values())) obs_tol = max_diff + obs_ctrl_tol = max_ctrl_diff - if obs_tol <= tol: + if obs_tol <= tol and obs_ctrl_tol <= ctrl_tol: converged = True if ll_bound is not None or up_bound is not None: - wgt_ratios = wgt / samp_weight + wgt_ratios = rk_wgt / samp_weight min_ratio = np.min(wgt_ratios) max_ratio = np.max(wgt_ratios) @@ -474,9 +485,12 @@ def rake( else: bounded = True + wgt_prev = rk_wgt iter += 1 - return wgt + self.adj_method = "raking" + + return rk_wgt @staticmethod def _calib_covariates( diff --git a/tests/weighting/test_adjustment.py b/tests/weighting/test_adjustment.py index 53087c36..3886fcd5 100644 --- a/tests/weighting/test_adjustment.py +++ b/tests/weighting/test_adjustment.py @@ -4,6 +4,40 @@ from samplics.weighting import SampleWeight +# stata example + +# nhis_sam = pl.read_csv("~/Downloads/nhis_sam.csv").with_columns( +# pl.when(pl.col("hisp") == 4).then(pl.lit(3)).otherwise(pl.col("hisp")).alias("hisp") +# ) + +# age_grp = { +# "<18": 5991, +# "18-24": 2014, +# "25-44": 6124, +# "45-64": 5011, +# "65+": 2448, +# } +# hisp_race = {1: 5031, 2: 12637, 3: 3920} +# control = {"age_grp": age_grp, "hisp": hisp_race} + +# # breakpoint() + +# ll = 0.8 +# ul = 1.2 + +# margins = { +# "age_grp": nhis_sam["age_grp"].to_list(), +# "hisp": nhis_sam["hisp"].to_list(), +# } + +# nhis_sam_rk = SampleWeight() + +# nhis_sam = nhis_sam.with_columns( +# rake_wt_2=nhis_sam_rk.rake( +# samp_weight=nhis_sam["wt"], control=control, margins=margins, display_iter=True, tol=1e-6 +# ) +# ).with_columns(diff=pl.col("rake_wt_2") - pl.col("rake_wt")) + # synthetic data for testing wgt = np.random.uniform(0, 1, 1000) @@ -275,9 +309,12 @@ def test_ps_wgt_with_class(): sample_wgt_rk_not_bound = SampleWeight() -rk_wgt_not_bound = sample_wgt_rk_not_bound.rake( - samp_weight=income_sample2["design_wgt"], control=control, margins=margins, display_iter=True -) +# rk_wgt_not_bound = sample_wgt_rk_not_bound.rake( +# samp_weight=income_sample2["design_wgt"], control=control, margins=margins, display_iter=True, tol=1e-4 +# ) + + + # breakpoint() # age_grp = {"<18": 21588, age}