Skip to content

Commit

Permalink
before forecasting 3
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jul 7, 2023
1 parent 084d5c7 commit e72c63c
Show file tree
Hide file tree
Showing 10 changed files with 5,938 additions and 5,158 deletions.
2 changes: 1 addition & 1 deletion atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
"""

from atom.api import ATOMClassifier, ATOMModel, ATOMRegressor
from atom.api import ATOMClassifier, ATOMForecaster, ATOMModel, ATOMRegressor
from atom.utils import __version__
3 changes: 2 additions & 1 deletion atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ class ATOMForecaster(BaseTransformer, ATOM):
n_rows: int or float, default=1
Subsample of the dataset to use. The cut is made from the head
of the dataset. The default value selects all rows.
of the dataset (older entries are dropped when sorted by date
ascending). The default value selects all rows.
- If <=1: Fraction of the dataset to select.
- If >1: Exact number of rows to select. Only if `arrays` is X
Expand Down
4 changes: 3 additions & 1 deletion atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,8 @@ def _subsample(df: DATAFRAME) -> DATAFRAME:

if self.shuffle:
return df.iloc[random.sample(range(len(df)), k=n_rows)]
elif self.goal == "fc":
return df.iloc[-n_rows:] # For time series, select from tail
else:
return df.iloc[sorted(random.sample(range(len(df)), k=n_rows))]

Expand Down Expand Up @@ -847,7 +849,7 @@ def _has_data_sets(
return self.branch._data, self.branch._idx, self.branch._holdout

Check notice on line 849 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class

Check notice on line 849 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Accessing a protected member of a class or a module

Access to a protected member _idx of a class

Check notice on line 849 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Accessing a protected member of a class or a module

Access to a protected member _holdout of a class

elif len(arrays) == 1:
# arrays=(X,)
# arrays=(X,) or arrays=(y,) for forecasting
sets = _no_data_sets(*self._prepare_input(arrays[0], y=y))

elif len(arrays) == 2:
Expand Down
6 changes: 2 additions & 4 deletions atom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,9 +1972,6 @@ def get_custom_scorer(metric: str | Callable | Scorer) -> Scorer:
def infer_task(y: PANDAS, goal: str = "class") -> str:
"""Infer the task corresponding to a target column.
If goal is provided, only look at number of unique values to
determine the classification task.
Parameters
----------
y: series or dataframe
Expand All @@ -1994,6 +1991,8 @@ def infer_task(y: PANDAS, goal: str = "class") -> str:
return "regression"
else:
return "multioutput regression"
elif goal == "fc":
return "forecasting"

if y.ndim > 1:
if all(y[col].nunique() == 2 for col in y):
Expand Down Expand Up @@ -2467,7 +2466,6 @@ def prepare_df(out: FEATURES, og: DATAFRAME) -> DATAFRAME:
)
if isinstance(y, DATAFRAME_TYPES):
y_new = prepare_df(y_new, y)

elif "X" in params and X is not None and any(c in X for c in inc):
# X in -> X out
X_new = prepare_df(out, X)

Check notice on line 2471 in atom/utils.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Expand Down
8 changes: 4 additions & 4 deletions docs_sources/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ And operating systems:
ATOM is built on top of several existing Python libraries. These
packages are necessary for its correct functioning.

* **[category-encoders](https://contrib.scikit-learn.org/categorical-encoding/index.html)** (>=2.5.1)
* **[category-encoders](https://contrib.scikit-learn.org/categorical-encoding/index.html)** (>=2.6.1)
* **[dill](https://pypi.org/project/dill/)** (>=0.3.6)
* **[dagshub](https://github.com/DagsHub/client)** (<=0.2.10)
* **[dagshub](https://github.com/DagsHub/client)** (>=0.2.18)
* **[gplearn](https://gplearn.readthedocs.io/en/stable/index.html)** (>=0.4.2)
* **[imbalanced-learn](https://imbalanced-learn.readthedocs.io/en/stable/api.html)** (>=0.10.1)
* **[ipython](https://ipython.readthedocs.io/en/stable/)** (>=8.11.0)
Expand Down Expand Up @@ -87,8 +87,8 @@ using `pip install atom-ml[dev]`.
* **[pytest](https://docs.pytest.org/en/latest/)** (>=7.2.1)
* **[pytest-cov](https://pytest-cov.readthedocs.io/en/latest/)** (>=4.0.0)
* **[pytest-xdist](https://github.com/pytest-dev/pytest-xdist)** (>=3.2.0)
* **[scikeras](https://github.com/adriangb/scikeras)** (>=0.10.0)
* **[tensorflow](https://www.tensorflow.org/learn)** (>=2.11.0)
* **[scikeras](https://github.com/adriangb/scikeras)** (>=0.11.0)
* **[tensorflow](https://www.tensorflow.org/learn)** (>=2.13.0)

**Documentation**

Expand Down
3 changes: 2 additions & 1 deletion docs_sources/user_guide/nomenclature.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,15 @@ an estimator's fit method.

<div id="task"><strong>task</strong></div>
<div markdown style="margin: -1em 0 0 1.2em">
One of the six supervised machine learning approaches that ATOM supports:
One of the supervised machine learning approaches that ATOM supports:
<ul style="line-height:1.2em;margin-top:-10px">
<li><a href="https://en.wikipedia.org/wiki/Binary_classification">binary classification</a></li>
<li><a href="https://en.wikipedia.org/wiki/Multiclass_classification">multiclass classification</a></li>
<li><a href="https://scikit-learn.org/stable/modules/multiclass.html#multilabel-classification">multilabel classification</a></li>
<li><a href="https://scikit-learn.org/stable/modules/multiclass.html#multiclass-multioutput-classification">multiclass-multioutput classification</a></li>
<li><a href="https://en.wikipedia.org/wiki/Regression_analysis">regression</a></li>
<li><a href="https://scikit-learn.org/stable/modules/multiclass.html#multioutput-regression">multioutput regression</a></li>
<li><a href="https://en.wikipedia.org/wiki/Time_series">forecasting</a></li>
</ul>
</div>

Expand Down
26 changes: 14 additions & 12 deletions examples/automated_feature_scaling.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit e72c63c

Please sign in to comment.