-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix mrmr working with categoricals (#1311)
- Loading branch information
1 parent
e486a24
commit 0225a51
Showing
5 changed files
with
63 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,43 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from etna.datasets import TSDataset | ||
from etna.datasets import duplicate_data | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def close_plots(): | ||
yield | ||
plt.close() | ||
|
||
|
||
@pytest.fixture | ||
def exog_and_target_dfs(): | ||
seg = ["a"] * 30 + ["b"] * 30 | ||
time = list(pd.date_range("2020-01-01", "2021-01-01")[:30]) | ||
timestamps = time * 2 | ||
target = np.arange(60) | ||
df = pd.DataFrame({"segment": seg, "timestamp": timestamps, "target": target}) | ||
ts = TSDataset.to_dataset(df) | ||
|
||
cast = ["1.1"] * 10 + ["2"] * 9 + [None] + ["56.1"] * 10 | ||
no_cast = ["1.1"] * 10 + ["two"] * 10 + ["56.1"] * 10 | ||
none = [1] * 10 + [2] * 10 + [56.1] * 10 | ||
none[10] = None | ||
df = pd.DataFrame( | ||
{ | ||
"timestamp": time, | ||
"exog1": np.arange(100, 70, -1), | ||
"exog2": np.sin(np.arange(30) / 10), | ||
"exog3": np.exp(np.arange(30)), | ||
"cast": cast, | ||
"no_cast": no_cast, | ||
"none": none, | ||
} | ||
) | ||
df["cast"] = df["cast"].astype("category") | ||
df["no_cast"] = df["no_cast"].astype("category") | ||
df_exog = duplicate_data(df, segments=["a", "b"]) | ||
return ts, df_exog |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0225a51
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉 Published on https://etna-docs.netlify.app as production
🚀 Deployed on https://64ac208678edb205c8dacbf8--etna-docs.netlify.app