Skip to content

Commit

Permalink
Bulk commit to DB and Cache IDs (#162)
Browse files Browse the repository at this point in the history
* Bulk commit and cache database responses
  • Loading branch information
simonhkswan authored Feb 5, 2024
1 parent 68fbadb commit ae2bc77
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ repos:
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.7.1
hooks:
- id: mypy
files: src
additional_dependencies:
- numpy>=1.21
- sqlalchemy[mypy]
- alembic
- types-cachetools
args: [--install-types, --non-interactive]
# Note that using the --install-types is problematic if running in
# parallel as mutating the pre-commit env at runtime breaks cache.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"pandas >= 1.2",
"scipy >= 1.5",
"seaborn >= 0.11.0",
"cachetools >= 5.0",
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
alembic==1.7.6
cachetools==5.3.1
contourpy==1.1.0
cycler==0.11.0
fonttools==4.42.1
Expand Down
30 changes: 30 additions & 0 deletions src/insight/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Utils for fetching information from the backend DB."""

import os
import re
import typing as ty

import pandas as pd
from cachetools import cached
from cachetools.keys import hashkey
from sqlalchemy import create_engine
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
Expand All @@ -14,6 +17,8 @@
NamedModelType = ty.TypeVar("NamedModelType", model.Dataset, model.Metric, model.Version)

_database_fail_note = "Failure to communicate with the database"
_DATASET_ID_MAPPING: ty.Optional[ty.Dict[str, int]] = None
_METRIC_ID_MAPPING: ty.Optional[ty.Dict[str, int]] = None


def get_df(url_or_path: str):
Expand All @@ -25,6 +30,7 @@ def get_df(url_or_path: str):
return df


@cached(cache={}, key=lambda df_name, session, **kwargs: hashkey(df_name))
def get_df_id(
df_name: str,
session: Session,
Expand All @@ -40,6 +46,17 @@ def get_df_id(
num_columns (int): The number of columns in the dataframe. Optional.
"""
global _DATASET_ID_MAPPING # pylint: disable=global-statement
# create a mapping of df_names to session
if _DATASET_ID_MAPPING is None:
with session:
df_names = session.query(model.Dataset).all()
_DATASET_ID_MAPPING = {df.name: df.id for df in df_names if df.name is not None}

df_id = _DATASET_ID_MAPPING.get(df_name)
if df_id is not None:
return df_id

dataset = get_object_from_db_by_name(df_name, session, model.Dataset)
if dataset is None:
with session:
Expand All @@ -51,6 +68,7 @@ def get_df_id(
return int(dataset.id)


@cached(cache={}, key=lambda metric, session, **kwargs: hashkey(metric))
def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = None) -> int:
"""Get the id of a metric in the database. If it doesn't exist, create it.
Expand All @@ -59,6 +77,17 @@ def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = No
session (Session): The database session.
category (str): The category of the metric. Optional.
"""
global _METRIC_ID_MAPPING # pylint: disable=global-statement
# create a mapping of df_names to session
if _METRIC_ID_MAPPING is None:
with session:
metrics = session.query(model.Dataset).all()
_METRIC_ID_MAPPING = {m.name: m.id for m in metrics if m.name is not None}

metric_id = _METRIC_ID_MAPPING.get(metric)
if metric_id is not None:
return metric_id

db_metric = get_object_from_db_by_name(metric, session, model.Metric)

if db_metric is None:
Expand All @@ -71,6 +100,7 @@ def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = No
return int(db_metric.id)


@cached(cache={}, key=lambda version, session: hashkey(version))
def get_version_id(version: str, session: Session) -> int:
"""Get the id of a version in the database. If it doesn't exist, create it.
Expand Down
29 changes: 25 additions & 4 deletions src/insight/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module contains the base classes for the metrics used across synthesized."""

import os
import typing as ty
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -71,6 +72,7 @@ def _add_to_database(
dataset_rows: ty.Optional[int] = None,
dataset_cols: ty.Optional[int] = None,
category: ty.Optional[str] = None,
session: ty.Optional[Session] = None,
):
"""
Adds the metric result to the database. The metric result should be specified as value.
Expand Down Expand Up @@ -101,7 +103,23 @@ def _add_to_database(
if hasattr(value, "item"):
value = value.item()

with self._session as session:
if session is None:
with self._session as session:
metric_id = utils.get_metric_id(self.name, session, category=category)
version_id = utils.get_version_id(version, session)
dataset_id = utils.get_df_id(
dataset_name, session, num_rows=dataset_rows, num_columns=dataset_cols
)
result = model.Result(
metric_id=metric_id,
dataset_id=dataset_id,
version_id=version_id,
value=value,
run_id=run_id,
)
session.add(result)
session.commit()
else:
metric_id = utils.get_metric_id(self.name, session, category=category)
version_id = utils.get_version_id(version, session)
dataset_id = utils.get_df_id(
Expand All @@ -115,7 +133,6 @@ def _add_to_database(
run_id=run_id,
)
session.add(result)
session.commit()


class OneColumnMetric(_Metric):
Expand Down Expand Up @@ -167,7 +184,7 @@ def check_column_types(cls, sr: pd.Series, check: Check = ColumnCheck()) -> bool
def _compute_metric(self, sr: pd.Series):
...

def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None):
def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None, session=None):
if not self.check_column_types(sr, self.check):
value = None
else:
Expand All @@ -181,6 +198,7 @@ def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None):
dataset_rows=len(sr),
category="OneColumnMetric",
dataset_cols=1,
session=session,
)

return value
Expand Down Expand Up @@ -237,7 +255,9 @@ def check_column_types(cls, sr_a: pd.Series, sr_b: pd.Series, check: Check = Col
def _compute_metric(self, sr_a: pd.Series, sr_b: pd.Series):
...

def __call__(self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[str] = None):
def __call__(
self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[str] = None, session=None
):
if not self.check_column_types(sr_a, sr_b, self.check):
value = None
else:
Expand All @@ -251,6 +271,7 @@ def __call__(self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[s
dataset_rows=len(sr_a),
category="TwoColumnMetric",
dataset_cols=1,
session=session,
)

return value
Expand Down
78 changes: 58 additions & 20 deletions src/insight/metrics/metrics_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@ def __init__(self, metric: OneColumnMetric):
self.name = f"{metric.name}_map"

def _compute_result(self, df: pd.DataFrame) -> pd.DataFrame:
columns_map = {
col: self._metric(df[col], dataset_name=df.attrs.get("name", "") + f"_{col}")
for col in df.columns
}
result = pd.DataFrame(data=columns_map.values(), index=df.columns, columns=[self.name])
dataset_name = df.attrs.get("name", "")
if self._session is not None:
with self._session as session:
columns_map = {
col: self._metric(
df[col], dataset_name=f"{dataset_name}_{col}", session=session
)
for col in df.columns
}
session.commit()
else:
columns_map = {
col: self._metric(df[col], dataset_name=f"{dataset_name}_{col}", session=None)
for col in df.columns
}

result = pd.DataFrame(data=columns_map.values(), index=df.columns, columns=[self.name])
result.name = self._metric.name
return result

Expand Down Expand Up @@ -57,12 +68,24 @@ def _compute_result(self, df: pd.DataFrame) -> pd.DataFrame:
columns = df.columns
matrix = pd.DataFrame(index=columns, columns=columns)

for col_a, col_b in permutations(columns, 2):
matrix[col_a][col_b] = self._metric(
df[col_a],
df[col_b],
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
)
if self._session is not None:
with self._session as session:
for col_a, col_b in permutations(columns, 2):
matrix[col_a][col_b] = self._metric(
df[col_a],
df[col_b],
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
session=session,
)
session.commit()
else:
for col_a, col_b in permutations(columns, 2):
matrix[col_a][col_b] = self._metric(
df[col_a],
df[col_b],
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
session=None,
)

return pd.DataFrame(matrix.astype(np.float32)) # explicit casting for mypy

Expand Down Expand Up @@ -105,16 +128,31 @@ def __init__(self, metric: TwoColumnMetric):
self.name = f"{metric.name}_map"

def _compute_result(self, df_old: pd.DataFrame, df_new: pd.DataFrame) -> pd.DataFrame:
columns_map = {
col: self._metric(
df_old[col],
df_new[col],
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
)
for col in df_old.columns
}
result = pd.DataFrame(data=columns_map.values(), index=df_old.columns, columns=[self.name])

if self._session is not None:
with self._session as session:
columns_map = {
col: self._metric(
df_old[col],
df_new[col],
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
session=session,
)
for col in df_old.columns
}
session.commit()
else:
columns_map = {
col: self._metric(
df_old[col],
df_new[col],
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
session=None,
)
for col in df_old.columns
}

result = pd.DataFrame(data=columns_map.values(), index=df_old.columns, columns=[self.name])
result.name = self._metric.name
return result

Expand Down
15 changes: 12 additions & 3 deletions tests/test_database/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@ def tables(engine):
Base.metadata.drop_all(engine)


@pytest.fixture
def db_session(engine, tables):
@pytest.fixture(scope="function")
def clear_utils_cache():
yield utils
utils.get_df_id.cache_clear()
utils.get_metric_id.cache_clear()
utils.get_version_id.cache_clear()
utils._DATASET_ID_MAPPING = None
utils._METRIC_ID_MAPPING = None


@pytest.fixture(scope="function")
def db_session(engine, tables, clear_utils_cache):
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection, expire_on_commit=False)
Expand All @@ -35,7 +45,6 @@ def db_session(engine, tables):
base.TwoColumnMetric._session = session
base.DataFrameMetric._session = session
base.TwoDataFrameMetric._session = session

yield session

# Return class variables to their original state.
Expand Down

0 comments on commit ae2bc77

Please sign in to comment.