Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usage logging #1920

Merged
merged 19 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
'deepecho>=0.6.0',
'rdt>=1.12.0',
'sdmetrics>=0.14.0',
'platformdirs>=4.0'
pvk-developer marked this conversation as resolved.
Show resolved Hide resolved
]

[project.urls]
Expand Down
3 changes: 1 addition & 2 deletions sdv/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Module for configuring loggers within the SDV library."""

from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config
from sdv.logging.utils import get_sdv_logger, get_sdv_logger_config

__all__ = (
'disable_single_table_logger',
'get_sdv_logger',
'get_sdv_logger_config',
)
86 changes: 36 additions & 50 deletions sdv/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Utilities for configuring logging within the SDV library."""

import contextlib
import logging
import logging.config
from functools import lru_cache
from pathlib import Path

import platformdirs
import yaml


Expand All @@ -16,34 +15,16 @@ def get_sdv_logger_config():
logger_conf = yaml.safe_load(f)

# Logfile to be in this same directory
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
store_path.mkdir(parents=True, exist_ok=True)
for logger in logger_conf.get('loggers', {}).values():
handler = logger.get('handlers', {})
if handler.get('filename') == 'sdv_logs.log':
handler['filename'] = logging_path / handler['filename']
handler['filename'] = store_path / handler['filename']

return logger_conf


@contextlib.contextmanager
def disable_single_table_logger():
"""Temporarily disables logging for the single table synthesizers.

This context manager temporarily removes all handlers associated with
the ``SingleTableSynthesizer`` logger, disabling logging for that module
within the current context. After the context exits, the
removed handlers are restored to the logger.
"""
# Logging without ``SingleTableSynthesizer``
single_table_logger = logging.getLogger('SingleTableSynthesizer')
handlers = single_table_logger.handlers
single_table_logger.handlers = []
try:
yield
finally:
for handler in handlers:
single_table_logger.addHandler(handler)


@lru_cache()
def get_sdv_logger(logger_name):
"""Get a logger instance with the specified name and configuration.
Expand All @@ -62,30 +43,35 @@ def get_sdv_logger(logger_name):
and the specific settings for the given logger name.
"""
logger_conf = get_sdv_logger_config()
logger = logging.getLogger(logger_name)
if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
if logger_conf.get('log_registry') is None:
# Return a logger without any extra settings and avoid writing into files or other streams
return logging.getLogger(logger_name)

if logger_conf.get('log_registry') == 'local':
logger = logging.getLogger(logger_name)
if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
27 changes: 14 additions & 13 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version,
generate_synthesizer_id)
from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.logging import get_sdv_logger
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -59,14 +59,14 @@ def _set_temp_numpy_seed(self):
np.random.set_state(initial_state)

def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
**synthesizer_parameters
)
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
table_name=table_name,
**synthesizer_parameters
)

def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
Expand Down Expand Up @@ -199,6 +199,8 @@ def set_table_parameters(self, table_name, table_parameters):
A dictionary with the parameters as keys and the values to be used to instantiate
the table's synthesizer.
"""
# Ensure that we set the name of the table no matter what
table_parameters.update({'table_name': table_name})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=self.metadata.tables[table_name],
**table_parameters
Expand Down Expand Up @@ -407,9 +409,8 @@ def fit_processed_data(self, processed_data):
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
with disable_single_table_logger():
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)

self._fitted = True
self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')
Expand Down Expand Up @@ -478,7 +479,7 @@ def sample(self, scale=1.0):
raise SynthesizerInputError(
f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.")

with self._set_temp_numpy_seed(), disable_single_table_logger():
with self._set_temp_numpy_seed():
sampled_data = self._sample(scale=scale)

total_rows = 0
Expand Down
18 changes: 15 additions & 3 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**self._table_parameters[child_name]
**synthesizer_parameters
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
Expand Down Expand Up @@ -521,7 +523,12 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
)
synthesizer._set_parameters(parameters, default_parameters)
synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor

Expand Down Expand Up @@ -615,7 +622,12 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
for parent_id, row in parent_rows.iterrows():
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._table_synthesizers[table_name].get_metadata()
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
synthesizer_parameters = self._table_parameters[table_name]
synthesizer_parameters.update({'table_name': table_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
)
synthesizer._set_parameters(parameters)
try:
likelihoods[parent_id] = synthesizer._get_likelihood(table_rows)
Expand Down
Loading
Loading