Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usage logging #1920

Merged
merged 19 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sdv/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Module for configuring loggers within the SDV library."""

from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config

__all__ = (
'disable_single_table_logger',
'get_sdv_logger',
'get_sdv_logger_config',
)
27 changes: 27 additions & 0 deletions sdv/logging/sdv_logger_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
log_registry: 'local'
version: 1
loggers:
SingleTableSynthesizer:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
MultiTableSynthesizer:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
MultiTableMetadata:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
SingleTableMetadata:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
91 changes: 91 additions & 0 deletions sdv/logging/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Utilities for configuring logging within the SDV library."""

import contextlib
import logging
import logging.config
from functools import lru_cache
from pathlib import Path

import yaml


def get_sdv_logger_config():
"""Return a dictionary with the logging configuration."""
logging_path = Path(__file__).parent
with open(logging_path / 'sdv_logger_config.yml', 'r') as f:
logger_conf = yaml.safe_load(f)

# Logfile to be in this same directory
for logger in logger_conf.get('loggers', {}).values():
handler = logger.get('handlers', {})
if handler.get('filename') == 'sdv_logs.log':
handler['filename'] = logging_path / handler['filename']
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved

return logger_conf


@contextlib.contextmanager
def disable_single_table_logger():
"""Temporarily disables logging for the single table synthesizers.

This context manager temporarily removes all handlers associated with
the ``SingleTableSynthesizer`` logger, disabling logging for that module
within the current context. After the context exits, the
removed handlers are restored to the logger.
"""
# Logging without ``SingleTableSynthesizer``
single_table_logger = logging.getLogger('SingleTableSynthesizer')
handlers = single_table_logger.handlers
single_table_logger.handlers = []
try:
yield
finally:
for handler in handlers:
single_table_logger.addHandler(handler)


@lru_cache()
def get_sdv_logger(logger_name):
"""Get a logger instance with the specified name and configuration.

This function retrieves or creates a logger instance with the specified name
and applies configuration settings based on the logger's name and the logging
configuration.

Args:
logger_name (str):
The name of the logger to retrieve or create.

Returns:
logging.Logger:
A logger instance configured according to the logging configuration
and the specific settings for the given logger name.
"""
logger_conf = get_sdv_logger_config()
logger = logging.getLogger(logger_name)
if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
19 changes: 19 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Multi Table Metadata."""

import datetime
import json
import logging
import warnings
Expand All @@ -11,6 +12,7 @@

from sdv._utils import _cast_to_iterable, _load_data_from_csv
from sdv.errors import InvalidDataError
from sdv.logging import get_sdv_logger
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
Expand All @@ -19,6 +21,7 @@
create_columns_node, create_summarized_columns_node, visualize_graph)

LOGGER = logging.getLogger(__name__)
MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata')
WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format']


Expand Down Expand Up @@ -1040,6 +1043,22 @@ def save_to_json(self, filepath):
"""
validate_file_does_not_exist(filepath)
metadata = self.to_dict()
total_columns = 0
for table in self.tables.values():
total_columns += len(table.columns)

MULTITABLEMETADATA_LOGGER.info(
'\nMetadata Save:\n'
' Timestamp: %s\n'
' Statistics about the metadata:\n'
' Total number of tables: %s\n'
' Total number of columns: %s\n'
' Total number of relationships: %s',
datetime.datetime.now(),
len(self.tables),
total_columns,
len(self.relationships)
)
with open(filepath, 'w', encoding='utf-8') as metadata_file:
json.dump(metadata, metadata_file, indent=4)

Expand Down
12 changes: 12 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
_cast_to_iterable, _format_invalid_values_string, _get_datetime_format, _is_boolean_type,
_is_datetime_type, _is_numerical_type, _load_data_from_csv, _validate_datetime_format)
from sdv.errors import InvalidDataError
from sdv.logging import get_sdv_logger
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)

LOGGER = logging.getLogger(__name__)
SINGLETABLEMETADATA_LOGGER = get_sdv_logger('SingleTableMetadata')


