diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index d9aecd7b54..c886589fae 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -6,7 +6,7 @@ Base interface for accessing the stored benchmark trial data. """ from abc import ABCMeta, abstractmethod -from datetime import datetime +from datetime import datetime, UTC from typing import Any, Dict, Optional, TYPE_CHECKING import pandas @@ -38,6 +38,8 @@ def __init__(self, *, self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id + assert ts_start.tzinfo == UTC, "ts_start must be in UTC" + assert ts_end is None or ts_end.tzinfo == UTC, "ts_end must be in UTC if not None" self._ts_start = ts_start self._ts_end = ts_end self._status = status diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index ce08e839b3..8ecff95b45 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -14,6 +14,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.storage.sql.schema import DbSchema +from mlos_bench.storage.util import utcify_timestamp, utcify_nullable_timestamp def get_trials( @@ -48,8 +49,8 @@ def get_trials( experiment_id=experiment_id, trial_id=trial.trial_id, config_id=trial.config_id, - ts_start=trial.ts_start, - ts_end=trial.ts_end, + ts_start=utcify_timestamp(trial.ts_start, origin="utc"), + ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), status=Status[trial.status], ) for trial in trials.fetchall() @@ -107,9 +108,23 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [(row.trial_id, row.ts_start, row.ts_end, row.config_id, row.tunable_config_trial_group_id, row.status) - for row in cur_trials.fetchall()], - columns=['trial_id', 'ts_start', 'ts_end', 'tunable_config_id', 'tunable_config_trial_group_id', 'status']) + [( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) for row in cur_trials.fetchall()], + columns=[ + 'trial_id', + 'ts_start', + 'ts_end', + 'tunable_config_id', + 'tunable_config_trial_group_id', + 'status', + ] + ) # Get each trial's config in wide format. configs_stmt = schema.trial.select().with_only_columns( diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index b1244b285d..1aa3f21e9c 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -18,6 +18,7 @@ from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.trial import Trial +from mlos_bench.storage.util import utcify_timestamp from mlos_bench.util import nullable _LOG = logging.getLogger(__name__) @@ -120,7 +121,9 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: self._schema.trial_telemetry.c.metric_id, ) ) - return [(row.ts, row.metric_id, row.metric_value) + # Not all storage backends store the original zone info. + # We try to ensure data is entered in UTC and augment it on return again here. + return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()] def load(self, @@ -183,6 +186,7 @@ def _save_params(conn: Connection, table: Table, ]) def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: + timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) if running: pending_status = ['PENDING', 'READY', 'RUNNING'] @@ -238,7 +242,8 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, config: Optional[Dict[str, Any]] = None) -> Storage.Trial: - _LOG.debug("Create trial: %s:%d", self._experiment_id, self._trial_id) + ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") + _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) @@ -246,7 +251,7 @@ def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None exp_id=self._experiment_id, trial_id=self._trial_id, config_id=config_id, - ts_start=ts_start or datetime.now(UTC), + ts_start=ts_start, status='PENDING', )) diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 9fc801b3eb..c59adc1c67 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -155,7 +155,7 @@ def __init__(self, engine: Engine): self._meta, Column("exp_id", String(self._ID_LEN), nullable=False), Column("trial_id", Integer, nullable=False), - Column("ts", DateTime, nullable=False, default="now"), + Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), UniqueConstraint("exp_id", "trial_id", "ts"), @@ -181,7 +181,7 @@ def __init__(self, engine: Engine): self._meta, Column("exp_id", String(self._ID_LEN), nullable=False), Column("trial_id", Integer, nullable=False), - Column("ts", DateTime, nullable=False, default="now"), + Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 74c80e158c..1fd38684e8 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -17,6 +17,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.sql.schema import DbSchema +from mlos_bench.storage.util import utcify_timestamp from mlos_bench.util import nullable _LOG = logging.getLogger(__name__) @@ -47,6 +48,8 @@ def __init__(self, *, def update(self, status: Status, timestamp: datetime, metrics: Optional[Union[Dict[str, Any], float]] = None ) -> Optional[Dict[str, Any]]: + # Make sure to convert the timestamp to UTC before storing it in the database. + timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) with self._engine.begin() as conn: self._update_status(conn, status, timestamp) @@ -106,6 +109,9 @@ def update(self, status: Status, timestamp: datetime, def update_telemetry(self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]]) -> None: super().update_telemetry(status, timestamp, metrics) + # Make sure to convert the timestamp to UTC before storing it in the database. + timestamp = utcify_timestamp(timestamp, origin="local") + metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of # a bulk upsert. @@ -130,6 +136,8 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) Insert a new status record into the database. This call is idempotent. """ + # Make sure to convert the timestamp to UTC before storing it in the database. + timestamp = utcify_timestamp(timestamp, origin="local") try: conn.execute(self._schema.trial_status.insert().values( exp_id=self._experiment_id, diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index e59664272e..71df2b3a45 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -16,6 +16,7 @@ from mlos_bench.environments.status import Status from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData +from mlos_bench.storage.util import utcify_timestamp if TYPE_CHECKING: from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData @@ -100,8 +101,10 @@ def telemetry_df(self) -> pandas.DataFrame: self._schema.trial_telemetry.c.metric_id, ) ) + # Not all storage backends store the original zone info. + # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [(row.ts, row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], + [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], columns=['ts', 'metric', 'value']) @property diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index a4610da8de..3deb52c23e 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -6,7 +6,8 @@ Utility functions for the storage subsystem. """ -from typing import Dict, Optional +from datetime import datetime, UTC +from typing import Dict, Literal, Optional import pandas @@ -14,6 +15,51 @@ from mlos_bench.util import try_parse_val +def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: + """ + Augment a timestamp with zoneinfo if missing and convert it to UTC. + + Parameters + ---------- + timestamp : datetime + A timestamp to convert to UTC. + Note: The original datetime may or may not have tzinfo associated with it. + + origin : Literal["utc", "local"] + Whether the source timestamp is considered to be in UTC or local time. + In the case of loading data from storage, where we intentionally convert all + timestamps to UTC, this can help us retrieve the original timezone when the + storage backend doesn't explicitly store it. + In the case of receiving data from a client or other source, this can help us + convert the timestamp to UTC if it's not already. + + Returns + ------- + datetime + A datetime with zoneinfo in UTC. + """ + if timestamp.tzinfo is not None or origin == "local": + # A timestamp with no zoneinfo is interpretted as "local" time + # (e.g., according to the TZ environment variable). + # That could be UTC or some other timezone, but either way we convert it to + # be explicitly UTC with zone info. + return timestamp.astimezone(UTC) + elif origin == "utc": + # If the timestamp is already in UTC, we just add the zoneinfo without conversion. + # Converting with astimezone() when the local time is *not* UTC would cause + # a timestamp conversion which we don't want. + return timestamp.replace(tzinfo=UTC) + else: + raise ValueError(f"Invalid origin: {origin}") + + +def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: + """ + A nullable version of utcify_timestamp. + """ + return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None + + def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValue]]: """ Utility function to convert certain flat key-value dataframe formats used by the diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 03cc2bf780..09366d83a4 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -34,6 +34,7 @@ def storage() -> SqlStorage: config={ "drivername": "sqlite", "database": ":memory:", + # "database": "mlos_bench.pytest.db", } ) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index 090b946243..3c7cbd18ac 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -5,8 +5,9 @@ """ Unit tests for saving and restoring the telemetry data. """ -from datetime import datetime, timedelta, UTC +from datetime import datetime, timedelta, tzinfo, UTC from typing import Any, List, Optional, Tuple +from zoneinfo import ZoneInfo import pytest @@ -18,8 +19,7 @@ # pylint: disable=redefined-outer-name -@pytest.fixture -def telemetry_data() -> List[Tuple[datetime, str, Any]]: +def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]: """ Mock telemetry data for the trial. @@ -28,7 +28,7 @@ def telemetry_data() -> List[Tuple[datetime, str, Any]]: List[Tuple[datetime, str, str]] A list of (timestamp, metric_id, metric_value) """ - timestamp1 = datetime.now(UTC) + timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) return sorted([ (timestamp1, "cpu_load", 10.1), @@ -40,40 +40,57 @@ def telemetry_data() -> List[Tuple[datetime, str, Any]]: ]) +ZONE_INFO: List[Optional[tzinfo]] = [ + # Explicit time zones. + UTC, + ZoneInfo("America/Chicago"), + ZoneInfo("America/Los_Angeles"), + # Implicit local time zone. + None, +] + + def _telemetry_str(data: List[Tuple[datetime, str, Any]] ) -> List[Tuple[datetime, str, Optional[str]]]: """ Convert telemetry values to strings. """ - return [(ts, key, nullable(str, val)) for (ts, key, val) in data] + # All retrieved timestamps should have been converted to UTC. + return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data] +@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) def test_update_telemetry(storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups, - telemetry_data: List[Tuple[datetime, str, Any]]) -> None: + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() and load_telemetry() methods work. """ + telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) assert exp_storage.load_telemetry(trial.trial_id) == [] - trial.update_telemetry(Status.RUNNING, datetime.now(UTC), telemetry_data) + trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data) assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data) # Also check that the TrialData telemetry looks right. trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id] - assert _telemetry_str([tuple(r) for r in trial_data.telemetry_df.to_numpy()]) == _telemetry_str(telemetry_data) + trial_telemetry_df = trial_data.telemetry_df + trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()] + assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data) +@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) def test_update_telemetry_twice(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, - telemetry_data: List[Tuple[datetime, str, Any]]) -> None: + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() call is idempotent. """ + telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) - timestamp = datetime.now(UTC) + timestamp = datetime.now(origin_zone_info) trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test_alt_tz.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test_alt_tz.py new file mode 100644 index 0000000000..ee89d494d7 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test_alt_tz.py @@ -0,0 +1,33 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Unit tests for saving and restoring the telemetry data when host timezone is in a different timezone. +""" + +from subprocess import run +import os +import sys +from typing import Optional + +import pytest + + +@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") +@pytest.mark.parametrize(("tz_name"), [None, "America/Chicago", "America/Los_Angeles", "UTC"]) +def test_trial_telemetry_alt_tz(tz_name: Optional[str]) -> None: + """ + Run the trial telemetry tests under alternative default (un-named) TZ info. + """ + env = os.environ.copy() + if tz_name is None: + env.pop("TZ", None) + else: + env["TZ"] = tz_name + cmd = run( + [sys.executable, "-m", "pytest", "-n0", f"{os.path.dirname(__file__)}/trial_telemetry_test.py"], + env=env, + check=True, + ) + assert cmd.returncode == 0