Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce storage/retrieval conversion to UTC #12

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlos_bench/mlos_bench/storage/base_trial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Base interface for accessing the stored benchmark trial data.
"""
from abc import ABCMeta, abstractmethod
from datetime import datetime
from datetime import datetime, UTC
from typing import Any, Dict, Optional, TYPE_CHECKING

import pandas
Expand Down Expand Up @@ -38,6 +38,8 @@ def __init__(self, *,
self._experiment_id = experiment_id
self._trial_id = trial_id
self._tunable_config_id = tunable_config_id
assert ts_start.tzinfo == UTC, "ts_start must be in UTC"
assert ts_end is None or ts_end.tzinfo == UTC, "ts_end must be in UTC if not None"
self._ts_start = ts_start
self._ts_end = ts_end
self._status = status
Expand Down
25 changes: 20 additions & 5 deletions mlos_bench/mlos_bench/storage/sql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.util import utcify_timestamp, utcify_nullable_timestamp


def get_trials(
Expand Down Expand Up @@ -48,8 +49,8 @@ def get_trials(
experiment_id=experiment_id,
trial_id=trial.trial_id,
config_id=trial.config_id,
ts_start=trial.ts_start,
ts_end=trial.ts_end,
ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
status=Status[trial.status],
)
for trial in trials.fetchall()
Expand Down Expand Up @@ -107,9 +108,23 @@ def get_results_df(
)
cur_trials = conn.execute(cur_trials_stmt)
trials_df = pandas.DataFrame(
[(row.trial_id, row.ts_start, row.ts_end, row.config_id, row.tunable_config_trial_group_id, row.status)
for row in cur_trials.fetchall()],
columns=['trial_id', 'ts_start', 'ts_end', 'tunable_config_id', 'tunable_config_trial_group_id', 'status'])
[(
row.trial_id,
utcify_timestamp(row.ts_start, origin="utc"),
utcify_nullable_timestamp(row.ts_end, origin="utc"),
row.config_id,
row.tunable_config_trial_group_id,
row.status,
) for row in cur_trials.fetchall()],
columns=[
'trial_id',
'ts_start',
'ts_end',
'tunable_config_id',
'tunable_config_trial_group_id',
'status',
]
)

# Get each trial's config in wide format.
configs_stmt = schema.trial.select().with_only_columns(
Expand Down
11 changes: 8 additions & 3 deletions mlos_bench/mlos_bench/storage/sql/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.trial import Trial
from mlos_bench.storage.util import utcify_timestamp
from mlos_bench.util import nullable

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,7 +121,9 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
self._schema.trial_telemetry.c.metric_id,
)
)
return [(row.ts, row.metric_id, row.metric_value)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
for row in cur_telemetry.fetchall()]

def load(self,
Expand Down Expand Up @@ -184,6 +187,7 @@ def _save_params(conn: Connection, table: Table,

def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
timestamp = timestamp.astimezone(UTC) if timestamp.tzinfo else timestamp.replace(tzinfo=UTC)
if running:
pending_status = ['PENDING', 'READY', 'RUNNING']
else:
Expand Down Expand Up @@ -238,15 +242,16 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:

def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
_LOG.debug("Create trial: %s:%d", self._experiment_id, self._trial_id)
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local")
_LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
with self._engine.begin() as conn:
try:
config_id = self._get_config_id(conn, tunables)
conn.execute(self._schema.trial.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
config_id=config_id,
ts_start=ts_start or datetime.now(UTC),
ts_start=ts_start,
status='PENDING',
))

Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/storage/sql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, engine: Engine):
self._meta,
Column("exp_id", String(self._ID_LEN), nullable=False),
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime, nullable=False, default="now"),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("status", String(self._STATUS_LEN), nullable=False),

UniqueConstraint("exp_id", "trial_id", "ts"),
Expand All @@ -181,7 +181,7 @@ def __init__(self, engine: Engine):
self._meta,
Column("exp_id", String(self._ID_LEN), nullable=False),
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime, nullable=False, default="now"),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),

Expand Down
8 changes: 8 additions & 0 deletions mlos_bench/mlos_bench/storage/sql/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.util import utcify_timestamp
from mlos_bench.util import nullable

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self, *,
def update(self, status: Status, timestamp: datetime,
metrics: Optional[Union[Dict[str, Any], float]] = None
) -> Optional[Dict[str, Any]]:
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = super().update(status, timestamp, metrics)
with self._engine.begin() as conn:
self._update_status(conn, status, timestamp)
Expand Down Expand Up @@ -106,6 +109,9 @@ def update(self, status: Status, timestamp: datetime,
def update_telemetry(self, status: Status, timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]]) -> None:
super().update_telemetry(status, timestamp, metrics)
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
# NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
# and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
# a bulk upsert.
Expand All @@ -130,6 +136,8 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime)
Insert a new status record into the database.
This call is idempotent.
"""
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
try:
conn.execute(self._schema.trial_status.insert().values(
exp_id=self._experiment_id,
Expand Down
7 changes: 5 additions & 2 deletions mlos_bench/mlos_bench/storage/sql/trial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
An interface to access the benchmark trial data stored in SQL DB.
"""
from datetime import datetime
from datetime import datetime, UTC
from typing import Optional, TYPE_CHECKING

import pandas
Expand All @@ -16,6 +16,7 @@
from mlos_bench.environments.status import Status
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
from mlos_bench.storage.util import utcify_timestamp

if TYPE_CHECKING:
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
Expand Down Expand Up @@ -100,8 +101,10 @@ def telemetry_df(self) -> pandas.DataFrame:
self._schema.trial_telemetry.c.metric_id,
)
)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return pandas.DataFrame(
[(row.ts, row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()],
[(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()],
columns=['ts', 'metric', 'value'])

@property
Expand Down
45 changes: 44 additions & 1 deletion mlos_bench/mlos_bench/storage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,57 @@
Utility functions for the storage subsystem.
"""

from typing import Dict, Optional
from datetime import datetime, UTC
from typing import Dict, Literal, Optional

import pandas

from mlos_bench.tunables.tunable import TunableValue, TunableValueTypeTuple
from mlos_bench.util import try_parse_val


def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
"""
Augment a timestamp with zoneinfo if missing and convert it to UTC.

Parameters
----------
timestamp : datetime
A timestamp to convert to UTC.
Note: The original datetime may or may not have tzinfo associated with it.

origin : Literal["utc", "local"]
Whether the source timestamp is considered to be in UTC or local time.
In the case of loading data from storage, where we intentionally convert all
timestamps to UTC, this can help us retrieve the original timezone when the
storage backend doesn't explicitly store it.
Returns
-------
datetime
A datetime with zoneinfo in UTC.
"""
if timestamp.tzinfo is not None or origin == "local":
# A timestamp with no zoneinfo is interpretted as "local" time
# (e.g., according to the TZ environment variable).
# That could be UTC or some other timezone, but either way we convert it to
# be explicitly UTC with zone info.
return timestamp.astimezone(UTC)
elif origin == "utc":
# If the timestamp is already in UTC, we just add the zoneinfo without conversion.
# Converting with astimezone() when the local time is *not* UTC would cause
# a timestamp conversion which we don't want.
return timestamp.replace(tzinfo=UTC)
else:
raise ValueError(f"Invalid origin: {origin}")


def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]:
"""
A nullable version of utcify_timestamp.
"""
return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None


def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValue]]:
"""
Utility function to convert certain flat key-value dataframe formats used by the
Expand Down
1 change: 1 addition & 0 deletions mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def storage() -> SqlStorage:
config={
"drivername": "sqlite",
"database": ":memory:",
# "database": "mlos_bench.pytest.db",
}
)

Expand Down
61 changes: 51 additions & 10 deletions mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
"""
Unit tests for saving and restoring the telemetry data.
"""
from datetime import datetime, timedelta, UTC
from datetime import datetime, timedelta, tzinfo, UTC
from typing import Any, List, Optional, Tuple
from zoneinfo import ZoneInfo

import pytest
from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture

from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable_groups import TunableGroups
Expand All @@ -18,8 +20,7 @@
# pylint: disable=redefined-outer-name


@pytest.fixture
def telemetry_data() -> List[Tuple[datetime, str, Any]]:
def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]:
"""
Mock telemetry data for the trial.

