Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prediction no longer takes a time series dataset only table #838

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3c51747
prediction no longer takes a time series dataset only table
Gerhardsa0 Jun 12, 2024
35e3a7a
added type notations for mypy
Gerhardsa0 Jun 12, 2024
456c651
returns a TimeSeriesDataset again
Gerhardsa0 Jun 12, 2024
8fc059e
style: apply automated linter fixes
megalinter-bot Jun 12, 2024
33ddba9
style: apply automated linter fixes
megalinter-bot Jun 12, 2024
edd9883
added code cov for is_predict_data_valid
Gerhardsa0 Jun 12, 2024
d34e66a
Merge remote-tracking branch 'origin/837-feat-change-that-nn-takes-ti…
Gerhardsa0 Jun 12, 2024
716f917
fixed init
Gerhardsa0 Jun 20, 2024
b35ec7f
for some reason the numerical grouped snapshot fails
Gerhardsa0 Jun 20, 2024
791956f
fixed linter
Gerhardsa0 Jun 20, 2024
d1f12f6
fixed linter
Gerhardsa0 Jun 20, 2024
33d3812
fixed linter
Gerhardsa0 Jun 20, 2024
b7005d2
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jun 20, 2024
73f7681
removed that it takes also time series dataset
Gerhardsa0 Jun 20, 2024
140ef46
Merge remote-tracking branch 'origin/837-feat-change-that-nn-takes-ti…
Gerhardsa0 Jun 20, 2024
a155ad0
fixed linter
Gerhardsa0 Jun 20, 2024
d26d9d3
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
7e5d564
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jun 24, 2024
8469ab9
added extranames to eq hash and size_of
Gerhardsa0 Jun 27, 2024
a365bef
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jun 27, 2024
3db24fb
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jun 28, 2024
adad8c2
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jun 30, 2024
8180b00
Merge branch 'main' into 837-feat-change-that-nn-takes-time-series-fo…
Gerhardsa0 Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 26 additions & 34 deletions docs/tutorials/time_series_forecasting.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/safeds/data/tabular/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"Column": "._column:Column",
"Row": "._row:Row",
"StringCell": "._string_cell:StringCell",
"TemporalCell": "._temporal_cell:",
"TemporalCell": "._temporal_cell:TemporalCell",
"Table": "._table:Table",
},
)
Expand Down
59 changes: 42 additions & 17 deletions src/safeds/ml/nn/converters/_input_converter_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from safeds._utils import _structural_hash
from safeds.data.labeled.containers import TimeSeriesDataset
from safeds.data.tabular.containers import Column
from safeds.data.tabular.containers import Column, Table

from ._input_converter import InputConversion

Expand All @@ -14,7 +14,7 @@
from torch.utils.data import DataLoader


class InputConversionTimeSeries(InputConversion[TimeSeriesDataset, TimeSeriesDataset]):
class InputConversionTimeSeries(InputConversion[TimeSeriesDataset, Table]):
"""The input conversion for a neural network, defines the input parameters for the neural network."""

def __init__(
Expand All @@ -24,9 +24,9 @@ def __init__(
self._forecast_horizon = 0
self._first = True
self._target_name: str = ""
self._time_name: str = ""
self._feature_names: list[str] = []
self._continuous: bool = False
self._extra_names: list[str] = []
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
Expand All @@ -38,6 +38,7 @@ def __eq__(self, other: object) -> bool:
and self._target_name == other._target_name
and self._feature_names == other._feature_names
and self._continuous == other._continuous
and self._extra_names == other._extra_names
)

def __hash__(self) -> int:
Expand All @@ -46,19 +47,19 @@ def __hash__(self) -> int:
self._window_size,
self._forecast_horizon,
self._target_name,
self._time_name,
self._feature_names,
self._continuous,
self._extra_names,
)

def __sizeof__(self) -> int:
return (
sys.getsizeof(self._window_size)
+ sys.getsizeof(self._forecast_horizon)
+ sys.getsizeof(self._target_name)
+ sys.getsizeof(self._time_name)
+ sys.getsizeof(self._feature_names)
+ sys.getsizeof(self._continuous)
+ sys.getsizeof(self._extra_names)
)

