Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert store and param registry to sqlite #114

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
45 changes: 45 additions & 0 deletions curifactory/dbschema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""The metadata and table definitions for sqlalchemy for the experiment store."""

# TODO: will need to decide beteen Core and ORM usage
from sqlalchemy import Boolean, Column, DateTime, Integer, MetaData, String, Table

metadata_obj = MetaData()

store_info = Table(
"store_info",
metadata_obj,
Column("key", String, primary_key=True),
Column("value", String),
)

runs_table = Table(
"run",
metadata_obj,
Column("reference", String, primary_key=True),
Column("experiment_name", String),
Column("run_number", Integer),
Column("timestamp", DateTime),
Column("commit", String),
Column("param_files", String), # NOTE: this will be a json.dumps,
# since this is likely to change in later cf versions, I don't want
# to bother correctly normalizing this part of the table, since I
# don't think there will be need to query on it anyway.
Column("params", String),
Column("workdir_dirty", Boolean),
Column("full_store", Boolean),
Column("status", String),
Column("cli", String),
Column("hostname", String),
Column("user", String),
Column("notes", String),
Column("uncommited_patch", String),
Column("pip_freeze", String),
Column("conda_env", String),
Column("conda_env_full", String),
Column("os", String),
Column("reproduce", String),
)


# https://docs.sqlalchemy.org/en/20/tutorial/metadata.html
# https://docs.sqlalchemy.org/en/20/tutorial/data_insert.html#tutorial-core-insert
74 changes: 74 additions & 0 deletions curifactory/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Class for tracking run metadata."""

from dataclasses import dataclass
from datetime import datetime


@dataclass
class RunMetadata:
"""Data structure for tracking all the relevant metadata for a run,
making it easier/providing a consistent interface for accessing
the information and converting it into formats necessary for saving.
"""

reference: str
"""The full reference name of the experiment, usually
``[experiment_name]_[run_number]_[timestamp]``."""
experiment_name: str
"""The name of the experiment and/or the prefix used for caching."""
run_number: int
"""The run counter for experiments with the given name."""
timestamp: datetime
"""The datetime timestamp for when the manager is initialized (and usually
also when the experiment starts running.)"""

param_files: list[str]
"""The list of parameter file names (minus extension, as they would be
passed into the CLI.)"""
params: dict[str, list[list[str, str]]]
"""A dictionary of parameter file names for keys, where each value is an array of arrays,
each inner array containing the parameter set name and the parameter set hash, for the
parameter sets that come from that parameter file.

e.g. ``{"my_param_file": [ [ "my_param_set", "44b5e428e7165975a3e4f0d1674dbe5f" ] ] }``
"""
full_store: bool
"""Whether this store was being fully exported or not."""

commit: str
"""The current git commit hash."""
workdir_dirty: bool
"""True if there are uncommited changes in the git repo."""
uncommited_patch: str
"""The output of ``git diff -p`` at runtime, to help more precisely reconstruct current codebase."""

status: str
"""Ran status: success/incomplete/error/etc."""
cli: str
"""The CLI line this run was created with."""
reproduce: str
"""The translated CLI line to reproduce this run."""

hostname: str
"""The name of the machine this experiment ran on."""
user: str
"""The name of the user account the experiment was run with."""
notes: str
"""User-entered notes associated with a session/run to output into the report etc."""

pip_freeze: str
"""The output from a ``pip freeze`` command."""
conda_env: str
"""The output from ``conda env export --from-history``."""
conda_env_full: str
"""The output from ``conda env export``."""
os: str
"""The name of the current OS running curifactory."""

def as_sql_safe_dict(self) -> dict:
"""Meant to be used when inserting/updating values in the Runs
sql table.

The targeted column names can be found in dbschema.py.
"""
pass
139 changes: 139 additions & 0 deletions curifactory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,146 @@
import json
import os

from sqlalchemy import create_engine, func, insert, select

from curifactory import utils
from curifactory.dbschema import metadata_obj, runs_table


class SQLStore:
"""EXPERIMENTAL, making an sqlite version of the data below."""

# TODO: (11/8/2023) make this take the full store path instead
def __init__(self, manager_cache_path: str):
self.path = manager_cache_path
"""The location to store the ``store.db``."""

if self.path[-1] != "/":
self.path += "/"

self.path += "store.db"

self.engine = create_engine(f"sqlite:///{self.path}")

self._ensure_tables()

def _ensure_tables(self):
"""Check for the existence of (and create if necessary) all of the tables
listed in dbscheme.py"""
metadata_obj.create_all(self.engine)