class SingleTableMetadata:
Expand Down Expand Up @@ -1206,6 +1208,16 @@ def save_to_json(self, filepath):
validate_file_does_not_exist(filepath)
metadata = self.to_dict()
metadata['METADATA_SPEC_VERSION'] = self.METADATA_SPEC_VERSION
SINGLETABLEMETADATA_LOGGER.info(
'\nMetadata Save:\n'
' Timestamp: %s\n'
' Statistics about the metadata:\n'
' Total number of tables: 1'
' Total number of columns: %s'
' Total number of relationships: 0',
datetime.now(),
len(self.columns)
)
with open(filepath, 'w', encoding='utf-8') as metadata_file:
json.dump(metadata, metadata_file, indent=4)

Expand Down
120 changes: 110 additions & 10 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
_validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version,
generate_synthesizer_id)
from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')


class BaseMultiTableSynthesizer:
"""Base class for multi table synthesizers.
Expand Down Expand Up @@ -56,13 +59,14 @@ def _set_temp_numpy_seed(self):
np.random.set_state(initial_state)

def _initialize_models(self):
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
**synthesizer_parameters
)
with disable_single_table_logger():
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
**synthesizer_parameters
)

def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
Expand Down Expand Up @@ -113,6 +117,15 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self._fitted_sdv_version = None
self._fitted_sdv_enterprise_version = None
self._synthesizer_id = generate_synthesizer_id(self)
SYNTHESIZER_LOGGER.info(
'\nInstance:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
self.__class__.__name__,
self._synthesizer_id
)

def _get_root_parents(self):
"""Get the set of root parents in the graph."""
Expand Down Expand Up @@ -371,9 +384,33 @@ def fit_processed_data(self, processed_data):
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
total_rows = 0
total_columns = 0
for table in processed_data.values():
total_rows += len(table)
total_columns += len(table.columns)

SYNTHESIZER_LOGGER.info(
'\nFit processed data:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Statistics of the fit processed data:\n'
' Total number of tables: %s\n'
' Total number of rows: %s\n'
' Total number of columns: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
self.__class__.__name__,
len(processed_data),
total_rows,
total_columns,
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
with disable_single_table_logger():
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)

self._fitted = True
self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_sdv_version = getattr(version, 'public', None)
Expand All @@ -387,6 +424,28 @@ def fit(self, data):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
total_rows = 0
total_columns = 0
for table in data.values():
total_rows += len(table)
total_columns += len(table.columns)

SYNTHESIZER_LOGGER.info(
'\nFit:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Statistics of the fit data:\n'
' Total number of tables: %s\n'
' Total number of rows: %s\n'
' Total number of columns: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
self.__class__.__name__,
len(data),
total_rows,
total_columns,
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
self._check_metadata_updated()
Expand Down Expand Up @@ -419,9 +478,31 @@ def sample(self, scale=1.0):
raise SynthesizerInputError(
f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.")

with self._set_temp_numpy_seed():
with self._set_temp_numpy_seed(), disable_single_table_logger():
sampled_data = self._sample(scale=scale)

total_rows = 0
total_columns = 0
for table in sampled_data.values():
total_rows += len(table)
total_columns += len(table.columns)

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Statistics of the sample size:\n'
' Total number of tables: %s\n'
' Total number of rows: %s\n'
' Total number of columns: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
self.__class__.__name__,
len(sampled_data),
total_rows,
total_columns,
self._synthesizer_id,
)
return sampled_data

def get_learned_distributions(self, table_name):
Expand Down Expand Up @@ -586,6 +667,16 @@ def save(self, filepath):
filepath (str):
Path where the instance will be serialized.
"""
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info(
'\nSave:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
self.__class__.__name__,
synthesizer_id
)
with open(filepath, 'wb') as output:
cloudpickle.dump(self, output)

Expand All @@ -609,4 +700,13 @@ def load(cls, filepath):
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)

SYNTHESIZER_LOGGER.info(
'\nLoad:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
synthesizer.__class__.__name__,
synthesizer._synthesizer_id,
)
return synthesizer
Loading
Loading