-
Notifications
You must be signed in to change notification settings - Fork 11
/
test_predictions.py
97 lines (71 loc) · 2.76 KB
/
test_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import polars as pl
from polars.testing import assert_series_equal
import statsmodels.formula.api as smf
from tests.conftest import guerry, penguins, diamonds
import marginaleffects
from marginaleffects import *
from tests.utilities import *
df = guerry.with_columns(pl.Series(range(guerry.shape[0])).alias("row_id")).sort(
"Region", "row_id"
)
mod_py = smf.ols("Literacy ~ Pop1831 * Desertion", df).fit()
def test_newdata_balanced():
mod = smf.ols(
"body_mass_g ~ flipper_length_mm * species * bill_length_mm + island",
penguins.to_pandas(),
).fit()
p = predictions(mod, newdata="balanced")
assert p.shape[0] == 9
def test_predictions():
pre_py = predictions(mod_py)
pre_r = pl.read_csv("tests/r/test_predictions_01.csv")
compare_r_to_py(pre_r, pre_py)
def test_by():
pre_py = predictions(mod_py, by="Region")
pre_r = pl.read_csv("tests/r/test_predictions_02.csv")
compare_r_to_py(pre_r, pre_py)
def test_by_hypothesis():
pre_py = predictions(mod_py, by="Region")
pre_py = predictions(mod_py, by="Region", hypothesis="b0 * b2 = b2*2")
pre_r = pl.read_csv("tests/r/test_predictions_03.csv")
compare_r_to_py(pre_r, pre_py)
def test_class_manipulation():
p = predictions(mod_py)
assert isinstance(p, pl.DataFrame)
assert isinstance(p, marginaleffects.classes.MarginaleffectsDataFrame)
p = p.head()
assert isinstance(p, pl.DataFrame)
assert isinstance(p, marginaleffects.classes.MarginaleffectsDataFrame)
def issue_38():
p = avg_predictions(mod_py, by=True)
assert p.shape[0] == 1
p = avg_predictions(mod_py)
assert p.shape[0] == 1
def issue_59():
p = predictions(mod_py, vcov=False)
assert p.shape[0] == df.shape[0]
assert p.shape[1] > 20
def test_issue_83():
diamonds83 = diamonds.with_columns(
cut_ideal_null=pl.when(pl.col("cut") == "Ideal")
.then(pl.lit(None))
.otherwise(pl.col("cut"))
)
model = smf.ols("price ~ cut_ideal_null", diamonds83.to_pandas()).fit()
newdata = diamonds.slice(0, 20)
newdata = newdata.with_columns(
cut_ideal_null=pl.when(pl.col("cut") == "Ideal")
.then(pl.lit("Premium"))
.otherwise(pl.col("cut"))
)
p = predictions(model, newdata=newdata)
assert p.shape[0] == newdata.shape[0]
def test_issue_95():
model = smf.ols("price ~ cut + clarity + color", diamonds.to_pandas()).fit()
newdata = diamonds.slice(0, 20)
p = predictions(model, newdata=newdata, by="cut")
newdata = newdata.with_columns(pred=pl.Series(model.predict(newdata.to_pandas())))
newdata = newdata.group_by("cut").agg(pl.col("pred").mean())
p = p.sort(by="cut")
newdata = newdata.sort(by="cut")
assert_series_equal(p["estimate"], newdata["pred"], check_names=False)