Skip to content

Commit

Permalink
wip: canonicalize storage and return of data in utc format
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth committed Mar 18, 2024
1 parent 08c3180 commit 6b0f411
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 15 deletions.
4 changes: 3 additions & 1 deletion mlos_bench/mlos_bench/storage/base_trial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Base interface for accessing the stored benchmark trial data.
"""
from abc import ABCMeta, abstractmethod
from datetime import datetime
from datetime import datetime, UTC
from typing import Any, Dict, Optional, TYPE_CHECKING

import pandas
Expand Down Expand Up @@ -38,6 +38,8 @@ def __init__(self, *,
self._experiment_id = experiment_id
self._trial_id = trial_id
self._tunable_config_id = tunable_config_id
assert ts_start.tzinfo == UTC, "ts_start must be in UTC"
assert ts_end is None or ts_end.tzinfo == UTC, "ts_end must be in UTC if not None"
self._ts_start = ts_start
self._ts_end = ts_end
self._status = status
Expand Down
25 changes: 20 additions & 5 deletions mlos_bench/mlos_bench/storage/sql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.util import utcify_timestamp, utcify_nullable_timestamp


def get_trials(
Expand Down Expand Up @@ -48,8 +49,8 @@ def get_trials(
experiment_id=experiment_id,
trial_id=trial.trial_id,
config_id=trial.config_id,
ts_start=trial.ts_start,
ts_end=trial.ts_end,
ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
status=Status[trial.status],
)
for trial in trials.fetchall()
Expand Down Expand Up @@ -107,9 +108,23 @@ def get_results_df(
)
cur_trials = conn.execute(cur_trials_stmt)
trials_df = pandas.DataFrame(
[(row.trial_id, row.ts_start, row.ts_end, row.config_id, row.tunable_config_trial_group_id, row.status)
for row in cur_trials.fetchall()],
columns=['trial_id', 'ts_start', 'ts_end', 'tunable_config_id', 'tunable_config_trial_group_id', 'status'])
[(
row.trial_id,
utcify_timestamp(row.ts_start, origin="utc"),
utcify_nullable_timestamp(row.ts_end, origin="utc"),
row.config_id,
row.tunable_config_trial_group_id,
row.status,
) for row in cur_trials.fetchall()],
columns=[
'trial_id',
'ts_start',
'ts_end',
'tunable_config_id',
'tunable_config_trial_group_id',
'status',
]
)

# Get each trial's config in wide format.
configs_stmt = schema.trial.select().with_only_columns(
Expand Down
11 changes: 8 additions & 3 deletions mlos_bench/mlos_bench/storage/sql/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.trial import Trial
from mlos_bench.storage.util import utcify_timestamp
from mlos_bench.util import nullable

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,7 +121,9 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
self._schema.trial_telemetry.c.metric_id,
)
)
return [(row.ts, row.metric_id, row.metric_value)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
for row in cur_telemetry.fetchall()]

def load(self,
Expand Down Expand Up @@ -184,6 +187,7 @@ def _save_params(conn: Connection, table: Table,

def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
timestamp = timestamp.astimezone(UTC) if timestamp.tzinfo else timestamp.replace(tzinfo=UTC)
if running:
pending_status = ['PENDING', 'READY', 'RUNNING']
else:
Expand Down Expand Up @@ -238,15 +242,16 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:

def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
_LOG.debug("Create trial: %s:%d", self._experiment_id, self._trial_id)
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local")
_LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
with self._engine.begin() as conn:
try:
config_id = self._get_config_id(conn, tunables)
conn.execute(self._schema.trial.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
config_id=config_id,
ts_start=ts_start or datetime.now(UTC),
ts_start=ts_start,
status='PENDING',
))

Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/storage/sql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, engine: Engine):
self._meta,
Column("exp_id", String(self._ID_LEN), nullable=False),
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime, nullable=False, default="now"),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("status", String(self._STATUS_LEN), nullable=False),

UniqueConstraint("exp_id", "trial_id", "ts"),
Expand All @@ -181,7 +181,7 @@ def __init__(self, engine: Engine):
self._meta,
Column("exp_id", String(self._ID_LEN), nullable=False),
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime, nullable=False, default="now"),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),

Expand Down
8 changes: 8 additions & 0 deletions mlos_bench/mlos_bench/storage/sql/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.util import utcify_timestamp
from mlos_bench.util import nullable

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self, *,
def update(self, status: Status, timestamp: datetime,
metrics: Optional[Union[Dict[str, Any], float]] = None
) -> Optional[Dict[str, Any]]:
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = super().update(status, timestamp, metrics)
with self._engine.begin() as conn:
self._update_status(conn, status, timestamp)
Expand Down Expand Up @@ -106,6 +109,9 @@ def update(self, status: Status, timestamp: datetime,
def update_telemetry(self, status: Status, timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]]) -> None:
super().update_telemetry(status, timestamp, metrics)
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
# NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
# and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
# a bulk upsert.
Expand All @@ -130,6 +136,8 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime)
Insert a new status record into the database.
This call is idempotent.
"""
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
try:
conn.execute(self._schema.trial_status.insert().values(
exp_id=self._experiment_id,
Expand Down
7 changes: 5 additions & 2 deletions mlos_bench/mlos_bench/storage/sql/trial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
An interface to access the benchmark trial data stored in SQL DB.
"""
from datetime import datetime
from datetime import datetime, UTC
from typing import Optional, TYPE_CHECKING

import pandas
Expand All @@ -16,6 +16,7 @@
from mlos_bench.environments.status import Status
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
from mlos_bench.storage.util import utcify_timestamp

if TYPE_CHECKING:
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
Expand Down Expand Up @@ -100,8 +101,10 @@ def telemetry_df(self) -> pandas.DataFrame:
self._schema.trial_telemetry.c.metric_id,
)
)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return pandas.DataFrame(
[(row.ts, row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()],
[(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()],
columns=['ts', 'metric', 'value'])

@property
Expand Down
45 changes: 44 additions & 1 deletion mlos_bench/mlos_bench/storage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,57 @@
Utility functions for the storage subsystem.
"""

from typing import Dict, Optional
from datetime import datetime, UTC
from typing import Dict, Literal, Optional

import pandas

from mlos_bench.tunables.tunable import TunableValue, TunableValueTypeTuple
from mlos_bench.util import try_parse_val


def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
"""
Augment a timestamp with zoneinfo if missing and convert it to UTC.
Parameters
----------
timestamp : datetime
A timestamp to convert to UTC.
Note: The original datetime may or may not have tzinfo associated with it.
origin : Literal["utc", "local"]
Whether the source timestamp is considered to be in UTC or local time.
In the case of loading data from storage, where we intentionally convert all
timestamps to UTC, this can help us retrieve the original timezone when the
storage backend doesn't explicitly store it.
Returns
-------
datetime
A datetime with zoneinfo in UTC.
"""
if timestamp.tzinfo is not None or origin == "local":
# A timestamp with no zoneinfo is interpretted as "local" time
# (e.g., according to the TZ environment variable).
# That could be UTC or some other timezone, but either way we convert it to
# be explicitly UTC with zone info.
return timestamp.astimezone(UTC)
elif origin == "utc":
# If the timestamp is already in UTC, we just add the zoneinfo without conversion.
# Converting with astimezone() when the local time is *not* UTC would cause
# a timestamp conversion which we don't want.
return timestamp.replace(tzinfo=UTC)
else:
raise ValueError(f"Invalid origin: {origin}")


def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]:
"""
A nullable version of utcify_timestamp.
"""
return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None


def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValue]]:
"""
Utility function to convert certain flat key-value dataframe formats used by the
Expand Down
1 change: 1 addition & 0 deletions mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def storage() -> SqlStorage:
config={
"drivername": "sqlite",
"database": ":memory:",
# "database": "mlos_bench.pytest.db",
}
)

Expand Down
4 changes: 3 additions & 1 deletion mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def test_update_telemetry(storage: Storage,

# Also check that the TrialData telemetry looks right.
trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
assert _telemetry_str([tuple(r) for r in trial_data.telemetry_df.to_numpy()]) == _telemetry_str(telemetry_data)
trial_telemetry_df = trial_data.telemetry_df
trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()]
assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data)


def test_update_telemetry_twice(exp_storage: Storage.Experiment,
Expand Down

0 comments on commit 6b0f411

Please sign in to comment.