Expand All @@ -28,7 +29,7 @@ def telemetry_data() -> List[Tuple[datetime, str, Any]]:
List[Tuple[datetime, str, str]]
A list of (timestamp, metric_id, metric_value)
"""
timestamp1 = datetime.now(UTC)
timestamp1 = datetime.now(zone_info)
timestamp2 = timestamp1 + timedelta(seconds=1)
return sorted([
(timestamp1, "cpu_load", 10.1),
Expand All @@ -40,40 +41,80 @@ def telemetry_data() -> List[Tuple[datetime, str, Any]]:
])


@pytest.fixture
def telemetry_data_implicit_local() -> List[Tuple[datetime, str, Any]]:
"""Telemetry data with implicit (i.e., missing) local timezone info."""
return zoned_telemetry_data(zone_info=None)


@pytest.fixture
def telemetry_data_utc() -> List[Tuple[datetime, str, Any]]:
"""Telemetry data with explicit UTC timezone info."""
return zoned_telemetry_data(zone_info=UTC)


@pytest.fixture
def telemetry_data_explicit() -> List[Tuple[datetime, str, Any]]:
"""Telemetry data with explicit UTC timezone info."""
zone_info = ZoneInfo("America/Chicago")
assert zone_info.utcoffset(datetime.now(UTC)) != timedelta(hours=0)
return zoned_telemetry_data(zone_info)


ZONE_INFO: List[Optional[tzinfo]] = [UTC, ZoneInfo("America/Chicago"), None]


def _telemetry_str(data: List[Tuple[datetime, str, Any]]
) -> List[Tuple[datetime, str, Optional[str]]]:
"""
Convert telemetry values to strings.
"""
return [(ts, key, nullable(str, val)) for (ts, key, val) in data]
# All retrieved timestamps should have been converted to UTC.
return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data]


@pytest.mark.parametrize(("telemetry_data"), [
(lazy_fixture("telemetry_data_implicit_local")),
(lazy_fixture("telemetry_data_utc")),
(lazy_fixture("telemetry_data_explicit")),
])
@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
def test_update_telemetry(storage: Storage,
exp_storage: Storage.Experiment,
tunable_groups: TunableGroups,
telemetry_data: List[Tuple[datetime, str, Any]]) -> None:
telemetry_data: List[Tuple[datetime, str, Any]],
origin_zone_info: Optional[tzinfo]) -> None:
"""
Make sure update_telemetry() and load_telemetry() methods work.
"""
trial = exp_storage.new_trial(tunable_groups)
assert exp_storage.load_telemetry(trial.trial_id) == []

trial.update_telemetry(Status.RUNNING, datetime.now(UTC), telemetry_data)
trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data)
assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)

# Also check that the TrialData telemetry looks right.
trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
assert _telemetry_str([tuple(r) for r in trial_data.telemetry_df.to_numpy()]) == _telemetry_str(telemetry_data)
trial_telemetry_df = trial_data.telemetry_df
trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()]
assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data)


@pytest.mark.parametrize(("telemetry_data"), [
(lazy_fixture("telemetry_data_implicit_local")),
(lazy_fixture("telemetry_data_utc")),
(lazy_fixture("telemetry_data_explicit")),
])
@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
def test_update_telemetry_twice(exp_storage: Storage.Experiment,
tunable_groups: TunableGroups,
telemetry_data: List[Tuple[datetime, str, Any]]) -> None:
telemetry_data: List[Tuple[datetime, str, Any]],
origin_zone_info: Optional[tzinfo]) -> None:
"""
Make sure update_telemetry() call is idempotent.
"""
trial = exp_storage.new_trial(tunable_groups)
timestamp = datetime.now(UTC)
timestamp = datetime.now(origin_zone_info)
trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
Expand Down
Loading