Skip to content

Commit

Permalink
Add support for weights in objects.Est (#3580)
Browse files Browse the repository at this point in the history
* Add support for weights in objects.Est

* Fix garbled docstring text
  • Loading branch information
mwaskom authored Dec 6, 2023
1 parent f013633 commit 2bb945c
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 4 deletions.
18 changes: 18 additions & 0 deletions doc/_docstrings/objects.Est.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,30 @@
"p.add(so.Range(), so.Est(seed=0))"
]
},
{
"cell_type": "markdown",
"id": "df807ef8-b5fb-4eac-b539-1bd4e797ddc2",
"metadata": {},
"source": [
"To compute a weighted estimate (and confidence interval), assign a `weight` variable in the layer where you use the stat:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e4a0594-e1ee-4f72-971e-3763dd626e8b",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Range(), so.Est(), weight=\"price\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d0c34d7-fb76-44cf-9079-3ec7f45741d0",
"metadata": {},
"outputs": [],
"source": []
}
],
Expand Down
56 changes: 56 additions & 0 deletions seaborn/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,62 @@ def __call__(self, data, var):
return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max})


class WeightedEstimateAggregator:

def __init__(self, estimator, errorbar=None, **boot_kws):
"""
Data aggregator that produces a weighted estimate and error bar interval.
Parameters
----------
estimator : string
Function (or method name) that maps a vector to a scalar. Currently
supports only "mean".
errorbar : string or (string, number) tuple
Name of errorbar method or a tuple with a method name and a level parameter.
Currently the only supported method is "ci".
boot_kws
Additional keywords are passed to bootstrap when error_method is "ci".
"""
if estimator != "mean":
# Note that, while other weighted estimators may make sense (e.g. median),
# I'm not aware of an implementation in our dependencies. We can add one
# in seaborn later, if there is sufficient interest. For now, limit to mean.
raise ValueError(f"Weighted estimator must be 'mean', not {estimator!r}.")
self.estimator = estimator

method, level = _validate_errorbar_arg(errorbar)
if method is not None and method != "ci":
# As with the estimator, weighted 'sd' or 'pi' error bars may make sense.
# But we'll keep things simple for now and limit to (bootstrap) CI.
raise ValueError(f"Error bar method must be 'ci', not {method!r}.")
self.error_method = method
self.error_level = level

self.boot_kws = boot_kws

def __call__(self, data, var):
"""Aggregate over `var` column of `data` with estimate and error interval."""
vals = data[var]
weights = data["weight"]

estimate = np.average(vals, weights=weights)

if self.error_method == "ci" and len(data) > 1:

def error_func(x, w):
return np.average(x, weights=w)

boots = bootstrap(vals, weights, func=error_func, **self.boot_kws)
err_min, err_max = _percentile_interval(boots, self.error_level)

else:
err_min = err_max = np.nan

return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max})


class LetterValues:

def __init__(self, k_depth, outlier_prop, trust_alpha):
Expand Down
20 changes: 16 additions & 4 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from seaborn._core.scales import Scale
from seaborn._core.groupby import GroupBy
from seaborn._stats.base import Stat
from seaborn._statistics import EstimateAggregator
from seaborn._statistics import (
EstimateAggregator,
WeightedEstimateAggregator,
)
from seaborn._core.typing import Vector


Expand Down Expand Up @@ -54,8 +57,14 @@ class Est(Stat):
"""
Calculate a point estimate and error bar interval.
For additional information about the various `errorbar` choices, see
the :doc:`errorbar tutorial </tutorial/error_bars>`.
For more information about the various `errorbar` choices, see the
:doc:`errorbar tutorial </tutorial/error_bars>`.
Additional variables:
- **weight**: When passed to a layer that uses this stat, a weighted estimate
will be computed. Note that use of weights currently limits the choice of
function and error bar method to `"mean"` and `"ci"`, respectively.
Parameters
----------
Expand Down Expand Up @@ -95,7 +104,10 @@ def __call__(
) -> DataFrame:

boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)
if "weight" in data:
engine = WeightedEstimateAggregator(self.func, self.errorbar, **boot_kws)
else:
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

var = {"x": "y", "y": "x"}[orient]
res = (
Expand Down
11 changes: 11 additions & 0 deletions tests/_stats/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def test_median_pi(self, df):
expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"])
assert_frame_equal(res, expected)

def test_weighted_mean(self, df, rng):

weights = rng.uniform(0, 5, len(df))
gb = self.get_groupby(df[["x", "y"]], "x")
df = df.assign(weight=weights)
res = Est("mean")(df, gb, "x", {})
for _, res_row in res.iterrows():
rows = df[df["x"] == res_row["x"]]
expected = np.average(rows["y"], weights=rows["weight"])
assert res_row["y"] == expected

def test_seed(self, df):

ori = "x"
Expand Down
34 changes: 34 additions & 0 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ECDF,
EstimateAggregator,
LetterValues,
WeightedEstimateAggregator,
_validate_errorbar_arg,
_no_scipy,
)
Expand Down Expand Up @@ -632,6 +633,39 @@ def test_errorbar_validation(self):
_validate_errorbar_arg(arg)


class TestWeightedEstimateAggregator:

def test_weighted_mean(self, long_df):

long_df["weight"] = long_df["x"]
est = WeightedEstimateAggregator("mean")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
assert_array_equal(out["ymin"], np.nan)
assert_array_equal(out["ymax"], np.nan)

def test_weighted_ci(self, long_df):

long_df["weight"] = long_df["x"]
est = WeightedEstimateAggregator("mean", "ci")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
assert (out["ymin"] <= out["y"]).all()
assert (out["ymax"] >= out["y"]).all()

def test_limited_estimator(self):

with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"):
WeightedEstimateAggregator("median")

def test_limited_ci(self):

with pytest.raises(ValueError, match="Error bar method must be 'ci'"):
WeightedEstimateAggregator("mean", "sd")


class TestLetterValues:

@pytest.fixture
Expand Down

0 comments on commit 2bb945c

Please sign in to comment.