def get_run(self, ref_name: str) -> dict:
"""Get the metadata block for the run with the specified reference name.

Args:
ref_name (str): The run reference name, following the [experiment_name]_[run_number]_[timestamp] format.

Returns:
A dictionary (metadata block) for the run with the requested reference name, and the
index of the run in the table.
"""
# https://docs.sqlalchemy.org/en/20/tutorial/data_select.html

with self.engine.connect() as conn:
stmt = select(runs_table).where(runs_table.c.reference == ref_name)
# TODO: do I need to use prepare?
result = conn.execute(stmt)

# if we didn't get any rows back, this run doesn't exist.
if len(result) == 0:
return None

run = result[0]._asdict()
# NOTE: documented function of namedtuple, _ here doesn't imply hidden/not intended for use
run.param_files = json.loads(run.param_files)
return run

def add_run(self, mngr) -> dict:
"""Add a new metadata block to the store for the passed ``ArtifactManager`` instance.

Note that this automatically calls the ``save()`` function.

Args:
mngr (ArtifactManager): The manager to grab run metadata from.

Returns:
The newly created dictionary (metadata block) for the current manager's run.
"""

# get the new run number
with self.engine.connect() as conn:
stmt = select(func.count()).where(
runs_table.c.experiment_name == mngr.experiment_name
)
result = conn.execute(stmt)

# TODO: (11/6/2023) none of these should be set here...
run_count = result.all()[0][0]
mngr.experiment_run_number = run_count + 1
mngr.git_commit_hash = utils.get_current_commit()
mngr.git_workdir_dirty = utils.check_git_dirty_workingdir()

# insert the new entry
with self.engine.connect() as conn:
stmt = insert(runs_table).values(
reference=mngr.get_reference_name(),
experiment_name=mngr.experiment_name,
run_number=mngr.experiment_run_number,
timestamp=mngr.run_timestamp,
commit=mngr.git_commit_hash,
workdir_dirty=mngr.git_workdir_dirty,
param_files=str(mngr.parameter_files),
params=str(mngr.param_file_param_sets),
full_store=mngr.store_full,
status="incomplete",
cli=mngr.run_line,
hostname=mngr.hostname,
notes=mngr.notes,
)
conn.execute(stmt)

# create the metadata block
run_dict = {
"reference": mngr.get_reference_name(),
"experiment_name": mngr.experiment_name,
"run_number": mngr.experiment_run_number,
"timestamp": mngr.get_str_timestamp(),
"commit": mngr.git_commit_hash,
"workdir_dirty": mngr.git_workdir_dirty,
"param_files": mngr.parameter_files,
"params": mngr.param_file_param_sets,
"full_store": mngr.store_full,
"status": "incomplete",
"cli": mngr.run_line,
"hostname": mngr.hostname,
"notes": mngr.notes,
}

# sanitize reproduction cli command
if mngr.store_full:
run_dict = self._get_reproduction_line(mngr, run_dict)

return run_dict

# NOTE: we have to call this both from add_run and update_run because manager stores itself on init, but if
# someone _later_ sets store_full (maybe in a live run) we need to be able to handle this being added to the run_info
def _get_reproduction_line(self, mngr, run: dict) -> dict:
sanitized_run_line = mngr.run_line
if "--overwrite " in sanitized_run_line:
sanitized_run_line = sanitized_run_line.replace("--overwrite ", "")
if sanitized_run_line.endswith("--overwrite"):
sanitized_run_line = sanitized_run_line[:-12]
sanitized_run_line = sanitized_run_line.replace("--store-full ", "")
if sanitized_run_line.endswith("--store-full"):
sanitized_run_line = sanitized_run_line[:-13]

cache_path = mngr.get_run_output_path()
mngr.reproduction_line = (
f"{sanitized_run_line} --cache {cache_path}/artifacts --dry-cache"
)

run["reproduce"] = mngr.reproduction_line
return run


class ManagerStore:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ pyarrow>=7
lxml
openpyxl
tables
sqlalchemy
15 changes: 15 additions & 0 deletions test/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from curifactory.experiment import run_experiment
from curifactory.store import SQLStore


def test_add_run(configured_test_manager):
store = SQLStore(configured_test_manager.manager_cache_path)

results, mngr = run_experiment(
"simple_cache",
["simple_cache"],
param_set_names=["thing1", "thing2"],
mngr=configured_test_manager,
)

store.add_run(mngr)
Loading