Skip to content

Commit

Permalink
map_elements() -> map_batches() (#131)
Browse files Browse the repository at this point in the history
* map_elements -> map_batches

* toml

* maps_elements -> map_batches

* lint
  • Loading branch information
vincentarelbundock authored Oct 19, 2024
1 parent e0b817a commit 1d599fb
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
9 changes: 7 additions & 2 deletions marginaleffects/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ def __str__(self):
tmp = self.select(valid).rename(self.mapping)
for col in tmp.columns:
if tmp[col].dtype.is_numeric():
tmp = tmp.with_columns(
pl.col(col).map_elements(lambda x: f"{x:.3g}", return_dtype=pl.Utf8)

def fmt(x):
out = pl.Series([f"{i:.3g}" for i in x])
return out

tmp.with_columns(
pl.col(col).map_batches(fmt, return_dtype=pl.Utf8).alias(col)
)
out += tmp.__str__()
out = out + f"\n\nColumns: {', '.join(self.columns)}\n"
Expand Down
8 changes: 4 additions & 4 deletions marginaleffects/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def get_equivalence(
if np.isinf(df):
x = x.with_columns(
pl.col("statistic_noninf")
.map_elements(lambda x: 1 - norm.cdf(x), return_dtype=pl.Float64)
.map_batches(lambda x: 1 - norm.cdf(x), return_dtype=pl.Float64)
.alias("p_value_noninf"),
pl.col("statistic_nonsup")
.map_elements(lambda x: norm.cdf(x), return_dtype=pl.Float64)
.map_batches(lambda x: norm.cdf(x), return_dtype=pl.Float64)
.alias("p_value_nonsup"),
)
else:
x = x.with_columns(
pl.col("statistic_noninf")
.map_elements(lambda x: 1 - t.cdf(x), return_dtype=pl.Float64)
.map_batches(lambda x: 1 - t.cdf(x), return_dtype=pl.Float64)
.alias("p_value_noninf"),
pl.col("statistic_nonsup")
.map_elements(lambda x: t.cdf(x), return_dtype=pl.Float64)
.map_batches(lambda x: t.cdf(x), return_dtype=pl.Float64)
.alias("p_value_nonsup"),
)

Expand Down
4 changes: 2 additions & 2 deletions marginaleffects/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):

df = df.with_columns(
pl.col("statistic")
.map_elements(
.map_batches(
lambda x: (2 * (1 - stats.t.cdf(np.abs(x), dof))), return_dtype=pl.Float64
)
.alias("p_value")
Expand All @@ -76,7 +76,7 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):
try:
df = df.with_columns(
pl.col("p_value")
.map_elements(lambda x: -np.log2(x), return_dtype=pl.Float64)
.map_batches(lambda x: -np.log2(x), return_dtype=pl.Float64)
.alias("s_value")
)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "marginaleffects"
version = "0.0.13"
version = "0.0.13.1"
description = "Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1d599fb

Please sign in to comment.