@property
Expand Down Expand Up @@ -87,26 +88,41 @@ def _data_conversion_fit(
continuous=self._continuous,
)

def _data_conversion_predict(self, input_data: TimeSeriesDataset, batch_size: int) -> DataLoader:
return input_data._into_dataloader_with_window_predict(self._window_size, self._forecast_horizon, batch_size)
def _data_conversion_predict(self, input_data: Table, batch_size: int) -> DataLoader:
data: TimeSeriesDataset
data = input_data.to_time_series_dataset(
target_name=self._target_name,
window_size=self._window_size,
extra_names=self._extra_names,
forecast_horizon=self._forecast_horizon,
continuous=self._continuous,
)
return data._into_dataloader_with_window_predict(
self._window_size,
self._forecast_horizon,
batch_size,
)

def _data_conversion_output(
self,
input_data: TimeSeriesDataset,
input_data: Table,
output_data: Tensor,
) -> TimeSeriesDataset:
table_data: Table
window_size: int = self._window_size
forecast_horizon: int = self._forecast_horizon
input_data_table = input_data.to_table()
input_data_table = input_data_table.slice_rows(start=window_size + forecast_horizon)
table_data = input_data
input_data_table = table_data.slice_rows(start=window_size + forecast_horizon)

return input_data_table.replace_column(
self._target_name,
[Column(self._target_name, output_data.tolist())],
).to_time_series_dataset(
target_name=self._target_name,
extra_names=input_data.extras.column_names,
window_size=window_size,
window_size=self._window_size,
extra_names=self._extra_names,
forecast_horizon=self._forecast_horizon,
continuous=self._continuous,
)

def _is_fit_data_valid(self, input_data: TimeSeriesDataset) -> bool:
Expand All @@ -117,9 +133,18 @@ def _is_fit_data_valid(self, input_data: TimeSeriesDataset) -> bool:
self._target_name = input_data.target.name
self._continuous = input_data._continuous
self._first = False
return (sorted(input_data.features.column_names)).__eq__(
sorted(self._feature_names),
) and input_data.target.name == self._target_name
self._extra_names = input_data.extras.column_names
return (
sorted(input_data.features.column_names).__eq__(
sorted(self._feature_names),
)
and input_data.target.name == self._target_name
)

def _is_predict_data_valid(self, input_data: TimeSeriesDataset) -> bool:
return self._is_fit_data_valid(input_data)
def _is_predict_data_valid(self, input_data: Table) -> bool:
for name in self._feature_names:
if name not in input_data.column_names:
return False
if self._target_name not in input_data.column_names:
return False
return True
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ def test_should_raise_if_is_fitted_is_set_correctly_lstm() -> None:
)
assert not model.is_fitted
model = model.fit(ts)
model.predict(ts)
model.predict(ts.to_table())
assert model.is_fitted


def test_is_predict_data_valid() -> None:
input_conv = InputConversionTimeSeries()
data = Table({"target": [1, 1, 1, 1], "time": [0, 0, 0, 0], "feat": [0, 0, 0, 0]})
assert not input_conv._is_predict_data_valid(data)
input_conv._feature_names = ["XYZ"]
assert not input_conv._is_predict_data_valid(data)


class TestEq:
@pytest.mark.parametrize(
("output_conversion_ts1", "output_conversion_ts2"),
Expand Down
20 changes: 2 additions & 18 deletions tests/safeds/ml/nn/test_lstm_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,7 @@ def test_lstm_model(device: Device) -> None:
epoch_size=1,
)

trained_model.predict(
test_table.to_time_series_dataset(
"value",
window_size=7,
forecast_horizon=12,
continuous=True,
extra_names=["date"],
),
)
trained_model.predict(test_table)
trained_model_2 = model_2.fit(
train_table.to_time_series_dataset(
"value",
Expand All @@ -68,14 +60,6 @@ def test_lstm_model(device: Device) -> None:
epoch_size=1,
)

trained_model_2.predict(
test_table.to_time_series_dataset(
"value",
window_size=7,
forecast_horizon=12,
continuous=False,
extra_names=["date"],
),
)
trained_model_2.predict(test_table)
assert trained_model._model is not None
assert trained_model._model.state_dict()["_pytorch_layers.0._layer.weight"].device == _get_device()