diff --git a/.gitignore b/.gitignore index dcbbb4e0b..81d92bed1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.vscode/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Makefile b/Makefile index 5f1547db4..65ec08061 100644 --- a/Makefile +++ b/Makefile @@ -99,14 +99,13 @@ check-dependencies: ## test if there are any broken dependencies pip check .PHONY: lint -lint: ## check style with flake8 and isort +lint: invoke lint .PHONY: fix-lint -fix-lint: ## fix lint issues using autoflake, autopep8, and isort - find sdv tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables - autopep8 --in-place --recursive --aggressive sdv tests - isort --apply --atomic sdv tests +fix-lint: + ruff check --fix . + ruff format # TEST TARGETS diff --git a/pyproject.toml b/pyproject.toml index 1e3d58f64..422a6fdb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,32 +86,7 @@ dev = [ 'Jinja2>=2,<4', # style check - 'flake8>=3.7.7,<8', - 'flake8-absolute-import>=1.0,<2', - 'flake8-builtins>=1.5.3,<3', - 'flake8-comprehensions>=3.6.1,<4', - 'flake8-debugger>=4.0.0,<5', - 'flake8-docstrings>=1.5.0,<2', - 'flake8-eradicate>=1.1.0,<2', - 'flake8-fixme>=1.1.1,<1.2', - 'flake8-mock>=0.3,<1', - 'flake8-multiline-containers>=0.0.18,<0.1', - 'flake8-mutable>=1.2.0,<1.3', - 'flake8-expression-complexity>=0.0.9,<0.1', - 'flake8-print>=4.0.0,<4.1', - 'flake8-pytest-style>=2.0.0,<3', - 'flake8-quotes>=3.3.0,<4', - 'flake8-sfs>=0.0.3,<2', - 'flake8-variables-names>=0.0.4,<0.1', - 'dlint>=0.11.0,<1', - 'isort>=5.13.2,<6', - 'pandas-vet>=0.2.3,<2024', - 'pep8-naming>=0.12.1,<1', - 'pydocstyle>=6.1.1,<7', - - # fix style issues - 'autoflake>=1.1,<3', - 'autopep8>=1.4.3,<3', + 'ruff>=0.4.5,<1', # distribute on PyPI 'twine>=1.10.0,<6', @@ -190,19 +165,55 @@ filename = "sdv/__init__.py" search = "__version__ = '{current_version}'" replace = "__version__ = '{new_version}'" -[tool.isort] -line_length = 99 -lines_between_types = 0 -multi_line_output = 4 -use_parentheses = true - -[tool.pydocstyle] -convention = 'google' -add-ignore = ['D105', 'D107', 'D407', 'D417'] - [tool.pytest.ini_options] addopts = "--ignore=pyproject.toml" [build-system] requires = ['setuptools', 'wheel'] build-backend = 'setuptools.build_meta' + +[tool.ruff] +preview = true +line-length = 100 +indent-width = 4 +src = ["sdv"] +exclude = [ + "docs", + ".tox", + ".git", + "__pycache__", + ".ipynb_checkpoints" +] + +[tool.ruff.lint] +select = [ + # Pyflakes + "F", + # Pycodestyle + "E", + "W", + "D200", + # isort + "I001", +] +ignore = [ + "E501", + "D107", # Missing docstring in __init__ + "D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449 +] + +[tool.ruff.format] +quote-style = "single" +indent-style = "space" +preview = true +docstring-code-format = true +docstring-code-line-length = "dynamic" + +[tool.ruff.lint.isort] +known-first-party = ["sdv"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "E402", "F403", "F405", "E501", "I001"] + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/scripts/release_notes_generator.py b/scripts/release_notes_generator.py index 09adc9109..8e7ba6ffb 100644 --- a/scripts/release_notes_generator.py +++ b/scripts/release_notes_generator.py @@ -13,7 +13,7 @@ 'maintenance': 'Maintenance', 'customer success': 'Customer Success', 'documentation': 'Documentation', - 'misc': 'Miscellaneous' + 'misc': 'Miscellaneous', } ISSUE_LABELS = [ 'documentation', @@ -21,7 +21,7 @@ 'internal', 'bug', 'feature request', - 'customer success' + 'customer success', ] NEW_LINE = '\n' GITHUB_URL = 'https://api.github.com/repos/sdv-dev/sdv' @@ -30,14 +30,8 @@ def _get_milestone_number(milestone_title): url = f'{GITHUB_URL}/milestones' - headers = { - 'Authorization': f'Bearer {GITHUB_TOKEN}' - } - query_params = { - 'milestone': milestone_title, - 'state': 'all', - 'per_page': 100 - } + headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} + query_params = {'milestone': milestone_title, 'state': 'all', 'per_page': 100} response = requests.get(url, headers=headers, params=query_params) body = response.json() if response.status_code != 200: @@ -52,17 +46,12 @@ def _get_milestone_number(milestone_title): def _get_issues_by_milestone(milestone): - headers = { - 'Authorization': f'Bearer {GITHUB_TOKEN}' - } + headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} # get milestone number milestone_number = _get_milestone_number(milestone) url = f'{GITHUB_URL}/issues' page = 1 - query_params = { - 'milestone': milestone_number, - 'state': 'all' - } + query_params = {'milestone': milestone_number, 'state': 'all'} issues = [] while True: query_params['page'] = page diff --git a/sdv/__init__.py b/sdv/__init__.py index e1b0460aa..92c9298d5 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -16,8 +16,21 @@ from types import ModuleType from sdv import ( - constraints, data_processing, datasets, evaluation, io, lite, logging, metadata, metrics, - multi_table, sampling, sequential, single_table, version) + constraints, + data_processing, + datasets, + evaluation, + io, + lite, + logging, + metadata, + metrics, + multi_table, + sampling, + sequential, + single_table, + version, +) __all__ = [ 'constraints', @@ -33,7 +46,7 @@ 'sampling', 'sequential', 'single_table', - 'version' + 'version', ] diff --git a/sdv/_utils.py b/sdv/_utils.py index 5db3955ec..9138dbcce 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -1,4 +1,5 @@ """Miscellaneous utility functions.""" + import operator import uuid import warnings @@ -122,11 +123,7 @@ def _validate_datetime_format(column, datetime_format): Series of booleans, with True if the value matches the format, False if not. """ pandas_datetime_format = datetime_format.replace('%-', '%') - datetime_column = pd.to_datetime( - column, - errors='coerce', - format=pandas_datetime_format - ) + datetime_column = pd.to_datetime(column, errors='coerce', format=pandas_datetime_format) valid = pd.isna(column) | ~pd.isna(datetime_column) return set(column[~valid]) @@ -253,7 +250,7 @@ def check_sdv_versions_and_warn(synthesizer): public_missmatch = current_public_version != fitted_public_version enterprise_missmatch = current_enterprise_version != fitted_enterprise_version - if (public_missmatch or enterprise_missmatch): + if public_missmatch or enterprise_missmatch: static_message = ( 'The latest bug fixes and features may not be available for this synthesizer. ' 'To see these enhancements, create and train a new synthesizer on this version.' @@ -346,23 +343,18 @@ def check_synthesizer_version(synthesizer, is_fit_method=False, compare_operator static_message = 'Downgrading your SDV version is not supported.' if is_fit_method: static_message = ( - 'Fitting this synthesizer again is not supported. ' - 'Please create a new synthesizer.' + 'Fitting this synthesizer again is not supported. ' 'Please create a new synthesizer.' ) fit_public_version = getattr(synthesizer, '_fitted_sdv_version', None) fit_enterprise_version = getattr(synthesizer, '_fitted_sdv_enterprise_version', None) is_public_lower = _compare_versions( - current_public_version, - fit_public_version, - compare_operator + current_public_version, fit_public_version, compare_operator ) is_enterprise_lower = _compare_versions( - current_enterprise_version, - fit_enterprise_version, - compare_operator + current_enterprise_version, fit_enterprise_version, compare_operator ) if is_public_lower and is_enterprise_lower: diff --git a/sdv/constraints/__init__.py b/sdv/constraints/__init__.py index 2fb12bedd..cf29042c1 100644 --- a/sdv/constraints/__init__.py +++ b/sdv/constraints/__init__.py @@ -1,8 +1,19 @@ """SDV Constraints module.""" + from sdv.constraints.base import Constraint from sdv.constraints.tabular import ( - FixedCombinations, FixedIncrements, Inequality, Negative, OneHotEncoding, Positive, Range, - ScalarInequality, ScalarRange, Unique, create_custom_constraint_class) + FixedCombinations, + FixedIncrements, + Inequality, + Negative, + OneHotEncoding, + Positive, + Range, + ScalarInequality, + ScalarRange, + Unique, + create_custom_constraint_class, +) __all__ = [ 'create_custom_constraint_class', @@ -16,5 +27,5 @@ 'Negative', 'Positive', 'OneHotEncoding', - 'Unique' + 'Unique', ] diff --git a/sdv/constraints/base.py b/sdv/constraints/base.py index 4fa889ea3..447d55611 100644 --- a/sdv/constraints/base.py +++ b/sdv/constraints/base.py @@ -13,7 +13,10 @@ from sdv._utils import _format_invalid_values_string, _groupby_list from sdv.constraints.errors import ( - AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError) + AggregateConstraintsError, + ConstraintMetadataError, + MissingConstraintColumnError, +) from sdv.errors import ConstraintsNotMetError LOGGER = logging.getLogger(__name__) @@ -123,15 +126,19 @@ def _validate_inputs(cls, **kwargs): constraint = cls.__name__ article = 'an' if constraint == 'Inequality' else 'a' if missing_values: - errors.append(ValueError( - f'Missing required values {missing_values} in {article} {constraint} constraint.' - )) + errors.append( + ValueError( + f'Missing required values {missing_values} in {article} {constraint} constraint.' + ) + ) invalid_vals = set(kwargs) - set(args) if invalid_vals: - errors.append(ValueError( - f'Invalid values {invalid_vals} are present in {article} {constraint} constraint.' - )) + errors.append( + ValueError( + f'Invalid values {invalid_vals} are present in {article} {constraint} constraint.' + ) + ) if errors: raise AggregateConstraintsError(errors) @@ -324,8 +331,9 @@ def filter_valid(self, table_data): valid = self.is_valid(table_data) invalid = sum(~valid) if invalid: - LOGGER.debug('%s: %s invalid rows out of %s.', - self.__class__.__name__, sum(~valid), len(valid)) + LOGGER.debug( + '%s: %s invalid rows out of %s.', self.__class__.__name__, sum(~valid), len(valid) + ) if isinstance(valid, pd.Series): return table_data[valid.to_numpy()] @@ -411,8 +419,7 @@ def _get_hyper_transformer_config(data_to_model): if dtype in ('i', 'f'): sdtypes[column_name] = 'numerical' transformers = FloatFormatter( - missing_value_replacement='mean', - missing_value_generation='from_column' + missing_value_replacement='mean', missing_value_generation='from_column' ) elif dtype == 'O': sdtypes[column_name] = 'categorical' @@ -420,14 +427,12 @@ def _get_hyper_transformer_config(data_to_model): elif dtype == 'M': sdtypes[column_name] = 'datetime' transformers[column_name] = UnixTimestampEncoder( - missing_value_replacement='mean', - missing_value_generation='from_column' + missing_value_replacement='mean', missing_value_generation='from_column' ) elif dtype == 'b': sdtypes[column_name] = 'boolean' transformers[column_name] = BinaryEncoder( - missing_value_replacement=-1, - missing_value_generation='from_column' + missing_value_replacement=-1, missing_value_generation='from_column' ) return {'sdtypes': sdtypes, 'transformers': transformers} @@ -468,8 +473,9 @@ def _reject_sample(self, num_rows, conditions): multiplier = num_rows // num_valid num_rows_missing = num_rows % num_valid remainder_rows = valid_rows.iloc[0:num_rows_missing, :] - valid_rows = pd.concat([valid_rows] * multiplier + [remainder_rows], - ignore_index=True) + valid_rows = pd.concat( + [valid_rows] * multiplier + [remainder_rows], ignore_index=True + ) break remaining = num_rows - num_valid @@ -500,9 +506,7 @@ def sample(self, table_data): Table data with additional ``constraint_columns``. """ condition_columns = [c for c in self.constraint_columns if c in table_data.columns] - grouped_conditions = table_data[condition_columns].groupby( - _groupby_list(condition_columns) - ) + grouped_conditions = table_data[condition_columns].groupby(_groupby_list(condition_columns)) all_sampled_rows = [] for group, dataframe in grouped_conditions: if not isinstance(group, tuple): @@ -510,8 +514,7 @@ def sample(self, table_data): transformed_condition = self._hyper_transformer.transform(dataframe).iloc[0].to_dict() sampled_rows = self._reject_sample( - num_rows=dataframe.shape[0], - conditions=transformed_condition + num_rows=dataframe.shape[0], conditions=transformed_condition ) all_sampled_rows.append(sampled_rows) diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 5eb370ceb..63fdcaa9a 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -39,16 +39,26 @@ from sdv._utils import _convert_to_timedelta, _create_unique_name, _is_datetime_type from sdv.constraints.base import Constraint from sdv.constraints.errors import ( - AggregateConstraintsError, ConstraintMetadataError, FunctionError, InvalidFunctionError) + AggregateConstraintsError, + ConstraintMetadataError, + FunctionError, + InvalidFunctionError, +) from sdv.constraints.utils import ( - cast_to_datetime64, compute_nans_column, get_datetime_diff, logit, matches_datetime_format, - revert_nans_columns, sigmoid) + cast_to_datetime64, + compute_nans_column, + get_datetime_diff, + logit, + matches_datetime_format, + revert_nans_columns, + sigmoid, +) INEQUALITY_TO_OPERATION = { '>': np.greater, '>=': np.greater_equal, '<': np.less, - '<=': np.less_equal + '<=': np.less_equal, } @@ -68,13 +78,13 @@ def _validate_inputs_custom_constraint(is_valid_fn, transform_fn=None, reverse_t raise ValueError('`reverse_transform_fn` must be a function.') -class _RecreateCustomConstraint(): +class _RecreateCustomConstraint: def __call__(self, is_valid_fn, transform_fn, reverse_transform_fn): constraint_class = _RecreateCustomConstraint() constraint_class.__class__ = create_custom_constraint_class( is_valid_fn=is_valid_fn, transform_fn=transform_fn, - reverse_transform_fn=reverse_transform_fn + reverse_transform_fn=reverse_transform_fn, ) return constraint_class @@ -125,7 +135,7 @@ def __reduce__(self): return ( _RecreateCustomConstraint(), (is_valid_fn, transform_fn, reverse_transform_fn), - self.__dict__ + self.__dict__, ) def __init__(self, column_names, **kwargs): @@ -147,7 +157,8 @@ def is_valid(self, data): valid = is_valid_fn(self.column_names, data, **self.kwargs) if len(valid) != data.shape[0]: raise InvalidFunctionError( - '`is_valid_fn` did not produce exactly 1 True/False value for each row.') + '`is_valid_fn` did not produce exactly 1 True/False value for each row.' + ) if not isinstance(valid, pd.Series): raise ValueError( @@ -176,7 +187,8 @@ def transform(self, data): transformed_data = transform_fn(self.column_names, data, **self.kwargs) if data.shape[0] != transformed_data.shape[0]: raise InvalidFunctionError( - 'Transformation did not produce the same number of rows as the original') + 'Transformation did not produce the same number of rows as the original' + ) self.reverse_transform(transformed_data.copy()) return transformed_data @@ -247,9 +259,11 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): if invalid_columns: columns = '", "'.join(invalid_columns) - raise ConstraintMetadataError(f'Invalid columns ("{columns}") supplied to a ' - 'FixedCombinations constraint. This constraint only ' - 'supports boolean and categorical columns.') + raise ConstraintMetadataError( + f'Invalid columns ("{columns}") supplied to a ' + 'FixedCombinations constraint. This constraint only ' + 'supports boolean and categorical columns.' + ) def __init__(self, column_names): if len(column_names) < 2: @@ -299,10 +313,7 @@ def is_valid(self, table_data): Whether each row is valid. """ merged = table_data.merge( - self._combinations, - how='left', - on=self._columns, - indicator=self._joint_column + self._combinations, how='left', on=self._columns, indicator=self._joint_column ) return merged[self._joint_column] == 'both' @@ -441,9 +452,11 @@ def _fit(self, table_data): self._is_datetime = self._get_is_datetime() if self._is_datetime: self._low_datetime_format = self.metadata.columns[self._low_column_name].get( - 'datetime_format') + 'datetime_format' + ) self._high_datetime_format = self.metadata.columns[self._high_column_name].get( - 'datetime_format') + 'datetime_format' + ) def is_valid(self, table_data): """Check whether ``high`` is greater than ``low`` in each row. @@ -487,7 +500,7 @@ def _transform(self, table_data): high=high, low=low, high_datetime_format=self._high_datetime_format, - low_datetime_format=self._low_datetime_format + low_datetime_format=self._low_datetime_format, ) else: diff_column = high - low @@ -505,7 +518,7 @@ def _transform(self, table_data): mean_value_low = table_data[self._low_column_name].mean() table_data = table_data.fillna({ self._low_column_name: mean_value_low, - self._diff_column_name: table_data[self._diff_column_name].mean() + self._diff_column_name: table_data[self._diff_column_name].mean(), }) return table_data.drop(self._high_column_name, axis=1) @@ -572,10 +585,12 @@ def _validate_inputs(cls, **kwargs): if 'relation' in kwargs and kwargs['relation'] not in {'>', '>=', '<', '<='}: wrong_relation = {kwargs['relation']} - errors.append(ConstraintMetadataError( - f'Invalid relation value {wrong_relation} in a ScalarInequality constraint.' - " The relation must be one of: '>', '>=', '<' or '<='." - )) + errors.append( + ConstraintMetadataError( + f'Invalid relation value {wrong_relation} in a ScalarInequality constraint.' + " The relation must be one of: '>', '>=', '<' or '<='." + ) + ) if errors: raise AggregateConstraintsError(errors) @@ -836,16 +851,17 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): middle_sdtype = metadata.columns.get(middle, {}).get('sdtype') all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime' all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical' - if not (all_datetime or all_numerical) and \ - not (high is None or low is None or middle is None): + if not (all_datetime or all_numerical) and not ( + high is None or low is None or middle is None + ): raise ConstraintMetadataError( 'A Range constraint is being applied to columns with mismatched sdtypes ' f'{[high, middle, low]}. All columns must be either numerical or datetime.' ) - def __init__(self, low_column_name, middle_column_name, high_column_name, - strict_boundaries=True): - + def __init__( + self, low_column_name, middle_column_name, high_column_name, strict_boundaries=True + ): self.constraint_columns = (low_column_name, middle_column_name, high_column_name) self.low_column_name = low_column_name self.middle_column_name = middle_column_name @@ -881,11 +897,14 @@ def _fit(self, table_data): self._is_datetime = self._get_is_datetime() if self._is_datetime: self._low_datetime_format = self.metadata.columns[self.low_column_name].get( - 'datetime_format') + 'datetime_format' + ) self._middle_datetime_format = self.metadata.columns[self.middle_column_name].get( - 'datetime_format') + 'datetime_format' + ) self._high_datetime_format = self.metadata.columns[self.high_column_name].get( - 'datetime_format') + 'datetime_format' + ) self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}' self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}' @@ -977,7 +996,7 @@ def _transform(self, table_data): table_data = table_data.fillna({ self.low_column_name: mean_value_low, self.low_diff_column_name: table_data[self.low_diff_column_name].mean(), - self.high_diff_column_name: table_data[self.high_diff_column_name].mean() + self.high_diff_column_name: table_data[self.high_diff_column_name].mean(), }) return table_data.drop([self.middle_column_name, self.high_column_name], axis=1) @@ -1013,9 +1032,9 @@ def _reverse_transform(self, table_data): middle = pd.Series(low_diff_column + low).astype(self._dtype) table_data[self.middle_column_name] = middle - table_data[self.high_column_name] = pd.Series( - high_diff_column + middle.to_numpy() - ).astype(self._dtype) + table_data[self.high_column_name] = pd.Series(high_diff_column + middle.to_numpy()).astype( + self._dtype + ) if self.nan_column_name in table_data.columns: table_data = revert_nans_columns(table_data, self.nan_column_name) @@ -1133,12 +1152,13 @@ def _fit(self, table_data): self._is_datetime = self._get_is_datetime() self._transformed_column = self._get_diff_column_name(table_data) if self._is_datetime: - self._datetime_format = self.metadata.columns[self._column_name].get( - 'datetime_format') + self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format') self._low_value = cast_to_datetime64( - self._low_value, datetime_format=self._datetime_format) + self._low_value, datetime_format=self._datetime_format + ) self._high_value = cast_to_datetime64( - self._high_value, datetime_format=self._datetime_format) + self._high_value, datetime_format=self._datetime_format + ) def is_valid(self, table_data): """Say whether the ``column_name`` is between the ``low`` and ``high`` values. @@ -1188,7 +1208,8 @@ def _transform(self, table_data): data = table_data[self._column_name] if self._is_datetime: data = cast_to_datetime64( - table_data[self._column_name], datetime_format=self._datetime_format) + table_data[self._column_name], datetime_format=self._datetime_format + ) data = logit(data, self._low_value, self._high_value) table_data[self._transformed_column] = data @@ -1257,10 +1278,12 @@ def _validate_inputs(cls, **kwargs): if 'increment_value' in kwargs and kwargs['increment_value'] <= 0: wrong_increment = {kwargs['increment_value']} - errors.append(ConstraintMetadataError( - f'Invalid increment value {wrong_increment} in a FixedIncrements constraint.' - ' Increments must be positive integers.' - )) + errors.append( + ConstraintMetadataError( + f'Invalid increment value {wrong_increment} in a FixedIncrements constraint.' + ' Increments must be positive integers.' + ) + ) if errors: raise AggregateConstraintsError(errors) diff --git a/sdv/data_processing/__init__.py b/sdv/data_processing/__init__.py index 4cb6ea49c..d71721a0f 100644 --- a/sdv/data_processing/__init__.py +++ b/sdv/data_processing/__init__.py @@ -2,6 +2,4 @@ from sdv.data_processing.data_processor import DataProcessor -__all__ = ( - 'DataProcessor', -) +__all__ = ('DataProcessor',) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 638a922ce..aeff9eaee 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -15,7 +15,10 @@ from sdv.constraints import Constraint from sdv.constraints.base import get_subclasses from sdv.constraints.errors import ( - AggregateConstraintsError, FunctionError, MissingConstraintColumnError) + AggregateConstraintsError, + FunctionError, + MissingConstraintColumnError, +) from sdv.data_processing.datetime_formatter import DatetimeFormatter from sdv.data_processing.errors import InvalidConstraintsError, NotFittedError from sdv.data_processing.numerical_formatter import NumericalFormatter @@ -65,17 +68,14 @@ class DataProcessor: 'M': 'datetime', } - _COLUMN_RELATIONSHIP_TO_TRANSFORMER = { - 'address': 'RandomLocationGenerator', - 'gps': 'GPSNoiser' - } + _COLUMN_RELATIONSHIP_TO_TRANSFORMER = {'address': 'RandomLocationGenerator', 'gps': 'GPSNoiser'} def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values): custom_float_formatter = rdt.transformers.FloatFormatter( missing_value_replacement='mean', missing_value_generation='random', learn_rounding_scheme=enforce_rounding, - enforce_min_max_values=enforce_min_max_values + enforce_min_max_values=enforce_min_max_values, ) self._transformers_by_sdtype.update({'numerical': custom_float_formatter}) @@ -104,8 +104,15 @@ def _detect_multi_column_transformers(self): return result - def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True, - model_kwargs=None, table_name=None, locales=['en_US']): + def __init__( + self, + metadata, + enforce_rounding=True, + enforce_min_max_values=True, + model_kwargs=None, + table_name=None, + locales=['en_US'], + ): self.metadata = metadata self._enforce_rounding = enforce_rounding self._enforce_min_max_values = enforce_min_max_values @@ -140,9 +147,7 @@ def _get_grouped_columns(self): list: A list of columns that are part of a multi column transformer. """ - return [ - col for col_tuple in self.grouped_columns_to_transformers for col in col_tuple - ] + return [col for col_tuple in self.grouped_columns_to_transformers for col in col_tuple] def get_model_kwargs(self, model_name): """Return the required model kwargs for the indicated model. @@ -381,7 +386,7 @@ def _transform_constraints(self, data, is_condition=False): 'Unable to transform %s with columns %s because they are not all available' ' in the data. This happens due to multiple, overlapping constraints.', constraint.__class__.__name__, - error.missing_columns + error.missing_columns, ) log_exc_stacktrace(LOGGER, error) else: @@ -392,7 +397,7 @@ def _transform_constraints(self, data, is_condition=False): '%s\nUsing the reject sampling approach instead.', constraint.__class__.__name__, constraint.column_names, - str(error) + str(error), ) log_exc_stacktrace(LOGGER, error) if is_condition: @@ -412,8 +417,7 @@ def _update_transformers_by_sdtypes(self, sdtype, transformer): self._transformers_by_sdtype[sdtype] = transformer @staticmethod - def create_anonymized_transformer(sdtype, column_metadata, cardinality_rule, - locales=['en_US']): + def create_anonymized_transformer(sdtype, column_metadata, cardinality_rule, locales=['en_US']): """Create an instance of an ``AnonymizedFaker``. Read the extra keyword arguments from the ``column_metadata`` and use them to create @@ -436,10 +440,7 @@ def create_anonymized_transformer(sdtype, column_metadata, cardinality_rule, Returns: Instance of ``rdt.transformers.pii.AnonymizedFaker``. """ - kwargs = { - 'locales': locales, - 'cardinality_rule': cardinality_rule - } + kwargs = {'locales': locales, 'cardinality_rule': cardinality_rule} for key, value in column_metadata.items(): if key not in ['pii', 'sdtype']: kwargs[key] = value @@ -484,7 +485,7 @@ def create_regex_generator(self, column_name, sdtype, column_metadata, is_numeri transformer = rdt.transformers.RegexGenerator( regex_format=regex_format, enforce_uniqueness=(column_name in self._keys), - generation_order='scrambled' + generation_order='scrambled', ) return transformer @@ -500,8 +501,7 @@ def _get_transformer_instance(self, sdtype, column_metadata): ) kwargs = { - key: value for key, value in column_metadata.items() - if key not in ['pii', 'sdtype'] + key: value for key, value in column_metadata.items() if key not in ['pii', 'sdtype'] } if sdtype == 'datetime': kwargs['enforce_min_max_values'] = self._enforce_min_max_values @@ -521,7 +521,7 @@ def _update_constraint_transformers(self, data, columns_created_by_constraints, config['transformers'][column] = rdt.transformers.FloatFormatter( missing_value_replacement='mean', missing_value_generation='random', - enforce_min_max_values=self._enforce_min_max_values + enforce_min_max_values=self._enforce_min_max_values, ) else: sdtype = self._DTYPE_TO_SDTYPE.get(dtype_kind, 'categorical') @@ -563,10 +563,7 @@ def _create_config(self, data, columns_created_by_constraints): is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype) if column_metadata.get('regex_format', False): transformers[column] = self.create_regex_generator( - column, - sdtype, - column_metadata, - is_numeric + column, sdtype, column_metadata, is_numeric ) sdtypes[column] = 'text' @@ -583,7 +580,7 @@ def _create_config(self, data, columns_created_by_constraints): provider_name=None, function_name='bothify', function_kwargs={'text': bothify_format}, - cardinality_rule=cardinality_rule + cardinality_rule=cardinality_rule, ) sdtypes[column] = 'pii' if column_metadata.get('pii') else 'text' @@ -595,17 +592,14 @@ def _create_config(self, data, columns_created_by_constraints): ) transformers[column].function_kwargs = { 'text': 'sdv-pii-?????', - 'letters': '0123456789abcdefghijklmnopqrstuvwxyz' + 'letters': '0123456789abcdefghijklmnopqrstuvwxyz', } elif pii: sdtypes[column] = 'pii' cardinality_rule = 'unique' if bool(column in self._keys) else None transformers[column] = self.create_anonymized_transformer( - sdtype, - column_metadata, - cardinality_rule, - self._locales + sdtype, column_metadata, cardinality_rule, self._locales ) elif sdtype in self._transformers_by_sdtype: @@ -617,14 +611,13 @@ def _create_config(self, data, columns_created_by_constraints): sdtype=sdtype, column_metadata=column_metadata, cardinality_rule='unique', - locales=self._locales + locales=self._locales, ) else: sdtypes[column] = 'categorical' transformers[column] = self._get_transformer_instance( - 'categorical', - column_metadata + 'categorical', column_metadata ) for columns, transformer in self.grouped_columns_to_transformers.items(): @@ -690,7 +683,7 @@ def _fit_formatters(self, data): self.formatters[column_name] = NumericalFormatter( enforce_rounding=self._enforce_rounding, enforce_min_max_values=self._enforce_min_max_values, - computer_representation=representation + computer_representation=representation, ) self.formatters[column_name].learn_format(data[column_name]) @@ -729,19 +722,17 @@ def prepare_for_fitting(self, data): config = self._hyper_transformer.get_config() missing_columns = columns_created_by_constraints - config.get('sdtypes').keys() if not config.get('sdtypes'): - LOGGER.info(( - 'Setting the configuration for the ``HyperTransformer`` ' - f'for table {self.table_name}' - )) + LOGGER.info( + ( + 'Setting the configuration for the ``HyperTransformer`` ' + f'for table {self.table_name}' + ) + ) config = self._create_config(constrained, columns_created_by_constraints) self._hyper_transformer.set_config(config) elif missing_columns: - config = self._update_constraint_transformers( - constrained, - missing_columns, - config - ) + config = self._update_constraint_transformers(constrained, missing_columns, config) self._hyper_transformer = rdt.HyperTransformer() self._hyper_transformer.set_config(config) @@ -761,7 +752,8 @@ def fit(self, data): constrained = self._transform_constraints(data) if constrained.empty: raise ValueError( - 'The constrained fit dataframe is empty, synthesizer will not be fitted.') + 'The constrained fit dataframe is empty, synthesizer will not be fitted.' + ) LOGGER.info(f'Fitting HyperTransformer for table {self.table_name}') self._fit_hyper_transformer(constrained) self.fitted = True @@ -806,7 +798,8 @@ def transform(self, data, is_condition=False): # Filter columns that can be transformed columns = [ - column for column in self.get_sdtypes(primary_keys=not is_condition) + column + for column in self.get_sdtypes(primary_keys=not is_condition) if column in data.columns ] LOGGER.debug(f'Transforming constraints for table {self.table_name}') @@ -839,9 +832,7 @@ def reverse_transform(self, data, reset_keys=False): raise NotFittedError() reversible_columns = [ - column - for column in self._hyper_transformer._output_columns - if column in data.columns + column for column in self._hyper_transformer._output_columns if column in data.columns ] reversed_data = data @@ -866,8 +857,7 @@ def reverse_transform(self, data, reset_keys=False): ] if missing_columns and num_rows: anonymized_data = self._hyper_transformer.create_anonymized_columns( - num_rows=num_rows, - column_names=missing_columns + num_rows=num_rows, column_names=missing_columns ) sampled_columns.extend(missing_columns) reversed_data[anonymized_data.columns] = anonymized_data[anonymized_data.notna()] @@ -890,8 +880,7 @@ def reverse_transform(self, data, reset_keys=False): # And alternate keys. Thats the reason of ensuring that the metadata column is within # The sampled columns. sampled_columns = [ - column for column in self.metadata.columns.keys() - if column in sampled_columns + column for column in self.metadata.columns.keys() if column in sampled_columns ] for column_name in sampled_columns: column_data = reversed_data[column_name] @@ -954,7 +943,7 @@ def to_dict(self): 'metadata': deepcopy(self.metadata.to_dict()), 'constraints_list': self.get_constraints(), 'constraints_to_reverse': constraints_to_reverse, - 'model_kwargs': deepcopy(self._model_kwargs) + 'model_kwargs': deepcopy(self._model_kwargs), } @classmethod @@ -973,7 +962,7 @@ def from_dict(cls, metadata_dict, enforce_rounding=True, enforce_min_max_values= metadata=SingleTableMetadata.load_from_dict(metadata_dict['metadata']), enforce_rounding=enforce_rounding, enforce_min_max_values=enforce_min_max_values, - model_kwargs=metadata_dict.get('model_kwargs') + model_kwargs=metadata_dict.get('model_kwargs'), ) instance._constraints_to_reverse = [ diff --git a/sdv/data_processing/datetime_formatter.py b/sdv/data_processing/datetime_formatter.py index 1ce924a49..e79b439f4 100644 --- a/sdv/data_processing/datetime_formatter.py +++ b/sdv/data_processing/datetime_formatter.py @@ -1,4 +1,5 @@ """Formatter for datetime data.""" + import pandas as pd from sdv._utils import _get_datetime_format diff --git a/sdv/data_processing/errors.py b/sdv/data_processing/errors.py index e617e0cd8..f58453478 100644 --- a/sdv/data_processing/errors.py +++ b/sdv/data_processing/errors.py @@ -13,7 +13,4 @@ def __init__(self, errors): self.errors = errors def __str__(self): - return ( - 'The provided constraint is invalid:\n' + - '\n\n'.join(map(str, self.errors)) - ) + return 'The provided constraint is invalid:\n' + '\n\n'.join(map(str, self.errors)) diff --git a/sdv/data_processing/numerical_formatter.py b/sdv/data_processing/numerical_formatter.py index a66e36887..fe9f72881 100644 --- a/sdv/data_processing/numerical_formatter.py +++ b/sdv/data_processing/numerical_formatter.py @@ -1,4 +1,5 @@ """Formatter for numerical data.""" + import logging import sys @@ -9,10 +10,10 @@ MAX_DECIMALS = sys.float_info.dig - 1 INTEGER_BOUNDS = { - 'Int8': (-2**7, 2**7 - 1), - 'Int16': (-2**15, 2**15 - 1), - 'Int32': (-2**31, 2**31 - 1), - 'Int64': (-2**63, 2**63 - 1), + 'Int8': (-(2**7), 2**7 - 1), + 'Int16': (-(2**15), 2**15 - 1), + 'Int32': (-(2**31), 2**31 - 1), + 'Int64': (-(2**63), 2**63 - 1), 'UInt8': (0, 2**8 - 1), 'UInt16': (0, 2**16 - 1), 'UInt32': (0, 2**32 - 1), @@ -43,8 +44,9 @@ class NumericalFormatter: _max_value = None _rounding_digits = None - def __init__(self, enforce_rounding=False, enforce_min_max_values=False, - computer_representation='Float'): + def __init__( + self, enforce_rounding=False, enforce_min_max_values=False, computer_representation='Float' + ): self.enforce_rounding = enforce_rounding self.enforce_min_max_values = enforce_min_max_values self.computer_representation = computer_representation diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 43c35f4ea..1ead34d39 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -72,7 +72,7 @@ def _extract_data(bytes_io, output_folder_name): os.remove(metadata_v0_filepath) os.rename( os.path.join(output_folder_name, 'metadata_v1.json'), - os.path.join(output_folder_name, METADATA_FILENAME) + os.path.join(output_folder_name, METADATA_FILENAME), ) else: diff --git a/sdv/datasets/local.py b/sdv/datasets/local.py index 38682238d..28b42c249 100644 --- a/sdv/datasets/local.py +++ b/sdv/datasets/local.py @@ -33,9 +33,7 @@ def load_csvs(folder_name, read_csv_parameters=None): other_files.append(filename) if other_files: - warnings.warn( - f"Ignoring incompatible files {other_files} in folder '{folder_name}'." - ) + warnings.warn(f"Ignoring incompatible files {other_files} in folder '{folder_name}'.") if not csvs: raise ValueError( diff --git a/sdv/errors.py b/sdv/errors.py index c0b785917..c9906757c 100644 --- a/sdv/errors.py +++ b/sdv/errors.py @@ -53,9 +53,8 @@ def __init__(self, errors): self.errors = errors def __str__(self): - return ( - 'The provided data does not match the metadata:\n' + - '\n\n'.join(map(str, self.errors)) + return 'The provided data does not match the metadata:\n' + '\n\n'.join( + map(str, self.errors) ) diff --git a/sdv/evaluation/multi_table.py b/sdv/evaluation/multi_table.py index 93a5325f1..25669af82 100644 --- a/sdv/evaluation/multi_table.py +++ b/sdv/evaluation/multi_table.py @@ -1,4 +1,5 @@ """Methods to compare the real and synthetic data for multi-table.""" + from sdmetrics import visualization from sdmetrics.reports.multi_table.diagnostic_report import DiagnosticReport from sdmetrics.reports.multi_table.quality_report import QualityReport @@ -87,8 +88,9 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name ) -def get_column_pair_plot(real_data, synthetic_data, metadata, - table_name, column_names, plot_type=None, sample_size=None): +def get_column_pair_plot( + real_data, synthetic_data, metadata, table_name, column_names, plot_type=None, sample_size=None +): """Get a plot of the real and synthetic data for a given column pair. Args: @@ -119,17 +121,19 @@ def get_column_pair_plot(real_data, synthetic_data, metadata, real_data = real_data[table_name] synthetic_data = synthetic_data[table_name] return single_table_visualization.get_column_pair_plot( - real_data, - synthetic_data, - metadata, - column_names, - sample_size, - plot_type + real_data, synthetic_data, metadata, column_names, sample_size, plot_type ) -def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name, - child_foreign_key, metadata, plot_type='bar'): +def get_cardinality_plot( + real_data, + synthetic_data, + child_table_name, + parent_table_name, + child_foreign_key, + metadata, + plot_type='bar', +): """Get a plot of the cardinality of the parent-child relationship. Args: @@ -160,5 +164,5 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab parent_table_name, child_foreign_key, parent_primary_key, - plot_type + plot_type, ) diff --git a/sdv/evaluation/single_table.py b/sdv/evaluation/single_table.py index d30e52ebd..d02b38a16 100644 --- a/sdv/evaluation/single_table.py +++ b/sdv/evaluation/single_table.py @@ -100,15 +100,13 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type= }) return visualization.get_column_plot( - real_data, - synthetic_data, - column_name, - plot_type=plot_type + real_data, synthetic_data, column_name, plot_type=plot_type ) def get_column_pair_plot( - real_data, synthetic_data, metadata, column_names, plot_type=None, sample_size=None): + real_data, synthetic_data, metadata, column_names, plot_type=None, sample_size=None +): """Get a plot of the real and synthetic data for a given column pair. Args: @@ -159,13 +157,9 @@ def get_column_pair_plot( sdtype = metadata.columns.get(column_name)['sdtype'] if sdtype == 'datetime': datetime_format = metadata.columns.get(column_name).get('datetime_format') - real_data[column_name] = pd.to_datetime( - real_data[column_name], - format=datetime_format - ) + real_data[column_name] = pd.to_datetime(real_data[column_name], format=datetime_format) synthetic_data[column_name] = pd.to_datetime( - synthetic_data[column_name], - format=datetime_format + synthetic_data[column_name], format=datetime_format ) require_subsample = sample_size and sample_size < min(len(real_data), len(synthetic_data)) @@ -173,9 +167,4 @@ def get_column_pair_plot( real_data = real_data.sample(n=sample_size) synthetic_data = synthetic_data.sample(n=sample_size) - return visualization.get_column_pair_plot( - real_data, - synthetic_data, - column_names, - plot_type - ) + return visualization.get_column_pair_plot(real_data, synthetic_data, column_names, plot_type) diff --git a/sdv/io/local/__init__.py b/sdv/io/local/__init__.py index bd3c2ba5b..f57292917 100644 --- a/sdv/io/local/__init__.py +++ b/sdv/io/local/__init__.py @@ -2,8 +2,4 @@ from sdv.io.local.local import BaseLocalHandler, CSVHandler, ExcelHandler -__all__ = ( - 'BaseLocalHandler', - 'CSVHandler', - 'ExcelHandler' -) +__all__ = ('BaseLocalHandler', 'CSVHandler', 'ExcelHandler') diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index 024b4f435..e14d5de03 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -1,4 +1,5 @@ """Local file handlers.""" + import codecs import inspect import os @@ -74,8 +75,9 @@ class CSVHandler(BaseLocalHandler): If the provided encoding is not available in the system. """ - def __init__(self, sep=',', encoding='UTF', decimal='.', float_format=None, - quotechar='"', quoting=0): + def __init__( + self, sep=',', encoding='UTF', decimal='.', float_format=None, quotechar='"', quoting=0 + ): super().__init__(decimal, float_format) try: codecs.lookup(encoding) @@ -116,11 +118,7 @@ def read(self, folder_name, file_names=None): else: # Validate if the given files exist in the folder file_names = file_names - missing_files = [ - file - for file in file_names - if not (folder_path / file).exists() - ] + missing_files = [file for file in file_names if not (folder_path / file).exists()] if missing_files: raise FileNotFoundError( f"The following files do not exist in the folder: {', '.join(missing_files)}." @@ -137,7 +135,7 @@ def read(self, folder_name, file_names=None): 'decimal': self.decimal, 'on_bad_lines': 'warn', 'quotechar': self.quotechar, - 'quoting': self.quoting + 'quoting': self.quoting, } args = inspect.getfullargspec(pd.read_csv) @@ -147,10 +145,7 @@ def read(self, folder_name, file_names=None): for file_path in file_paths: table_name = file_path.stem # Remove file extension to get table name - data[table_name] = pd.read_csv( - file_path, - **kwargs - ) + data[table_name] = pd.read_csv(file_path, **kwargs) return data @@ -205,7 +200,7 @@ def _read_excel(self, filepath, sheet_names=None): sheet_name=sheet_name, parse_dates=False, decimal=self.decimal, - index_col=None + index_col=None, ) return data @@ -257,8 +252,7 @@ def write(self, synthetic_data, file_name, sheet_name_suffix=None, mode='w'): if temp_data.get(sheet_name) is not None: temp_data[sheet_name] = pd.concat( - [temp_data[sheet_name], synthetic_data[sheet_name]], - ignore_index=True + [temp_data[sheet_name], synthetic_data[sheet_name]], ignore_index=True ) else: @@ -270,10 +264,7 @@ def write(self, synthetic_data, file_name, sheet_name_suffix=None, mode='w'): table_name += sheet_name_suffix table_data.to_excel( - writer, - sheet_name=table_name, - float_format=self.float_format, - index=False + writer, sheet_name=table_name, float_format=self.float_format, index=False ) writer.close() diff --git a/sdv/lite/__init__.py b/sdv/lite/__init__.py index d26e87d31..a7b176062 100644 --- a/sdv/lite/__init__.py +++ b/sdv/lite/__init__.py @@ -2,6 +2,4 @@ from sdv.lite.single_table import SingleTablePreset -__all__ = ( - 'SingleTablePreset', -) +__all__ = ('SingleTablePreset',) diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index 5c46ca183..08232ac60 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -39,10 +39,7 @@ class SingleTablePreset: def _setup_fast_preset(self, metadata, locales): self._synthesizer = GaussianCopulaSynthesizer( - metadata=metadata, - default_distribution='norm', - enforce_rounding=False, - locales=locales + metadata=metadata, default_distribution='norm', enforce_rounding=False, locales=locales ) def __init__(self, metadata, name, locales=['en_US']): @@ -121,8 +118,9 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file return sampled - def sample_from_conditions(self, conditions, max_tries_per_batch=100, - batch_size=None, output_file_path=None): + def sample_from_conditions( + self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None + ): """Sample rows from this table with the given conditions. Args: @@ -144,16 +142,14 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, """ warnings.warn(DEPRECATION_MSG, FutureWarning) sampled = self._synthesizer.sample_from_conditions( - conditions, - max_tries_per_batch, - batch_size, - output_file_path + conditions, max_tries_per_batch, batch_size, output_file_path ) return sampled - def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, - batch_size=None, output_file_path=None): + def sample_remaining_columns( + self, known_columns, max_tries_per_batch=100, batch_size=None, output_file_path=None + ): """Sample rows from this table. Args: @@ -175,10 +171,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, """ warnings.warn(DEPRECATION_MSG, FutureWarning) sampled = self._synthesizer.sample_remaining_columns( - known_columns, - max_tries_per_batch, - batch_size, - output_file_path + known_columns, max_tries_per_batch, batch_size, output_file_path ) return sampled @@ -215,10 +208,12 @@ def load(cls, filepath): def list_available_presets(cls, out=sys.stdout): """List the available presets and their descriptions.""" warnings.warn(DEPRECATION_MSG, FutureWarning) - out.write(f'Available presets:\n{PRESETS}\n\n' - 'Supply the desired preset using the `name` parameter.\n\n' - 'Have any requests for custom presets? Contact the SDV team to learn ' - 'more an SDV Premium license.\n') + out.write( + f'Available presets:\n{PRESETS}\n\n' + 'Supply the desired preset using the `name` parameter.\n\n' + 'Have any requests for custom presets? Contact the SDV team to learn ' + 'more an SDV Premium license.\n' + ) def __repr__(self): """Represent single table preset instance as text. diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index 2c5d10e88..a4355423c 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -2,11 +2,14 @@ from sdv.logging.logger import get_sdv_logger from sdv.logging.utils import ( - disable_single_table_logger, get_sdv_logger_config, load_logfile_dataframe) + disable_single_table_logger, + get_sdv_logger_config, + load_logfile_dataframe, +) __all__ = ( 'disable_single_table_logger', 'get_sdv_logger', 'get_sdv_logger_config', - 'load_logfile_dataframe' + 'load_logfile_dataframe', ) diff --git a/sdv/logging/logger.py b/sdv/logging/logger.py index c75f03550..7ce51854c 100644 --- a/sdv/logging/logger.py +++ b/sdv/logging/logger.py @@ -1,4 +1,5 @@ """SDV Logger.""" + import csv import logging import os @@ -15,8 +16,14 @@ def __init__(self, filename=None): super().__init__() self.output = StringIO() headers = [ - 'LEVEL', 'EVENT', 'TIMESTAMP', 'SYNTHESIZER CLASS NAME', 'SYNTHESIZER ID', - 'TOTAL NUMBER OF TABLES', 'TOTAL NUMBER OF ROWS', 'TOTAL NUMBER OF COLUMNS' + 'LEVEL', + 'EVENT', + 'TIMESTAMP', + 'SYNTHESIZER CLASS NAME', + 'SYNTHESIZER ID', + 'TOTAL NUMBER OF TABLES', + 'TOTAL NUMBER OF ROWS', + 'TOTAL NUMBER OF COLUMNS', ] self.writer = csv.DictWriter(self.output, fieldnames=headers) if filename: diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index e6c86e3ea..97ff09fb0 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -60,7 +60,13 @@ def load_logfile_dataframe(logfile): Path to the SDV log CSV file. """ column_names = [ - 'LEVEL', 'EVENT', 'TIMESTAMP', 'SYNTHESIZER CLASS NAME', 'SYNTHESIZER ID', - 'TOTAL NUMBER OF TABLES', 'TOTAL NUMBER OF ROWS', 'TOTAL NUMBER OF COLUMNS' + 'LEVEL', + 'EVENT', + 'TIMESTAMP', + 'SYNTHESIZER CLASS NAME', + 'SYNTHESIZER ID', + 'TOTAL NUMBER OF TABLES', + 'TOTAL NUMBER OF ROWS', + 'TOTAL NUMBER OF COLUMNS', ] return pd.read_csv(logfile, names=column_names) diff --git a/sdv/metadata/__init__.py b/sdv/metadata/__init__.py index ad452169c..71d689727 100644 --- a/sdv/metadata/__init__.py +++ b/sdv/metadata/__init__.py @@ -10,5 +10,5 @@ 'MetadataNotFittedError', 'MultiTableMetadata', 'SingleTableMetadata', - 'visualization' + 'visualization', ) diff --git a/sdv/metadata/metadata_upgrader.py b/sdv/metadata/metadata_upgrader.py index 95b83b36a..f9cbe2c97 100644 --- a/sdv/metadata/metadata_upgrader.py +++ b/sdv/metadata/metadata_upgrader.py @@ -3,8 +3,16 @@ import warnings from sdv.constraints import ( - FixedCombinations, Inequality, Negative, OneHotEncoding, Positive, Range, ScalarInequality, - ScalarRange, Unique) + FixedCombinations, + Inequality, + Negative, + OneHotEncoding, + Positive, + Range, + ScalarInequality, + ScalarRange, + Unique, +) def _upgrade_columns_and_keys(old_metadata): @@ -75,7 +83,7 @@ def _upgrade_positive_negative(old_constraint): new_constraint = { 'constraint_name': constraint_name, 'column_name': column, - 'strict_boundaries': strict + 'strict_boundaries': strict, } new_constraints.append(new_constraint) @@ -85,7 +93,7 @@ def _upgrade_positive_negative(old_constraint): def _upgrade_unique_combinations(old_constraint): new_constraint = { 'constraint_name': FixedCombinations.__name__, - 'column_names': old_constraint.get('columns') + 'column_names': old_constraint.get('columns'), } return [new_constraint] @@ -117,7 +125,7 @@ def _upgrade_greater_than(old_constraint): 'constraint_name': Inequality.__name__, 'high_column_name': high if high_is_string else high[0], 'low_column_name': low if low_is_string else low[0], - 'strict_boundaries': strict + 'strict_boundaries': strict, } new_constraints.append(new_constraint) @@ -128,7 +136,7 @@ def _upgrade_greater_than(old_constraint): 'constraint_name': ScalarInequality.__name__, 'column_name': column, 'relation': '>' if strict else '>=', - 'value': low + 'value': low, } new_constraints.append(new_constraint) @@ -139,7 +147,7 @@ def _upgrade_greater_than(old_constraint): 'constraint_name': ScalarInequality.__name__, 'column_name': column, 'relation': '<' if strict else '<=', - 'value': high + 'value': high, } new_constraints.append(new_constraint) @@ -160,7 +168,7 @@ def _upgrade_between(old_constraint): 'column_name': constraint_column, 'low_value': low, 'high_value': high, - 'strict_boundaries': strict + 'strict_boundaries': strict, } new_constraints.append(new_constraint) @@ -169,13 +177,13 @@ def _upgrade_between(old_constraint): 'constraint_name': Inequality.__name__, 'low_column_name': low, 'high_column_name': constraint_column, - 'strict_boundaries': strict + 'strict_boundaries': strict, } scalar_constraint = { 'constraint_name': ScalarInequality.__name__, 'column_name': constraint_column, 'relation': '<' if strict else '<=', - 'value': high + 'value': high, } new_constraints.append(inequality_constraint) new_constraints.append(scalar_constraint) @@ -185,13 +193,13 @@ def _upgrade_between(old_constraint): 'constraint_name': Inequality.__name__, 'low_column_name': constraint_column, 'high_column_name': high, - 'strict_boundaries': strict + 'strict_boundaries': strict, } scalar_constraint = { 'constraint_name': ScalarInequality.__name__, 'column_name': constraint_column, 'relation': '>' if strict else '>=', - 'value': low + 'value': low, } new_constraints.append(inequality_constraint) new_constraints.append(scalar_constraint) @@ -202,7 +210,7 @@ def _upgrade_between(old_constraint): 'low_column_name': low, 'middle_column_name': constraint_column, 'high_column_name': high, - 'strict_boundaries': strict + 'strict_boundaries': strict, } new_constraints.append(new_constraint) @@ -212,7 +220,7 @@ def _upgrade_between(old_constraint): def _upgrade_one_hot_encoding(old_constraint): new_constraint = { 'constraint_name': OneHotEncoding.__name__, - 'column_names': old_constraint.get('columns') + 'column_names': old_constraint.get('columns'), } return [new_constraint] @@ -220,7 +228,7 @@ def _upgrade_one_hot_encoding(old_constraint): def _upgrade_unique(old_constraint): new_constraint = { 'constraint_name': Unique.__name__, - 'column_names': old_constraint.get('columns') + 'column_names': old_constraint.get('columns'), } return [new_constraint] diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index f9c3b6893..2606b990a 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -18,7 +18,10 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.metadata.utils import read_json, validate_file_does_not_exist from sdv.metadata.visualization import ( - create_columns_node, create_summarized_columns_node, visualize_graph) + create_columns_node, + create_summarized_columns_node, + visualize_graph, +) LOGGER = logging.getLogger(__name__) MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata') @@ -48,8 +51,9 @@ def _reset_updated_flag(self): self._multi_table_updated = False - def _validate_missing_relationship_keys(self, parent_table_name, parent_primary_key, - child_table_name, child_foreign_key): + def _validate_missing_relationship_keys( + self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key + ): parent_table = self.tables.get(parent_table_name) child_table = self.tables.get(child_table_name) if parent_table.primary_key is None: @@ -87,14 +91,17 @@ def _validate_no_missing_tables_in_relationship(parent_table_name, child_table_n if missing_table_names: if len(missing_table_names) == 1: raise InvalidMetadataError( - f'Relationship contains an unknown table {missing_table_names}.') + f'Relationship contains an unknown table {missing_table_names}.' + ) else: raise InvalidMetadataError( - f'Relationship contains unknown tables {missing_table_names}.') + f'Relationship contains unknown tables {missing_table_names}.' + ) @staticmethod - def _validate_relationship_key_length(parent_table_name, parent_primary_key, - child_table_name, child_foreign_key): + def _validate_relationship_key_length( + parent_table_name, parent_primary_key, child_table_name, child_foreign_key + ): pk_len = len(set(_cast_to_iterable(parent_primary_key))) fk_len = len(set(_cast_to_iterable(child_foreign_key))) if pk_len != fk_len: @@ -104,8 +111,9 @@ def _validate_relationship_key_length(parent_table_name, parent_primary_key, f'length {fk_len}.' ) - def _validate_relationship_sdtypes(self, parent_table_name, parent_primary_key, - child_table_name, child_foreign_key): + def _validate_relationship_sdtypes( + self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key + ): parent_table_columns = self.tables.get(parent_table_name).columns child_table_columns = self.tables.get(child_table_name).columns parent_primary_key = _cast_to_iterable(parent_primary_key) @@ -117,8 +125,9 @@ def _validate_relationship_sdtypes(self, parent_table_name, parent_primary_key, 'is invalid. The primary and foreign key columns are not the same type.' ) - def _validate_circular_relationships(self, parent, children=None, - parents=None, child_map=None, errors=None): + def _validate_circular_relationships( + self, parent, children=None, parents=None, child_map=None, errors=None + ): """Validate that there is no circular relationship in the metadata.""" parents = set() if parents is None else parents if children is None: @@ -137,7 +146,7 @@ def _validate_circular_relationships(self, parent, children=None, children=child_map.get(child, set()), child_map=child_map, parents=parents, - errors=errors + errors=errors, ) def _validate_child_map_circular_relationship(self, child_map): @@ -161,43 +170,37 @@ def _validate_foreign_child_key(self, child_table_name, parent_table_name, child 'with a non-primary key.' ) - def _validate_relationship_does_not_exist(self, parent_table_name, parent_primary_key, - child_table_name, child_foreign_key): + def _validate_relationship_does_not_exist( + self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key + ): for relationship in self.relationships: already_exists = ( - relationship['parent_table_name'] == parent_table_name and - relationship['parent_primary_key'] == parent_primary_key and - relationship['child_table_name'] == child_table_name and - relationship['child_foreign_key'] == child_foreign_key + relationship['parent_table_name'] == parent_table_name + and relationship['parent_primary_key'] == parent_primary_key + and relationship['child_table_name'] == child_table_name + and relationship['child_foreign_key'] == child_foreign_key ) if already_exists: raise InvalidMetadataError('This relationship has already been added.') - def _validate_relationship(self, parent_table_name, child_table_name, - parent_primary_key, child_foreign_key): + def _validate_relationship( + self, parent_table_name, child_table_name, parent_primary_key, child_foreign_key + ): self._validate_no_missing_tables_in_relationship( - parent_table_name, child_table_name, self.tables.keys()) + parent_table_name, child_table_name, self.tables.keys() + ) self._validate_missing_relationship_keys( - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) self._validate_relationship_key_length( - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) self._validate_foreign_child_key(child_table_name, parent_table_name, child_foreign_key) self._validate_relationship_sdtypes( - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) def _get_parent_map(self): @@ -222,8 +225,10 @@ def _get_foreign_keys(self, parent_table_name, child_table_name): """Get all foreign keys for the parent table.""" foreign_keys = [] for relation in self.relationships: - if parent_table_name == relation['parent_table_name'] and\ - child_table_name == relation['child_table_name']: + if ( + parent_table_name == relation['parent_table_name'] + and child_table_name == relation['child_table_name'] + ): foreign_keys.append(deepcopy(relation['child_foreign_key'])) return foreign_keys @@ -236,8 +241,9 @@ def _get_all_foreign_keys(self, table_name): return foreign_keys - def add_relationship(self, parent_table_name, child_table_name, - parent_primary_key, child_foreign_key): + def add_relationship( + self, parent_table_name, child_table_name, parent_primary_key, child_foreign_key + ): """Add a relationship between two tables. Args: @@ -264,15 +270,13 @@ def add_relationship(self, parent_table_name, child_table_name, - ``InvalidMetadataError`` if ``child_foreign_key`` is a primary key. """ self._validate_relationship( - parent_table_name, child_table_name, parent_primary_key, child_foreign_key) + parent_table_name, child_table_name, parent_primary_key, child_foreign_key + ) child_map = self._get_child_map() child_map[parent_table_name].add(child_table_name) self._validate_relationship_does_not_exist( - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) self._validate_child_map_circular_relationship(child_map) @@ -295,8 +299,10 @@ def remove_relationship(self, parent_table_name, child_table_name): """ relationships_to_remove = [] for relation in self.relationships: - if (relation['parent_table_name'] == parent_table_name and - relation['child_table_name'] == child_table_name): + if ( + relation['parent_table_name'] == parent_table_name + and relation['child_table_name'] == child_table_name + ): relationships_to_remove.append(relation) if not relationships_to_remove: @@ -331,8 +337,9 @@ def remove_primary_key(self, table_name): parent_table = relationship['parent_table_name'] child_table = relationship['child_table_name'] foreign_key = relationship['child_foreign_key'] - if ((child_table == table_name and foreign_key == primary_key) or - parent_table == table_name): + if ( + child_table == table_name and foreign_key == primary_key + ) or parent_table == table_name: other_table = child_table if parent_table == table_name else parent_table info_msg = ( f"Relationship between '{table_name}' and '{other_table}' removed because " @@ -487,7 +494,8 @@ def _validate_all_tables_connected(self, parent_map, child_map): ) raise InvalidMetadataError( - f'The relationships in the dataset are disjointed. {table_msg}') + f'The relationships in the dataset are disjointed. {table_msg}' + ) def _detect_relationships(self): """Automatically detect relationships between tables.""" @@ -502,15 +510,12 @@ def _detect_relationships(self): self.update_column(child_candidate, primary_key, sdtype='id') self.add_relationship( - parent_candidate, - child_candidate, - primary_key, - primary_key + parent_candidate, child_candidate, primary_key, primary_key ) except InvalidMetadataError: - self.update_column(child_candidate, - primary_key, - sdtype=original_foreign_key_sdtype) + self.update_column( + child_candidate, primary_key, sdtype=original_foreign_key_sdtype + ) continue def detect_table_from_dataframe(self, table_name, data): @@ -642,8 +647,7 @@ def set_sequence_index(self, table_name, column_name): warnings.warn('Sequential modeling is not yet supported on SDV Multi Table models.') self.tables[table_name].set_sequence_index(column_name) - def _validate_column_relationships_foreign_keys( - self, table_column_relationships, foreign_keys): + def _validate_column_relationships_foreign_keys(self, table_column_relationships, foreign_keys): """Validate that a table's column relationships do not use any foreign keys. Args: @@ -676,8 +680,9 @@ def add_column_relationship(self, table_name, relationship_type, column_names): """ self._validate_table_exists(table_name) foreign_keys = self._get_all_foreign_keys(table_name) - relationships = [{'type': relationship_type, 'column_names': column_names}] + \ - self.tables[table_name].column_relationships + relationships = [{'type': relationship_type, 'column_names': column_names}] + self.tables[ + table_name + ].column_relationships self._validate_column_relationships_foreign_keys(relationships, foreign_keys) self.tables[table_name].add_column_relationship(relationship_type, column_names) @@ -701,7 +706,8 @@ def _validate_single_table(self, errors): errors.append('\n') title = f'Table: {table_name}' error = str(error).replace( - 'The following errors were found in the metadata:\n', title) + 'The following errors were found in the metadata:\n', title + ) errors.append(error) try: @@ -734,7 +740,8 @@ def validate(self): child_map = self._get_child_map() self._append_relationships_errors( - errors, self._validate_child_map_circular_relationship, child_map) + errors, self._validate_child_map_circular_relationship, child_map + ) if errors: raise InvalidMetadataError( 'The metadata is not valid' + '\n'.join(str(e) for e in errors) @@ -907,8 +914,9 @@ def get_table_metadata(self, table_name): self._validate_table_exists(table_name) return deepcopy(self.tables[table_name]) - def visualize(self, show_table_details='full', show_relationship_labels=True, - output_filepath=None): + def visualize( + self, show_table_details='full', show_relationship_labels=True, output_filepath=None + ): """Create a visualization of the multi-table dataset. Args: @@ -929,7 +937,8 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, """ if show_table_details not in (None, True, False, 'full', 'summarized'): raise ValueError( - "'show_table_details' parameter should be 'full', 'summarized' or None.") + "'show_table_details' parameter should be 'full', 'summarized' or None." + ) if isinstance(show_table_details, bool): if show_table_details: @@ -954,14 +963,14 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, for table_name, table_meta in self.tables.items(): nodes[table_name] = { 'columns': create_columns_node(table_meta.columns), - 'primary_key': f'Primary key: {table_meta.primary_key}' + 'primary_key': f'Primary key: {table_meta.primary_key}', } elif show_table_details == 'summarized': for table_name, table_meta in self.tables.items(): nodes[table_name] = { 'columns': create_summarized_columns_node(table_meta.columns), - 'primary_key': f'Primary key: {table_meta.primary_key}' + 'primary_key': f'Primary key: {table_meta.primary_key}', } elif show_table_details is None: @@ -988,9 +997,9 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, foreign_keys = r'\l'.join(info.get('foreign_keys', [])) keys = r'\l'.join([info['primary_key'], foreign_keys]) if foreign_keys: - label = fr"{{{table}|{info['columns']}\l|{keys}\l}}" + label = rf"{{{table}|{info['columns']}\l|{keys}\l}}" else: - label = fr"{{{table}|{info['columns']}\l|{keys}}}" + label = rf"{{{table}|{info['columns']}\l|{keys}}}" else: label = f'{table}' @@ -1023,9 +1032,8 @@ def _set_metadata_dict(self, metadata): for relationship in metadata.get('relationships', []): type_safe_relationships = { - key: str(value) - if not isinstance(value, str) - else value for key, value in relationship.items() + key: str(value) if not isinstance(value, str) else value + for key, value in relationship.items() } self.relationships.append(type_safe_relationships) @@ -1070,7 +1078,7 @@ def save_to_json(self, filepath): datetime.datetime.now(), len(self.tables), total_columns, - len(self.relationships) + len(self.relationships), ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) @@ -1127,7 +1135,7 @@ def _convert_relationships(cls, old_metadata): 'parent_table_name': parent, 'parent_primary_key': tables.get(parent).get('primary_key'), 'child_table_name': table, - 'child_foreign_key': foreign_key + 'child_foreign_key': foreign_key, } for table in tables for parent in list(parents[table]) @@ -1159,7 +1167,7 @@ def upgrade_metadata(cls, filepath): metadata_dict = { 'tables': tables_metadata, 'relationships': relationships, - 'METADATA_SPEC_VERSION': cls.METADATA_SPEC_VERSION + 'METADATA_SPEC_VERSION': cls.METADATA_SPEC_VERSION, } metadata = cls.load_from_dict(metadata_dict) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index a91a02064..41d87fb70 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -13,15 +13,25 @@ from rdt.transformers.pii.anonymization import SDTYPE_ANONYMIZERS, is_faker_function from sdv._utils import ( - _cast_to_iterable, _format_invalid_values_string, _get_datetime_format, _is_boolean_type, - _is_datetime_type, _is_numerical_type, _load_data_from_csv, _validate_datetime_format) + _cast_to_iterable, + _format_invalid_values_string, + _get_datetime_format, + _is_boolean_type, + _is_datetime_type, + _is_numerical_type, + _load_data_from_csv, + _validate_datetime_format, +) from sdv.errors import InvalidDataError from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist from sdv.metadata.visualization import ( - create_columns_node, create_summarized_columns_node, visualize_graph) + create_columns_node, + create_summarized_columns_node, + visualize_graph, +) LOGGER = logging.getLogger(__name__) SINGLETABLEMETADATA_LOGGER = get_sdv_logger('SingleTableMetadata') @@ -45,8 +55,15 @@ class SingleTableMetadata: } _NUMERICAL_REPRESENTATIONS = frozenset([ - 'Float', 'Int64', 'Int32', 'Int16', 'Int8', - 'UInt64', 'UInt32', 'UInt16', 'UInt8', + 'Float', + 'Int64', + 'Int32', + 'Int16', + 'Int8', + 'UInt64', + 'UInt32', + 'UInt16', + 'UInt8', ]) _KEYS = frozenset([ 'columns', @@ -55,7 +72,7 @@ class SingleTableMetadata: 'sequence_key', 'sequence_index', 'column_relationships', - 'METADATA_SPEC_VERSION' + 'METADATA_SPEC_VERSION', ]) _REFERENCE_TO_SDTYPE = { @@ -104,7 +121,8 @@ class SingleTableMetadata: } _SDTYPES_WITH_SUBSTRINGS = dict( - set(_REFERENCE_TO_SDTYPE.items()) - set(_SDTYPES_WITHOUT_SUBSTRINGS.items())) + set(_REFERENCE_TO_SDTYPE.items()) - set(_SDTYPES_WITHOUT_SUBSTRINGS.items()) + ) _COLUMN_RELATIONSHIP_TYPES = { 'address': AddressValidator.validate, @@ -155,8 +173,9 @@ def _validate_categorical(column_name, **kwargs): f"Unknown ordering method '{order_by}' provided for categorical column " f"'{column_name}'. Ordering method must be 'numerical_value' or 'alphabetical'." ) - if (isinstance(order, list) and not len(order)) or\ - (not isinstance(order, list) and order is not None): + if (isinstance(order, list) and not len(order)) or ( + not isinstance(order, list) and order is not None + ): raise InvalidMetadataError( f"Invalid order value provided for categorical column '{column_name}'. " "The 'order' must be a list with 1 or more elements." @@ -323,9 +342,7 @@ def update_columns(self, column_names, **kwargs): errors = [] has_sdtype_key = 'sdtype' in kwargs if has_sdtype_key: - kwargs_without_sdtype = { - key: value for key, value in kwargs.items() if key != 'sdtype' - } + kwargs_without_sdtype = {key: value for key, value in kwargs.items() if key != 'sdtype'} unexpected_kwargs = self._get_unexpected_kwargs( kwargs['sdtype'], **kwargs_without_sdtype ) @@ -452,10 +469,14 @@ def _detect_pii_column(self, column_name): # tokenize the column name based on (1) symbols and (2) camelCase tokens = self._tokenize_column_name(column_name) - return next(( - sdtype for reference, sdtype in self._SDTYPES_WITH_SUBSTRINGS.items() - if reference in tokens - ), None) + return next( + ( + sdtype + for reference, sdtype in self._SDTYPES_WITH_SUBSTRINGS.items() + if reference in tokens + ), + None, + ) def _determine_sdtype_for_numbers(self, data): """Determine the sdtype for a numerical column. @@ -628,21 +649,21 @@ def _validate_keys_sdtype(self, keys, key_type): """Validate that each key is of type 'id' or a valid Faker function.""" bad_keys = set() for key in keys: - if not (self.columns[key]['sdtype'] == 'id' or - is_faker_function(self.columns[key]['sdtype'])): + if not ( + self.columns[key]['sdtype'] == 'id' + or is_faker_function(self.columns[key]['sdtype']) + ): bad_keys.add(key) if bad_keys: raise InvalidMetadataError( - f"The {key_type}_keys {sorted(bad_keys)} must be type 'id' or " - 'another PII type.' + f"The {key_type}_keys {sorted(bad_keys)} must be type 'id' or " 'another PII type.' ) def _validate_key(self, column_name, key_type): """Validate the primary and sequence keys.""" if column_name is not None: if not self._validate_key_datatype(column_name): - raise InvalidMetadataError( - f"'{key_type}_key' must be a string.") + raise InvalidMetadataError(f"'{key_type}_key' must be a string.") keys = {column_name} if isinstance(column_name, str) else set(column_name) invalid_ids = keys - set(self.columns) @@ -704,11 +725,10 @@ def set_sequence_key(self, column_name): self.sequence_key = column_name def _validate_alternate_keys(self, column_names): - if not isinstance(column_names, list) or \ - not all(self._validate_key_datatype(column_name) for column_name in column_names): - raise InvalidMetadataError( - "'alternate_keys' must be a list of strings." - ) + if not isinstance(column_names, list) or not all( + self._validate_key_datatype(column_name) for column_name in column_names + ): + raise InvalidMetadataError("'alternate_keys' must be a list of strings.") keys = set() for column_name in column_names: @@ -759,7 +779,8 @@ def _validate_sequence_index(self, column_name): sdtype = self.columns[column_name].get('sdtype') if sdtype not in ['datetime', 'numerical']: raise InvalidMetadataError( - "The sequence_index must be of type 'datetime' or 'numerical'.") + "The sequence_index must be of type 'datetime' or 'numerical'." + ) def set_sequence_index(self, column_name): """Set the metadata sequence index. @@ -817,9 +838,7 @@ def _validate_column_relationship(self, relationship): if column not in self.columns: errors.append(f"Column '{column}' not in metadata.") elif self.primary_key == column: - errors.append( - f"Cannot use primary key '{column}' in column relationship." - ) + errors.append(f"Cannot use primary key '{column}' in column relationship.") columns_to_sdtypes = { column: self.columns.get(column, {}).get('sdtype') for column in column_names @@ -855,8 +874,7 @@ def _validate_column_relationship_with_others(self, column_relationship, other_r List of other column relationships to compare against. """ for other_relationship in other_relationships: - repeated_columns = set( - other_relationship.get('column_names', [])) & set( + repeated_columns = set(other_relationship.get('column_names', [])) & set( column_relationship['column_names'] ) if repeated_columns: @@ -885,12 +903,10 @@ def _validate_all_column_relationships(self, column_relationships): for idx, relationship in enumerate(column_relationships): if set(relationship.keys()) != valid_relationship_keys: unknown_keys = set(relationship.keys()).difference(valid_relationship_keys) - raise InvalidMetadataError( - f'Relationship has invalid keys {unknown_keys}.' - ) + raise InvalidMetadataError(f'Relationship has invalid keys {unknown_keys}.') self._validate_column_relationship_with_others( - relationship, column_relationships[idx + 1:] + relationship, column_relationships[idx + 1 :] ) # Validate each individual relationship @@ -912,8 +928,8 @@ def _validate_all_column_relationships(self, column_relationships): if errors: raise InvalidMetadataError( - 'Column relationships have following errors:\n' + - '\n'.join([str(e) for e in errors]) + 'Column relationships have following errors:\n' + + '\n'.join([str(e) for e in errors]) ) def add_column_relationship(self, relationship_type, column_names): @@ -957,9 +973,7 @@ def validate(self): # Validate column relationships self._append_error( - errors, - self._validate_all_column_relationships, - self.column_relationships + errors, self._validate_all_column_relationships, self.column_relationships ) if errors: @@ -974,7 +988,8 @@ def _validate_metadata_matches_data(self, columns): missing_data_columns = set(columns).difference(metadata_columns) if missing_data_columns: errors.append( - f'The columns {sorted(missing_data_columns)} are not present in the metadata.') + f'The columns {sorted(missing_data_columns)} are not present in the metadata.' + ) missing_metadata_columns = set(metadata_columns).difference(columns) if missing_metadata_columns: @@ -1084,7 +1099,7 @@ def _validate_column_data(self, column, sdtype_warnings): invalid_values = self._get_invalid_column_values( column.sample(num_samples_to_validate), - lambda x: pd.isna(x) | _is_datetime_type(x) + lambda x: pd.isna(x) | _is_datetime_type(x), ) if datetime_format is None and column.dtype == 'O': @@ -1173,28 +1188,28 @@ def visualize(self, show_table_details='full', output_filepath=None): raise ValueError("'show_table_details' should be 'full' or 'summarized'.") if show_table_details == 'full': - node = fr'{create_columns_node(self.columns)}\l' + node = rf'{create_columns_node(self.columns)}\l' elif show_table_details == 'summarized': - node = fr'{create_summarized_columns_node(self.columns)}\l' + node = rf'{create_summarized_columns_node(self.columns)}\l' keys_node = '' if self.primary_key: - keys_node = fr'{keys_node}Primary key: {self.primary_key}\l' + keys_node = rf'{keys_node}Primary key: {self.primary_key}\l' if self.sequence_key: - keys_node = fr'{keys_node}Sequence key: {self.sequence_key}\l' + keys_node = rf'{keys_node}Sequence key: {self.sequence_key}\l' if self.sequence_index: - keys_node = fr'{keys_node}Sequence index: {self.sequence_index}\l' + keys_node = rf'{keys_node}Sequence index: {self.sequence_index}\l' if self.alternate_keys: - alternate_keys = [fr'    • {key}\l' for key in self.alternate_keys] + alternate_keys = [rf'    • {key}\l' for key in self.alternate_keys] alternate_keys = ''.join(alternate_keys) - keys_node = fr'{keys_node}Alternate keys:\l {alternate_keys}' + keys_node = rf'{keys_node}Alternate keys:\l {alternate_keys}' if keys_node != '': - node = fr'{node}|{keys_node}' + node = rf'{node}|{keys_node}' node = {'': f'{{{node}}}'} return visualize_graph(node, [], output_filepath) @@ -1220,7 +1235,7 @@ def save_to_json(self, filepath): ' Total number of columns: %s' ' Total number of relationships: 0', datetime.now(), - len(self.columns) + len(self.columns), ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) @@ -1245,9 +1260,8 @@ def load_from_dict(cls, metadata_dict): if value: if key == 'columns': value = { - str(key) - if not isinstance(key, str) - else key: col for key, col in value.items() + str(key) if not isinstance(key, str) else key: col + for key, col in value.items() } setattr(instance, f'{key}', value) diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py index cfb337e00..595bc111f 100644 --- a/sdv/metadata/visualization.py +++ b/sdv/metadata/visualization.py @@ -20,10 +20,7 @@ def create_columns_node(columns): str: String representing the node that will be printed for the given columns. """ - columns = [ - fr"{name} : {meta.get('sdtype')}" - for name, meta in columns.items() - ] + columns = [rf"{name} : {meta.get('sdtype')}" for name, meta in columns.items()] return r'\l'.join(columns) @@ -46,10 +43,7 @@ def create_summarized_columns_node(columns): count_dict = dict(sorted(count_dict.items())) columns = ['Columns'] - columns.extend([ - fr'    • {sdtype} : {count}' - for sdtype, count in count_dict.items() - ]) + columns.extend([rf'    • {sdtype} : {count}' for sdtype, count in count_dict.items()]) return r'\l'.join(columns) @@ -101,11 +95,7 @@ def visualize_graph(nodes, edges, filepath=None): digraph = graphviz.Digraph( 'Metadata', format=graphviz_extension, - node_attr={ - 'shape': 'Mrecord', - 'fillcolor': 'lightgoldenrod1', - 'style': 'filled' - }, + node_attr={'shape': 'Mrecord', 'fillcolor': 'lightgoldenrod1', 'style': 'filled'}, ) for name, label in nodes.items(): diff --git a/sdv/metrics/demos.py b/sdv/metrics/demos.py index 721579f21..04857ea4a 100644 --- a/sdv/metrics/demos.py +++ b/sdv/metrics/demos.py @@ -2,4 +2,5 @@ This subpackage exists only to enable importing the sdmetrics demos as part of sdv. """ + from sdmetrics.demos import * # noqa diff --git a/sdv/metrics/relational.py b/sdv/metrics/relational.py index e056640dd..4517afb67 100644 --- a/sdv/metrics/relational.py +++ b/sdv/metrics/relational.py @@ -2,4 +2,5 @@ This subpackage exists only to enable importing sdmetrics as part of sdv. """ + from sdmetrics.multi_table import * # noqa diff --git a/sdv/metrics/tabular.py b/sdv/metrics/tabular.py index ae15cafdc..89f8eff03 100644 --- a/sdv/metrics/tabular.py +++ b/sdv/metrics/tabular.py @@ -2,4 +2,5 @@ This subpackage exists only to enable importing sdmetrics as part of sdv. """ + from sdmetrics.single_table import * # noqa diff --git a/sdv/metrics/timeseries.py b/sdv/metrics/timeseries.py index 775ab793c..5f6f66bed 100644 --- a/sdv/metrics/timeseries.py +++ b/sdv/metrics/timeseries.py @@ -2,4 +2,5 @@ This subpackage exists only to enable importing sdmetrics as part of sdv. """ + from sdmetrics.timeseries import * # noqa diff --git a/sdv/multi_table/__init__.py b/sdv/multi_table/__init__.py index b33f38ce3..ffcf98dee 100644 --- a/sdv/multi_table/__init__.py +++ b/sdv/multi_table/__init__.py @@ -2,6 +2,4 @@ from sdv.multi_table.hma import HMASynthesizer -__all__ = ( - 'HMASynthesizer', -) +__all__ = ('HMASynthesizer',) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index f4891b9c0..c789ee3c5 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -1,4 +1,5 @@ """Base Multi Table Synthesizer class.""" + import contextlib import datetime import inspect @@ -13,10 +14,17 @@ from sdv import version from sdv._utils import ( - _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, - generate_synthesizer_id) + _validate_foreign_keys_not_null, + check_sdv_versions_and_warn, + check_synthesizer_version, + generate_synthesizer_id, +) from sdv.errors import ( - ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) + ConstraintsNotMetError, + InvalidDataError, + SamplingError, + SynthesizerInputError, +) from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -64,9 +72,7 @@ def _initialize_models(self): for table_name, table_metadata in self.metadata.tables.items(): synthesizer_parameters = self._table_parameters.get(table_name, {}) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - **synthesizer_parameters + metadata=table_metadata, locales=self.locales, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name @@ -124,7 +130,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): 'EVENT': 'Instance', 'TIMESTAMP': datetime.datetime.now(), 'SYNTHESIZER CLASS NAME': self.__class__.__name__, - 'SYNTHESIZER ID': self._synthesizer_id + 'SYNTHESIZER ID': self._synthesizer_id, }) def set_address_columns(self, table_name, column_names, anonymization_level='full'): @@ -162,7 +168,7 @@ def get_table_parameters(self, table_name): else: table_params = { 'synthesizer_name': type(table_synthesizer).__name__, - 'synthesizer_parameters': table_synthesizer.get_parameters() + 'synthesizer_parameters': table_synthesizer.get_parameters(), } return table_params @@ -193,8 +199,7 @@ def set_table_parameters(self, table_name, table_parameters): the table's synthesizer. """ self._table_synthesizers[table_name] = self._synthesizer( - metadata=self.metadata.tables[table_name], - **table_parameters + metadata=self.metadata.tables[table_name], **table_parameters ) self._table_parameters[table_name].update(deepcopy(table_parameters)) @@ -263,10 +268,7 @@ def _assign_table_transformers(self, synthesizer, table_name, table_data): """Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data.""" synthesizer.auto_assign_transformers(table_data) foreign_key_columns = self.metadata._get_all_foreign_keys(table_name) - column_name_to_transformers = { - column_name: None - for column_name in foreign_key_columns - } + column_name_to_transformers = {column_name: None for column_name in foreign_key_columns} synthesizer.update_transformers(column_name_to_transformers) def auto_assign_transformers(self, data): @@ -408,7 +410,7 @@ def fit_processed_data(self, processed_data): 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': len(processed_data), 'TOTAL NUMBER OF ROWS': total_rows, - 'TOTAL NUMBER OF COLUMNS': total_columns + 'TOTAL NUMBER OF COLUMNS': total_columns, }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -442,7 +444,7 @@ def fit(self, data): 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': len(data), 'TOTAL NUMBER OF ROWS': total_rows, - 'TOTAL NUMBER OF COLUMNS': total_columns + 'TOTAL NUMBER OF COLUMNS': total_columns, }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -479,9 +481,10 @@ def sample(self, scale=1.0): 'sampling synthetic data.' ) - if not type(scale) in (float, int) or not scale > 0: + if type(scale) not in (float, int) or not scale > 0: raise SynthesizerInputError( - f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") + f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0." + ) with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) @@ -504,7 +507,7 @@ def sample(self, scale=1.0): 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': len(sampled_data), 'TOTAL NUMBER OF ROWS': total_rows, - 'TOTAL NUMBER OF COLUMNS': total_columns + 'TOTAL NUMBER OF COLUMNS': total_columns, }) return sampled_data @@ -548,9 +551,7 @@ def get_loss_values(self, table_name): Dataframe of loss values per epoch """ if table_name not in self._table_synthesizers: - raise ValueError( - f"Table '{table_name}' is not present in the metadata." - ) + raise ValueError(f"Table '{table_name}' is not present in the metadata.") synthesizer = self._table_synthesizers[table_name] if hasattr(synthesizer, 'get_loss_values'): @@ -657,7 +658,7 @@ def get_info(self): 'creation_date': self._creation_date, 'is_fit': self._fitted, 'last_fit_date': self._fitted_date, - 'fitted_sdv_version': self._fitted_sdv_version + 'fitted_sdv_version': self._fitted_sdv_version, } if self._fitted_sdv_enterprise_version: info['fitted_sdv_enterprise_version'] = self._fitted_sdv_enterprise_version diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index dd142b420..d7a4712db 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -31,15 +31,13 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): Whether to print progress for fitting or not. """ - DEFAULT_SYNTHESIZER_KWARGS = { - 'default_distribution': 'beta' - } + DEFAULT_SYNTHESIZER_KWARGS = {'default_distribution': 'beta'} DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS = { 'beta': 4, 'truncnorm': 4, 'gamma': 3, 'norm': 2, - 'uniform': 2 + 'uniform': 2, } @staticmethod @@ -52,14 +50,16 @@ def _get_num_data_columns(metadata): """ columns_per_table = {} for table_name, table in metadata.tables.items(): - columns_per_table[table_name] = \ - sum([1 for col in table.columns.values() if col['sdtype'] != 'id']) + columns_per_table[table_name] = sum([ + 1 for col in table.columns.values() if col['sdtype'] != 'id' + ]) return columns_per_table @classmethod - def _get_num_extended_columns(cls, metadata, table_name, - parent_table, columns_per_table, distributions=None): + def _get_num_extended_columns( + cls, metadata, table_name, parent_table, columns_per_table, distributions=None + ): """Get the number of columns that will be generated for table_name. A table generates, for each foreign key: @@ -93,8 +93,9 @@ def _get_num_extended_columns(cls, metadata, table_name, return num_correlation_columns + num_rows_columns + num_parameters_columns @classmethod - def _estimate_columns_traversal(cls, metadata, table_name, - columns_per_table, visited, distributions=None): + def _estimate_columns_traversal( + cls, metadata, table_name, columns_per_table, visited, distributions=None + ): """Given a table, estimate how many columns each parent will model. This method recursively models the children of a table all the way to the leaf nodes. @@ -111,9 +112,8 @@ def _estimate_columns_traversal(cls, metadata, table_name, if child_name not in visited: cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited) - columns_per_table[table_name] += \ - cls._get_num_extended_columns( - metadata, child_name, table_name, columns_per_table, distributions + columns_per_table[table_name] += cls._get_num_extended_columns( + metadata, child_name, table_name, columns_per_table, distributions ) visited.add(table_name) @@ -161,10 +161,8 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._default_parameters = {} self.verbose = verbose BaseHierarchicalSampler.__init__( - self, - self.metadata, - self._table_synthesizers, - self._table_sizes) + self, self.metadata, self._table_synthesizers, self._table_sizes + ) self._print_estimate_warning() def set_table_parameters(self, table_name, table_parameters): @@ -239,16 +237,13 @@ def _print_estimate_warning(self): self._print( 'PerformanceAlert: Using the HMASynthesizer on this metadata ' 'schema is not recommended. To model this data, HMA will ' - f'generate a large number of columns. ({total_est_cols} columns)\n\n') + f'generate a large number of columns. ({total_est_cols} columns)\n\n' + ) self._print( pd.DataFrame( - print_table, - columns=[ - 'Table Name', - '# Columns in Metadata', - 'Est # Columns' - ] - ).to_string(index=False) + '\n' + print_table, columns=['Table Name', '# Columns in Metadata', 'Est # Columns'] + ).to_string(index=False) + + '\n' ) self._print( "We recommend simplifying your metadata schema using 'sdv.utils.poc.simplify_sch" @@ -313,8 +308,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row.index = f'__{child_name}__{foreign_key}__' + row.index else: synthesizer = self._synthesizer( - table_meta, - **self._table_parameters[child_name] + table_meta, **self._table_parameters[child_name] ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -322,11 +316,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row.index = f'__{child_name}__{foreign_key}__' + row.index if scale_columns is None: - scale_columns = [ - column - for column in row.index - if column.endswith('scale') - ] + scale_columns = [column for column in row.index if column.endswith('scale')] if len(child_rows) == 1: row.loc[scale_columns] = None @@ -385,10 +375,7 @@ def _augment_table(self, table, tables, table_name): f"Tables '{table_name}' and '{child_name}' ('{foreign_key}')" ) extension = self._get_extension( - child_name, - child_table.copy(), - foreign_key, - progress_bar_desc + child_name, child_table.copy(), foreign_key, progress_bar_desc ) for column in extension.columns: extension[column] = extension[column].astype(float) @@ -396,7 +383,8 @@ def _augment_table(self, table, tables, table_name): extension[column] = extension[column].fillna(1e-6) self.extended_columns[child_name][column] = FloatFormatter( - enforce_min_max_values=True) + enforce_min_max_values=True + ) self.extended_columns[child_name][column].fit(extension, column) table = table.merge(extension, how='left', right_index=True, left_index=True) @@ -459,22 +447,26 @@ def _model_tables(self, augmented_data): Dictionary mapping each table name to an augmented ``pandas.DataFrame``. """ augmented_data_to_model = [ - (table_name, table) - for table_name, table in augmented_data.items() + (table_name, table) for table_name, table in augmented_data.items() ] self._print(text='\n', end='') pbar_args = self._get_pbar_args(desc='Modeling Tables') for table_name, table in tqdm(augmented_data_to_model, **pbar_args): keys = self._pop_foreign_keys(table, table_name) self._clear_nans(table) - LOGGER.info('Fitting %s for table %s; shape: %s', self._synthesizer.__name__, - table_name, table.shape) + LOGGER.info( + 'Fitting %s for table %s; shape: %s', + self._synthesizer.__name__, + table_name, + table.shape, + ) if not table.empty: self._table_synthesizers[table_name].fit_processed_data(table) table_parameters = self._table_synthesizers[table_name]._get_parameters() self._default_parameters[table_name] = { - parameter: value for parameter, value in table_parameters.items() + parameter: value + for parameter, value in table_parameters.items() if 'univariates' in parameter } @@ -495,22 +487,20 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): """ prefix = f'__{table_name}__{foreign_key}__' keys = [key for key in parent_row.keys() if key.startswith(prefix)] - new_keys = {key: key[len(prefix):] for key in keys} + new_keys = {key: key[len(prefix) :] for key in keys} flat_parameters = parent_row[keys].astype(float).fillna(1e-6) num_rows_key = f'{prefix}num_rows' if num_rows_key in flat_parameters: num_rows = flat_parameters[num_rows_key] - flat_parameters[num_rows_key] = min( - self._max_child_rows[num_rows_key], - round(num_rows) - ) + flat_parameters[num_rows_key] = min(self._max_child_rows[num_rows_key], round(num_rows)) flat_parameters = flat_parameters.to_dict() for parameter_name, parameter in flat_parameters.items(): float_formatter = self.extended_columns[table_name][parameter_name] flat_parameters[parameter_name] = np.clip( # this should be revisited in GH#1769 - parameter, float_formatter._min_value, float_formatter._max_value) + parameter, float_formatter._min_value, float_formatter._max_value + ) return {new_keys[key]: value for key, value in flat_parameters.items()} @@ -521,10 +511,7 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer( - table_meta, - **self._table_parameters[child_name] - ) + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -611,17 +598,11 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): if transformed.index.name: table_rows = table_rows.set_index(transformed.index.name) - table_rows = pd.concat( - [transformed, table_rows.drop(columns=transformed.columns)], - axis=1 - ) + table_rows = pd.concat([transformed, table_rows.drop(columns=transformed.columns)], axis=1) for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() - synthesizer = self._synthesizer( - table_meta, - **self._table_parameters[table_name] - ) + synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name]) synthesizer._set_parameters(parameters) try: likelihoods[parent_id] = synthesizer._get_likelihood(table_rows) @@ -669,6 +650,6 @@ def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent parent_table=parent_table, child_name=child_name, parent_name=parent_name, - foreign_key=foreign_key + foreign_key=foreign_key, ) child_table[foreign_key] = parent_ids.to_numpy() diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 3f6fd9526..e339bcca2 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -1,4 +1,5 @@ """Utility functions for the MultiTable models.""" + import math import warnings from collections import defaultdict @@ -197,9 +198,9 @@ def _get_num_column_to_drop(metadata, child_table, max_col_per_relationships): num_modelable_column = sum([len(value) for value in modelable_columns.values()]) num_cols_to_drop = math.ceil( - num_modelable_column + num_column_parameter - np.sqrt( - num_column_parameter ** 2 + 1 + 2 * max_col_per_relationships - ) + num_modelable_column + + num_column_parameter + - np.sqrt(num_column_parameter**2 + 1 + 2 * max_col_per_relationships) ) return num_cols_to_drop, modelable_columns @@ -221,11 +222,7 @@ def _get_columns_to_drop_child(metadata, child_table, max_col_per_relationships) for sdtype, frequency in sdtypes_frequency.items(): num_col_to_drop_per_sdtype = round(num_col_to_drop * frequency) columns_to_drop.extend( - np.random.choice( - modelable_columns[sdtype], - num_col_to_drop_per_sdtype, - replace=False - ) + np.random.choice(modelable_columns[sdtype], num_col_to_drop_per_sdtype, replace=False) ) return columns_to_drop @@ -242,9 +239,7 @@ def _simplify_child(metadata, child_table, max_col_per_relationships): max_col_per_relationships (int): Maximum number of columns to model per relationship. """ - columns_to_drop = _get_columns_to_drop_child( - metadata, child_table, max_col_per_relationships - ) + columns_to_drop = _get_columns_to_drop_child(metadata, child_table, max_col_per_relationships) columns = metadata.tables[child_table].columns for column in columns_to_drop: del columns[column] @@ -312,9 +307,7 @@ def _simplify_metadata(metadata): tables_to_keep = set(children) | set(grandchildren) | {root_to_keep} table_to_drop = set(simplified_metadata.tables.keys()) - tables_to_keep - _simplify_relationships_and_tables( - simplified_metadata, table_to_drop - ) + _simplify_relationships_and_tables(simplified_metadata, table_to_drop) if grandchildren: _simplify_grandchildren(simplified_metadata, grandchildren) @@ -324,9 +317,7 @@ def _simplify_metadata(metadata): return simplified_metadata num_data_column = HMASynthesizer._get_num_data_columns(simplified_metadata) - _simplify_children( - simplified_metadata, children, root_to_keep, num_data_column - ) + _simplify_children(simplified_metadata, children, root_to_keep, num_data_column) simplified_metadata.validate() return simplified_metadata @@ -369,7 +360,7 @@ def _print_simplified_schema_summary(data_before, data_after): '# Columns (Before)': [len(data_before[table].columns) for table in tables], '# Columns (After)': [ len(data_after[table].columns) if table in data_after else 0 for table in tables - ] + ], }) message.append(summary.to_string(index=False)) print('\n'.join(message)) # noqa: T001 @@ -403,7 +394,8 @@ def _get_rows_to_drop(data, metadata): relationships_parent = _get_relationships_for_parent(relationships, parent_table) parent_column = metadata.tables[parent_table].primary_key valid_parent_idx = [ - idx for idx in data[parent_table].index + idx + for idx in data[parent_table].index if idx not in table_to_idx_to_drop[parent_table] ] valid_parent_values = set(data[parent_table].loc[valid_parent_idx, parent_column]) @@ -412,18 +404,18 @@ def _get_rows_to_drop(data, metadata): child_column = relationship['child_foreign_key'] is_nan = data[child_table][child_column].isna() - invalid_values = set( - data[child_table].loc[~is_nan, child_column] - ) - valid_parent_values + invalid_values = ( + set(data[child_table].loc[~is_nan, child_column]) - valid_parent_values + ) invalid_rows = data[child_table][ data[child_table][child_column].isin(invalid_values) ] idx_to_drop = set(invalid_rows.index) if idx_to_drop: - table_to_idx_to_drop[child_table] = table_to_idx_to_drop[ - child_table - ].union(idx_to_drop) + table_to_idx_to_drop[child_table] = table_to_idx_to_drop[child_table].union( + idx_to_drop + ) relationships = [rel for rel in relationships if rel not in relationships_parent] @@ -436,9 +428,7 @@ def _get_nan_fk_indices_table(data, relationships, table): relationships_for_table = _get_relationships_for_child(relationships, table) for relationship in relationships_for_table: child_column = relationship['child_foreign_key'] - idx_with_nan_foreign_key.update( - data[table][data[table][child_column].isna()].index - ) + idx_with_nan_foreign_key.update(data[table][data[table][child_column].isna()].index) return idx_with_nan_foreign_key @@ -449,9 +439,7 @@ def _drop_rows(data, metadata, drop_missing_values): idx_to_drop = table_to_idx_to_drop[table] data[table] = data[table].drop(idx_to_drop) if drop_missing_values: - idx_with_nan_fk = _get_nan_fk_indices_table( - data, metadata.relationships, table - ) + idx_with_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table) data[table] = data[table].drop(idx_with_nan_fk) if data[table].empty: @@ -526,8 +514,9 @@ def _get_primary_keys_referenced(data, metadata): return primary_keys_referenced -def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_before, - dereferenced_pk_parent): +def _subsample_parent( + parent_table, parent_primary_key, parent_pk_referenced_before, dereferenced_pk_parent +): """Subsample the parent table. The strategy here is to: @@ -596,8 +585,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced): pk_referenced_before = primary_keys_referenced[parent] dereferenced_primary_keys = pk_referenced_before - pk_referenced[parent] data[parent] = _subsample_parent( - data[parent], parent_primary_key, pk_referenced_before, - dereferenced_primary_keys + data[parent], parent_primary_key, pk_referenced_before, dereferenced_primary_keys ) if dereferenced_primary_keys: primary_keys_referenced[parent] = pk_referenced[parent] @@ -667,7 +655,7 @@ def _print_subsample_summary(data_before, data_after): '# Rows (Before)': [len(data_before[table]) for table in tables], '# Rows (After)': [ len(data_after[table]) if table in data_after else 0 for table in tables - ] + ], }) subsample_rows = 100 * (1 - summary['# Rows (After)'].sum() / summary['# Rows (Before)'].sum()) message = [f'Success! Your subset has {round(subsample_rows)}% less rows than the original.\n'] diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 96ae041fb..7ba175904 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -1,4 +1,5 @@ """Hierarchical Samplers.""" + import logging import warnings @@ -7,7 +8,7 @@ LOGGER = logging.getLogger(__name__) -class BaseHierarchicalSampler(): +class BaseHierarchicalSampler: """Hierarchical sampler mixin. Args: @@ -108,8 +109,9 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num if previous is None: sampled_data[child_name] = sampled_rows else: - sampled_data[child_name] = pd.concat( - [previous, sampled_rows]).reset_index(drop=True) + sampled_data[child_name] = pd.concat([previous, sampled_rows]).reset_index( + drop=True + ) def _enforce_table_size(self, child_name, table_name, scale, sampled_data): """Ensure the child table has the same size as in the real data times the scale factor. @@ -155,8 +157,10 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): # If the number of rows is already at the maximum, skip # The exception is when the smallest value is already at the maximum, # in which case we ignore the boundary - if sampled_data[table_name].loc[i, num_rows_key] >= max_rows and \ - sampled_data[table_name][num_rows_key].min() < max_rows: + if ( + sampled_data[table_name].loc[i, num_rows_key] >= max_rows + and sampled_data[table_name][num_rows_key].min() < max_rows + ): break sampled_data[table_name].loc[i, num_rows_key] += 1 @@ -168,8 +172,10 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): # If the number of rows is already at the minimum, skip # The exception is when the highest value is already at the minimum, # in which case we ignore the boundary - if sampled_data[table_name].loc[i, num_rows_key] <= min_rows and \ - sampled_data[table_name][num_rows_key].max() > min_rows: + if ( + sampled_data[table_name].loc[i, num_rows_key] <= min_rows + and sampled_data[table_name][num_rows_key].max() > min_rows + ): break sampled_data[table_name].loc[i, num_rows_key] -= 1 @@ -198,7 +204,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): child_name=child_name, parent_name=table_name, parent_row=row, - sampled_data=sampled_data + sampled_data=sampled_data, ) if child_name not in sampled_data: # No child rows sampled, force row creation @@ -215,14 +221,10 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): parent_name=table_name, parent_row=parent_row, sampled_data=sampled_data, - num_rows=1 + num_rows=1, ) - self._sample_children( - table_name=child_name, - sampled_data=sampled_data, - scale=scale - ) + self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale) def _finalize(self, sampled_data): """Remove extra columns from sampled tables and apply finishing touches. @@ -300,10 +302,7 @@ def _sample(self, scale=1.0): # is used to recreate the child tables, so the rest can be skipped. if (parent_name, child_name) not in added_relationships: self._add_foreign_key_columns( - sampled_data[child_name], - sampled_data[parent_name], - child_name, - parent_name + sampled_data[child_name], sampled_data[parent_name], child_name, parent_name ) added_relationships.add((parent_name, child_name)) diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index b8cfd8203..22d0359b4 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -1,10 +1,11 @@ """Independent Samplers.""" + import logging LOGGER = logging.getLogger(__name__) -class BaseIndependentSampler(): +class BaseIndependentSampler: """Independent sampler mixin. Args: @@ -67,17 +68,17 @@ def _connect_tables(self, sampled_data): A dictionary mapping table names to the sampled tables (pd.DataFrame). """ queue = [ - table - for table in self.metadata.tables - if not self.metadata._get_parent_map()[table] + table for table in self.metadata.tables if not self.metadata._get_parent_map()[table] ] while queue: parent = queue.pop(0) for child in self.metadata._get_child_map()[parent]: self._add_foreign_key_columns( - sampled_data[child], sampled_data[parent], child, parent) + sampled_data[child], sampled_data[parent], child, parent + ) if set(self.metadata._get_all_foreign_keys(child)).issubset( - set(sampled_data[child].columns)): + set(sampled_data[child].columns) + ): queue.append(child) def _finalize(self, sampled_data): @@ -95,7 +96,6 @@ def _finalize(self, sampled_data): """ final_data = {} for table_name, table_rows in sampled_data.items(): - synthesizer = self._table_synthesizers.get(table_name) metadata = synthesizer.get_metadata() dtypes = synthesizer._data_processor._dtypes @@ -147,8 +147,12 @@ def _sample(self, scale=1.0): for table in self.metadata.tables: num_rows = int(self._table_sizes[table] * scale) synthesizer = self._table_synthesizers[table] - self._sample_table(synthesizer=synthesizer, table_name=table, num_rows=num_rows, - sampled_data=sampled_data) + self._sample_table( + synthesizer=synthesizer, + table_name=table, + num_rows=num_rows, + sampled_data=sampled_data, + ) self._connect_tables(sampled_data) return self._finalize(sampled_data) diff --git a/sdv/sampling/tabular.py b/sdv/sampling/tabular.py index a7421253c..b3bb3ac6e 100644 --- a/sdv/sampling/tabular.py +++ b/sdv/sampling/tabular.py @@ -1,7 +1,7 @@ """SDV Condition class for sampling.""" -class Condition(): +class Condition: """Condition class. This class represents a condition that is used for sampling. diff --git a/sdv/sequential/__init__.py b/sdv/sequential/__init__.py index 42d2707ae..6c8507d35 100644 --- a/sdv/sequential/__init__.py +++ b/sdv/sequential/__init__.py @@ -2,6 +2,4 @@ from sdv.sequential.par import PARSynthesizer -__all__ = ( - 'PARSynthesizer', -) +__all__ = ('PARSynthesizer',) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 0321cd59d..ccd8a974f 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -63,11 +63,7 @@ class PARSynthesizer(LossValuesMixin, BaseSynthesizer): Whether to print progress to console or not. """ - _model_sdtype_transformers = { - 'categorical': None, - 'numerical': None, - 'boolean': None - } + _model_sdtype_transformers = {'categorical': None, 'numerical': None, 'boolean': None} def _get_context_metadata(self): context_columns_dict = {} @@ -84,9 +80,19 @@ def _get_context_metadata(self): context_metadata_dict = {'columns': context_columns_dict} return SingleTableMetadata.load_from_dict(context_metadata_dict) - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False, - locales=['en_US'], context_columns=None, segment_size=None, epochs=128, - sample_size=1, cuda=True, verbose=False): + def __init__( + self, + metadata, + enforce_min_max_values=True, + enforce_rounding=False, + locales=['en_US'], + context_columns=None, + segment_size=None, + epochs=128, + sample_size=1, + cuda=True, + verbose=False, + ): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, @@ -117,7 +123,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False self._context_synthesizer = GaussianCopulaSynthesizer( metadata=context_metadata, enforce_min_max_values=enforce_min_max_values, - enforce_rounding=enforce_rounding + enforce_rounding=enforce_rounding, ) def get_parameters(self): @@ -165,7 +171,8 @@ def add_constraints(self, constraints): if col in constraint_cols: raise SynthesizerInputError( 'The PARSynthesizer cannot accommodate multiple constraints ' - 'that overlap on the same columns.') + 'that overlap on the same columns.' + ) constraint_cols.append(col) all_context = all(col in context_set for col in constraint_cols) @@ -176,7 +183,8 @@ def add_constraints(self, constraints): else: raise SynthesizerInputError( 'The PARSynthesizer cannot accommodate constraints ' - 'with a mix of context and non-context columns.') + 'with a mix of context and non-context columns.' + ) def load_custom_constraint_classes(self, filepath, class_names): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" @@ -192,10 +200,12 @@ def _validate_context_columns(self, data): for sequence_key_value, data_values in data.groupby(_groupby_list(self._sequence_key)): for context_column in self.context_columns: if len(data_values[context_column].unique()) > 1: - errors.append(( - f"Context column '{context_column}' is changing inside sequence " - f'({self._sequence_key}={sequence_key_value}).' - )) + errors.append( + ( + f"Context column '{context_column}' is changing inside sequence " + f'({self._sequence_key}={sequence_key_value}).' + ) + ) return errors @@ -211,9 +221,12 @@ def _transform_sequence_index(self, data): if all(sequence_index[self._sequence_key].nunique() == 1): sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill() else: - sequence_index_sequence = sequence_index.groupby(self._sequence_key).apply( - lambda x: x[self._sequence_index].diff().bfill() - ).droplevel(1).reset_index() + sequence_index_sequence = ( + sequence_index.groupby(self._sequence_key) + .apply(lambda x: x[self._sequence_index].diff().bfill()) + .droplevel(1) + .reset_index() + ) if all(sequence_index_sequence[self._sequence_index].isna()): fill_value = 0 @@ -222,18 +235,13 @@ def _transform_sequence_index(self, data): sequence_index_sequence = sequence_index_sequence.fillna(fill_value) data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy() - data = data.merge( - sequence_index_context, - left_on=self._sequence_key, - right_index=True) + data = data.merge(sequence_index_context, left_on=self._sequence_key, right_index=True) - self.extended_columns[self._sequence_index] = FloatFormatter( - enforce_min_max_values=True) + self.extended_columns[self._sequence_index] = FloatFormatter(enforce_min_max_values=True) self.extended_columns[self._sequence_index].fit( - sequence_index_sequence, self._sequence_index) - self._extra_context_columns[f'{self._sequence_index}.context'] = { - 'sdtype': 'numerical' - } + sequence_index_sequence, self._sequence_index + ) + self._extra_context_columns[f'{self._sequence_index}.context'] = {'sdtype': 'numerical'} return data @@ -298,7 +306,8 @@ def update_transformers(self, column_name_to_transformer): """ if set(column_name_to_transformer).intersection(set(self.context_columns)): raise SynthesizerInputError( - 'Transformers for context columns are not allowed to be updated.') + 'Transformers for context columns are not allowed to be updated.' + ) super().update_transformers(column_name_to_transformer) @@ -307,8 +316,7 @@ def _fit_context_model(self, transformed): context_metadata: SingleTableMetadata = self._get_context_metadata() if self.context_columns or self._extra_context_columns: context_cols = ( - self._sequence_key + self.context_columns + - list(self._extra_context_columns.keys()) + self._sequence_key + self.context_columns + list(self._extra_context_columns.keys()) ) context = transformed[context_cols] else: @@ -321,7 +329,7 @@ def _fit_context_model(self, transformed): self._context_synthesizer = GaussianCopulaSynthesizer( context_metadata, enforce_min_max_values=self._context_synthesizer.enforce_min_max_values, - enforce_rounding=self._context_synthesizer.enforce_rounding + enforce_rounding=self._context_synthesizer.enforce_rounding, ) context = context.groupby(self._sequence_key).first().reset_index() self._context_synthesizer.fit(context) @@ -333,9 +341,9 @@ def _fit_sequence_columns(self, timeseries_data): self._data_columns = [ column for column in timeseries_data.columns - if column not in ( - self._sequence_key + self.context_columns + - list(self._extra_context_columns.keys()) + if column + not in ( + self._sequence_key + self.context_columns + list(self._extra_context_columns.keys()) ) ] @@ -345,7 +353,7 @@ def _fit_sequence_columns(self, timeseries_data): self.context_columns + list(self._extra_context_columns.keys()), self.segment_size, self._sequence_index, - drop_sequence_index=False + drop_sequence_index=False, ) data_types = [] context_types = [] @@ -424,8 +432,7 @@ def _sample_from_par(self, context, sequence_length=None): # Reformat as a DataFrame sequence_df = pd.DataFrame( - dict(zip(self._data_columns, sequence)), - columns=self._data_columns + dict(zip(self._data_columns, sequence)), columns=self._data_columns ) sequence_df[self._sequence_key] = sequence_key_values context_columns = self.context_columns + list(self._extra_context_columns.keys()) @@ -459,9 +466,7 @@ def sample(self, num_sequences, sequence_length=None): """ if self._sequence_key: context_columns = self._context_synthesizer._sample_with_progress_bar( - num_sequences, - output_file_path='disable', - show_progress_bar=False + num_sequences, output_file_path='disable', show_progress_bar=False ) else: @@ -493,12 +498,14 @@ def sample_sequential_columns(self, context_columns, sequence_length=None): 'to sample new sequences.' ) - condition_columns = list(set.intersection( - set(context_columns.columns), set(self._context_synthesizer._model.columns) - )) - condition_columns = context_columns[condition_columns].to_dict('records') - context = self._context_synthesizer.sample_from_conditions( - [Condition(conditions) for conditions in condition_columns] + condition_columns = list( + set.intersection( + set(context_columns.columns), set(self._context_synthesizer._model.columns) + ) ) + condition_columns = context_columns[condition_columns].to_dict('records') + context = self._context_synthesizer.sample_from_conditions([ + Condition(conditions) for conditions in condition_columns + ]) context.update(context_columns) return self._sample(context, sequence_length) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index f0d669d2a..135e37f3b 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -20,11 +20,19 @@ from sdv import version from sdv._utils import ( - _groupby_list, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) + _groupby_list, + check_sdv_versions_and_warn, + check_synthesizer_version, + generate_synthesizer_id, +) from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ( - ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) + ConstraintsNotMetError, + InvalidDataError, + SamplingError, + SynthesizerInputError, +) from sdv.logging import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path @@ -86,8 +94,9 @@ def _check_metadata_updated(self): ' in future SDV versions.' ) - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US']): + def __init__( + self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=['en_US'] + ): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata self.metadata.validate() @@ -121,7 +130,8 @@ def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" warnings.warn( '`set_address_columns` is deprecated. Please add these columns directly to your' - ' metadata using `add_column_relationship`.', DeprecationWarning + ' metadata using `add_column_relationship`.', + DeprecationWarning, ) def _validate_metadata(self, data): @@ -199,7 +209,8 @@ def _validate_transformers(self, column_name_to_transformer): # If columns were set, the transformer was fitted if transformer.columns: raise SynthesizerInputError( - f"Transformer for column '{column}' has already been fit on data.") + f"Transformer for column '{column}' has already been fit on data." + ) def _warn_for_update_transformers(self, column_name_to_transformer): """Raise warnings for update_transformers. @@ -352,7 +363,7 @@ def get_info(self): 'creation_date': self._creation_date, 'is_fit': self._fitted, 'last_fit_date': self._fitted_date, - 'fitted_sdv_version': self._fitted_sdv_version + 'fitted_sdv_version': self._fitted_sdv_version, } if self._fitted_sdv_enterprise_version is not None: info['fitted_sdv_enterprise_version'] = self._fitted_sdv_enterprise_version @@ -423,7 +434,7 @@ def fit_processed_data(self, processed_data): 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': len(processed_data), - 'TOTAL NUMBER OF COLUMNS': len(processed_data.columns) + 'TOTAL NUMBER OF COLUMNS': len(processed_data.columns), }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -449,7 +460,7 @@ def fit(self, data): 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': len(data), - 'TOTAL NUMBER OF COLUMNS': len(data.columns) + 'TOTAL NUMBER OF COLUMNS': len(data.columns), }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -575,8 +586,15 @@ def _filter_conditions(sampled, conditions, float_rtol): return sampled - def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, - float_rtol=0.1, previous_rows=None, keep_extra_columns=False): + def _sample_rows( + self, + num_rows, + conditions=None, + transformed_conditions=None, + float_rtol=0.1, + previous_rows=None, + keep_extra_columns=False, + ): """Sample rows with the given conditions. Input conditions is taken both in the raw input format, which will be used @@ -619,7 +637,6 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns if self._model and need_sample: - if conditions is None: raw_sampled = self._sample(num_rows) else: @@ -653,9 +670,17 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, sampled = self._data_processor.reverse_transform(sampled) return sampled, num_rows - def _sample_batch(self, batch_size, max_tries=100, - conditions=None, transformed_conditions=None, float_rtol=0.01, - progress_bar=None, output_file_path=None, keep_extra_columns=False): + def _sample_batch( + self, + batch_size, + max_tries=100, + conditions=None, + transformed_conditions=None, + float_rtol=0.01, + progress_bar=None, + output_file_path=None, + keep_extra_columns=False, + ): """Sample a batch of rows with the given conditions. This will enter a reject-sampling loop in which rows will be sampled until @@ -717,7 +742,7 @@ def _sample_batch(self, batch_size, max_tries=100, transformed_conditions, float_rtol, sampled, - keep_extra_columns + keep_extra_columns, ) num_new_valid_rows = num_valid - prev_num_valid @@ -742,7 +767,8 @@ def _sample_batch(self, batch_size, max_tries=100, if remaining > 0: LOGGER.info( - f'{remaining} valid rows remaining. Resampling {num_rows_to_sample} rows') + f'{remaining} valid rows remaining. Resampling {num_rows_to_sample} rows' + ) counter += 1 @@ -766,16 +792,25 @@ def _make_condition_dfs(conditions): for condition in conditions: column_values = condition.get_column_values() condition_dataframes[tuple(column_values.keys())].append( - pd.DataFrame(column_values, index=range(condition.get_num_rows()))) + pd.DataFrame(column_values, index=range(condition.get_num_rows())) + ) return [ pd.concat(condition_list, ignore_index=True) for condition_list in condition_dataframes.values() ] - def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditions=None, - transformed_conditions=None, float_rtol=0.01, progress_bar=None, - output_file_path=None): + def _sample_in_batches( + self, + num_rows, + batch_size, + max_tries_per_batch, + conditions=None, + transformed_conditions=None, + float_rtol=0.01, + progress_bar=None, + output_file_path=None, + ): sampled = [] batch_size = batch_size if num_rows > batch_size else num_rows for step in range(math.ceil(num_rows / batch_size)): @@ -793,10 +828,18 @@ def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditio sampled = pd.concat(sampled, ignore_index=True) if len(sampled) > 0 else pd.DataFrame() return sampled.head(num_rows) - def _conditionally_sample_rows(self, dataframe, condition, transformed_condition, - max_tries_per_batch=None, batch_size=None, float_rtol=0.01, - graceful_reject_sampling=True, progress_bar=None, - output_file_path=None): + def _conditionally_sample_rows( + self, + dataframe, + condition, + transformed_condition, + max_tries_per_batch=None, + batch_size=None, + float_rtol=0.01, + graceful_reject_sampling=True, + progress_bar=None, + output_file_path=None, + ): batch_size = batch_size or len(dataframe) sampled_rows = self._sample_in_batches( num_rows=len(dataframe), @@ -806,16 +849,15 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition transformed_conditions=transformed_condition, float_rtol=float_rtol, progress_bar=progress_bar, - output_file_path=output_file_path + output_file_path=output_file_path, ) if len(sampled_rows) > 0: - sampled_rows[COND_IDX] = dataframe[COND_IDX].to_numpy()[:len(sampled_rows)] + sampled_rows[COND_IDX] = dataframe[COND_IDX].to_numpy()[: len(sampled_rows)] elif not graceful_reject_sampling: user_msg = ( - 'Unable to sample any rows for the given conditions ' - f"'{transformed_condition}'. " + 'Unable to sample any rows for the given conditions ' f"'{transformed_condition}'. " ) if hasattr(self, '_model') and isinstance(self._model, GaussianMultivariate): user_msg = user_msg + ( @@ -833,8 +875,14 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition return sampled_rows - def _sample_with_progress_bar(self, num_rows, max_tries_per_batch=100, batch_size=None, - output_file_path=None, show_progress_bar=True): + def _sample_with_progress_bar( + self, + num_rows, + max_tries_per_batch=100, + batch_size=None, + output_file_path=None, + show_progress_bar=True, + ): if num_rows is None: raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') @@ -853,7 +901,7 @@ def _sample_with_progress_bar(self, num_rows, max_tries_per_batch=100, batch_siz batch_size=batch_size, max_tries_per_batch=max_tries_per_batch, progress_bar=progress_bar, - output_file_path=output_file_path + output_file_path=output_file_path, ) except (Exception, KeyboardInterrupt) as error: @@ -895,7 +943,7 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file max_tries_per_batch, batch_size, output_file_path, - show_progress_bar=show_progress_bar + show_progress_bar=show_progress_bar, ) original_columns = getattr(self, '_original_columns', pd.Index([])) @@ -909,14 +957,14 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file 'SYNTHESIZER ID': self._synthesizer_id, 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': len(sampled_data), - 'TOTAL NUMBER OF COLUMNS': len(sampled_data.columns) - + 'TOTAL NUMBER OF COLUMNS': len(sampled_data.columns), }) return sampled_data - def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, - progress_bar=None, output_file_path=None): + def _sample_with_conditions( + self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None + ): """Sample rows with conditions. Args: @@ -959,8 +1007,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, condition_df = dataframe.iloc[0].to_frame().T try: transformed_condition = self._data_processor.transform( - condition_df, - is_condition=True + condition_df, is_condition=True ) except ConstraintsNotMetError as error: raise ConstraintsNotMetError( @@ -968,8 +1015,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, ) from error transformed_conditions = pd.concat( - [transformed_condition] * len(dataframe), - ignore_index=True + [transformed_condition] * len(dataframe), ignore_index=True ) transformed_columns = list(transformed_conditions.columns) if not transformed_conditions.empty: @@ -1021,8 +1067,10 @@ def _validate_conditions_unseen_columns(self, conditions): """Validate the user-passed conditions.""" for column in conditions.columns: if column not in self._data_processor.get_sdtypes(): - raise ValueError(f"Unexpected column name '{column}'. " - f'Use a column name that was present in the original data.') + raise ValueError( + f"Unexpected column name '{column}'. " + f'Use a column name that was present in the original data.' + ) @staticmethod def _raise_condition_with_nans(): @@ -1038,8 +1086,9 @@ def _validate_conditions(self, conditions): if condition_dataframe.isna().any().any(): self._raise_condition_with_nans() - def sample_from_conditions(self, conditions, max_tries_per_batch=100, - batch_size=None, output_file_path=None): + def sample_from_conditions( + self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None + ): """Sample rows from this table with the given conditions. Args: @@ -1069,7 +1118,8 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, output_file_path = validate_file_path(output_file_path) num_rows = functools.reduce( - lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) + lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0 + ) conditions = self._make_condition_dfs(conditions) self._validate_conditions(conditions) @@ -1089,12 +1139,13 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) is_reject_sampling = bool( - hasattr(self, '_model') and not isinstance(self._model, GaussianMultivariate)) + hasattr(self, '_model') and not isinstance(self._model, GaussianMultivariate) + ) check_num_rows( num_rows=len(sampled), expected_num_rows=num_rows, is_reject_sampling=is_reject_sampling, - max_tries_per_batch=max_tries_per_batch + max_tries_per_batch=max_tries_per_batch, ) except (Exception, KeyboardInterrupt) as error: @@ -1113,8 +1164,9 @@ def _validate_known_columns(self, conditions): 'Rows with any missing values will not be created.' ) - def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, - batch_size=None, output_file_path=None): + def sample_remaining_columns( + self, known_columns, max_tries_per_batch=100, batch_size=None, output_file_path=None + ): """Sample remaining rows from already known columns. Args: @@ -1150,16 +1202,18 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, with tqdm.tqdm(total=len(known_columns)) as progress_bar: progress_bar.set_description('Sampling remaining columns') sampled = self._sample_with_conditions( - known_columns, max_tries_per_batch, batch_size, progress_bar, output_file_path) + known_columns, max_tries_per_batch, batch_size, progress_bar, output_file_path + ) - is_reject_sampling = (hasattr(self, '_model') and not isinstance( - self._model, copulas.multivariate.GaussianMultivariate)) + is_reject_sampling = hasattr(self, '_model') and not isinstance( + self._model, copulas.multivariate.GaussianMultivariate + ) check_num_rows( num_rows=len(sampled), expected_num_rows=len(known_columns), is_reject_sampling=is_reject_sampling, - max_tries_per_batch=max_tries_per_batch + max_tries_per_batch=max_tries_per_batch, ) except (Exception, KeyboardInterrupt) as error: diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index c9309b45c..19ca50b2e 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -1,4 +1,5 @@ """Combination of GaussianCopula transformation and GANs.""" + import logging from copy import deepcopy @@ -7,7 +8,9 @@ from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer from sdv.single_table.utils import ( - log_numerical_distributions_error, validate_numerical_distributions) + log_numerical_distributions_error, + validate_numerical_distributions, +) LOGGER = logging.getLogger(__name__) @@ -116,13 +119,29 @@ class CopulaGANSynthesizer(CTGANSynthesizer): _gaussian_normalizer_hyper_transformer = None - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], embedding_dim=128, generator_dim=(256, 256), - discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, - discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, - discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, numerical_distributions=None, default_distribution=None): - + def __init__( + self, + metadata, + enforce_min_max_values=True, + enforce_rounding=True, + locales=['en_US'], + embedding_dim=128, + generator_dim=(256, 256), + discriminator_dim=(256, 256), + generator_lr=2e-4, + generator_decay=1e-6, + discriminator_lr=2e-4, + discriminator_decay=1e-6, + batch_size=500, + discriminator_steps=1, + log_frequency=True, + verbose=False, + epochs=300, + pac=10, + cuda=True, + numerical_distributions=None, + default_distribution=None, + ): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, @@ -164,10 +183,7 @@ def _create_gaussian_normalizer_config(self, processed_data): sdtype = columns.get(column, {}).get('sdtype') if column in columns and sdtype not in ['categorical', 'boolean']: sdtypes[column] = 'numerical' - distribution = self._numerical_distributions.get( - column, - self._default_distribution - ) + distribution = self._numerical_distributions.get(column, self._default_distribution) transformers[column] = rdt.transformers.GaussianNormalizer( missing_value_generation='from_column', @@ -188,7 +204,8 @@ def _fit(self, processed_data): Data to be learned. """ log_numerical_distributions_error( - self.numerical_distributions, processed_data.columns, LOGGER) + self.numerical_distributions, processed_data.columns, LOGGER + ) gaussian_normalizer_config = self._create_gaussian_normalizer_config(processed_data) self._gaussian_normalizer_hyper_transformer = rdt.HyperTransformer() @@ -239,12 +256,11 @@ def get_learned_distributions(self): learned_params = deepcopy(transformer._univariate.to_dict()) learned_params.pop('type') distribution = self.numerical_distributions.get( - column_name, - self.default_distribution + column_name, self.default_distribution ) learned_distributions[column_name] = { 'distribution': distribution, - 'learned_parameters': learned_params + 'learned_parameters': learned_params, } return learned_distributions diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 4fc213949..43aca493a 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -1,4 +1,5 @@ """Wrappers around copulas models.""" + import inspect import logging import warnings @@ -15,8 +16,11 @@ from sdv.errors import NonParametricError from sdv.single_table.base import BaseSingleTableSynthesizer from sdv.single_table.utils import ( - flatten_dict, log_numerical_distributions_error, unflatten_dict, - validate_numerical_distributions) + flatten_dict, + log_numerical_distributions_error, + unflatten_dict, + validate_numerical_distributions, +) LOGGER = logging.getLogger(__name__) @@ -90,8 +94,15 @@ def get_distribution_class(cls, distribution): return cls._DISTRIBUTIONS[distribution] - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], numerical_distributions=None, default_distribution=None): + def __init__( + self, + metadata, + enforce_min_max_values=True, + enforce_rounding=True, + locales=['en_US'], + numerical_distributions=None, + default_distribution=None, + ): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, @@ -117,18 +128,18 @@ def _fit(self, processed_data): Data to be learned. """ log_numerical_distributions_error( - self.numerical_distributions, processed_data.columns, LOGGER) + self.numerical_distributions, processed_data.columns, LOGGER + ) self._num_rows = len(processed_data) numerical_distributions = deepcopy(self._numerical_distributions) for column in processed_data.columns: if column not in numerical_distributions: numerical_distributions[column] = self._numerical_distributions.get( - column, self._default_distribution) + column, self._default_distribution + ) - self._model = multivariate.GaussianMultivariate( - distribution=numerical_distributions - ) + self._model = multivariate.GaussianMultivariate(distribution=numerical_distributions) with warnings.catch_warnings(): warnings.filterwarnings('ignore', module='scipy') @@ -202,7 +213,7 @@ def get_learned_distributions(self): learned_params.pop('type') learned_distributions[column] = { 'distribution': distribution, - 'learned_parameters': learned_params + 'learned_parameters': learned_params, } return learned_distributions @@ -233,7 +244,7 @@ def _get_parameters(self): correlation = [] for index, row in enumerate(params['correlation'][1:]): - correlation.append(row[:index + 1]) + correlation.append(row[: index + 1]) params['correlation'] = correlation params['univariates'] = dict(zip(params.pop('columns'), params['univariates'])) @@ -366,9 +377,7 @@ def _rebuild_gaussian_copula(self, model_parameters, default_params=None): univariate = default_params['univariates'][column] univariate['type'] = univariate_type else: - LOGGER.debug( - f"Column '{column}' has invalid parameters." - ) + LOGGER.debug(f"Column '{column}' has invalid parameters.") else: LOGGER.debug(f"Univariate for col '{column}' does not have _argcheck method.") diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index edeacb35c..e23932190 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -1,4 +1,5 @@ """Wrapper around CTGAN model.""" + import numpy as np import pandas as pd import plotly.express as px @@ -74,11 +75,12 @@ def get_loss_values_plot(self, title='CTGAN loss function'): # Create a pretty chart using Plotly Express fig = px.line( - loss_df, x='Epoch', + loss_df, + x='Epoch', y=['Generator Loss', 'Discriminator Loss'], color_discrete_map={ 'Generator Loss': visualization.PlotConfig.DATACEBO_DARK, - 'Discriminator Loss': visualization.PlotConfig.DATACEBO_GREEN + 'Discriminator Loss': visualization.PlotConfig.DATACEBO_GREEN, }, ) fig.update_layout( @@ -86,7 +88,7 @@ def get_loss_values_plot(self, title='CTGAN loss function'): legend_title_text='', legend_orientation='v', plot_bgcolor=visualization.PlotConfig.BACKGROUND_COLOR, - font={'size': visualization.PlotConfig.FONT_SIZE} + font={'size': visualization.PlotConfig.FONT_SIZE}, ) fig.update_layout(title=title, xaxis_title='Epoch', yaxis_title='Loss') return fig @@ -145,18 +147,29 @@ class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): If ``False``, do not use cuda at all. """ - _model_sdtype_transformers = { - 'categorical': None, - 'boolean': None - } - - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], embedding_dim=128, generator_dim=(256, 256), - discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, - discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, - discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True): - + _model_sdtype_transformers = {'categorical': None, 'boolean': None} + + def __init__( + self, + metadata, + enforce_min_max_values=True, + enforce_rounding=True, + locales=['en_US'], + embedding_dim=128, + generator_dim=(256, 256), + discriminator_dim=(256, 256), + generator_lr=2e-4, + generator_decay=1e-6, + discriminator_lr=2e-4, + discriminator_decay=1e-6, + batch_size=500, + discriminator_steps=1, + log_frequency=True, + verbose=False, + epochs=300, + pac=10, + cuda=True, + ): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, @@ -193,7 +206,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, 'verbose': verbose, 'epochs': epochs, 'pac': pac, - 'cuda': cuda + 'cuda': cuda, } def _estimate_num_columns(self, data): @@ -271,9 +284,7 @@ def _fit(self, processed_data): transformers = self._data_processor._hyper_transformer.field_transformers discrete_columns = detect_discrete_columns( - self.get_metadata(), - processed_data, - transformers + self.get_metadata(), processed_data, transformers ) self._model = CTGAN(**self._model_kwargs) self._model.fit(processed_data, discrete_columns=discrete_columns) @@ -333,16 +344,23 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): If ``False``, do not use cuda at all. """ - _model_sdtype_transformers = { - 'categorical': None, - 'boolean': None - } - - def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, verbose=False, epochs=300, loss_factor=2, - cuda=True): - + _model_sdtype_transformers = {'categorical': None, 'boolean': None} + + def __init__( + self, + metadata, + enforce_min_max_values=True, + enforce_rounding=True, + embedding_dim=128, + compress_dims=(128, 128), + decompress_dims=(128, 128), + l2scale=1e-5, + batch_size=500, + verbose=False, + epochs=300, + loss_factor=2, + cuda=True, + ): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, @@ -367,7 +385,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, 'verbose': verbose, 'epochs': epochs, 'loss_factor': loss_factor, - 'cuda': cuda + 'cuda': cuda, } def _fit(self, processed_data): @@ -381,9 +399,7 @@ def _fit(self, processed_data): transformers = self._data_processor._hyper_transformer.field_transformers discrete_columns = detect_discrete_columns( - self.get_metadata(), - processed_data, - transformers + self.get_metadata(), processed_data, transformers ) self._model = TVAE(**self._model_kwargs) self._model.fit(processed_data, discrete_columns=discrete_columns) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 1ecb68267..527f845ec 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -64,7 +64,7 @@ def detect_discrete_columns(metadata, data, transformers): is_float = not is_int num_values = len(column_data) num_categories = column_data.nunique() - threshold = max(10, num_values * .1) + threshold = max(10, num_values * 0.1) has_many_categories = num_categories > threshold if is_float or (is_int and has_many_categories): continue @@ -95,9 +95,7 @@ def handle_sampling_error(output_file_path, sampling_error): error_msg = None if output_file_path is not None: - error_msg = ( - f'Error: Sampling terminated. Partial results are stored in {output_file_path}.' - ) + error_msg = f'Error: Sampling terminated. Partial results are stored in {output_file_path}.' else: error_msg = ( 'Error: Sampling terminated. No results were saved due to unspecified ' @@ -132,7 +130,7 @@ def check_num_rows(num_rows, expected_num_rows, is_reject_sampling, max_tries_pe """ if num_rows < expected_num_rows: if num_rows == 0: - user_msg = ('Unable to sample any rows for the given conditions. ') + user_msg = 'Unable to sample any rows for the given conditions. ' if is_reject_sampling: user_msg = user_msg + ( f'Try increasing `max_tries_per_batch` (currently: {max_tries_per_batch}). ' diff --git a/sdv/utils/__init__.py b/sdv/utils/__init__.py index 981fca3a5..38e8db044 100644 --- a/sdv/utils/__init__.py +++ b/sdv/utils/__init__.py @@ -2,6 +2,4 @@ from sdv.utils.utils import drop_unknown_references -__all__ = ( - 'drop_unknown_references', -) +__all__ = ('drop_unknown_references',) diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index 40f3944e9..682895bad 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -1,12 +1,18 @@ """POC functions to use HMASynthesizer succesfully.""" + import warnings from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS from sdv.multi_table.utils import ( - _get_total_estimated_columns, _print_simplified_schema_summary, _print_subsample_summary, - _simplify_data, _simplify_metadata, _subsample_data) + _get_total_estimated_columns, + _print_simplified_schema_summary, + _print_subsample_summary, + _simplify_data, + _simplify_metadata, + _subsample_data, +) from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references @@ -14,7 +20,8 @@ def drop_unknown_references(data, metadata): """Wrap the drop_unknown_references function from the utils module.""" warnings.warn( "Please access the 'drop_unknown_references' function directly from the sdv.utils module" - 'instead of sdv.utils.poc.', FutureWarning + 'instead of sdv.utils.poc.', + FutureWarning, ) return utils_drop_unknown_references(data, metadata) @@ -113,9 +120,7 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True): except InvalidDataError as error: raise InvalidDataError([error_message]) from error - error_message_num_rows = ( - '``num_rows`` must be a positive integer.' - ) + error_message_num_rows = '``num_rows`` must be a positive integer.' if not isinstance(num_rows, (int, float)) or num_rows != int(num_rows): raise ValueError(error_message_num_rows) diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index 5b3589b1f..f6e5db7c0 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -1,4 +1,5 @@ """Utils module.""" + import sys from copy import deepcopy @@ -36,7 +37,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr 'Table Name': table_names, '# Rows (Original)': [len(data[table]) for table in table_names], '# Invalid Rows': [0] * len(table_names), - '# Rows (New)': [len(data[table]) for table in table_names] + '# Rows (New)': [len(data[table]) for table in table_names], }) metadata.validate() try: @@ -45,9 +46,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr _validate_foreign_keys_not_null(metadata, data) if verbose: - sys.stdout.write( - '\n'.join([success_message, '', summary_table.to_string(index=False)]) - ) + sys.stdout.write('\n'.join([success_message, '', summary_table.to_string(index=False)])) return data except (InvalidDataError, SynthesizerInputError): @@ -58,8 +57,6 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr len(data[table]) - len(result[table]) for table in table_names ] summary_table['# Rows (New)'] = [len(result[table]) for table in table_names] - sys.stdout.write('\n'.join([ - success_message, '', summary_table.to_string(index=False) - ])) + sys.stdout.write('\n'.join([success_message, '', summary_table.to_string(index=False)])) return result diff --git a/sdv/version/__init__.py b/sdv/version/__init__.py index b15ffd036..b54ebaf8d 100644 --- a/sdv/version/__init__.py +++ b/sdv/version/__init__.py @@ -1,4 +1,5 @@ """SDV versions.""" + from importlib.metadata import version public = version('sdv') diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index c2b017843..000000000 --- a/setup.cfg +++ /dev/null @@ -1,22 +0,0 @@ -[flake8] -max-line-length = 99 -inline-quotes = single -exclude = docs, .tox, .git, __pycache__, .ipynb_checkpoints -extend-ignore = - # Missing docstring in magic method - D105, - # Missing docstring in __init__ - D107, - # Use arithmetic operator instead of method - PD005, - # String literal formatting using f-string - SFS3, - # TokenError: unterminated string literal - E902, - # Mutable default arg of type List - M511, - # Logging and IO shadowing python's builtins - A005 - -[aliases] -test = pytest diff --git a/tasks.py b/tasks.py index 97130a21c..45bbfa117 100644 --- a/tasks.py +++ b/tasks.py @@ -12,12 +12,7 @@ from packaging.requirements import Requirement from packaging.version import Version -COMPARISONS = { - '>=': operator.ge, - '>': operator.gt, - '<': operator.lt, - '<=': operator.le -} +COMPARISONS = {'>=': operator.ge, '>': operator.gt, '<': operator.lt, '<=': operator.le} if not hasattr(inspect, 'getargspec'): @@ -54,15 +49,22 @@ def _get_minimum_versions(dependencies, python_version): continue # Skip this dependency if the marker does not apply to the current Python version if req.name not in min_versions: - min_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None) + min_version = next( + (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None + ) if min_version: min_versions[req.name] = f'{req.name}=={min_version}' elif '@' not in min_versions[req.name]: existing_version = Version(min_versions[req.name].split('==')[1]) - new_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), existing_version) + new_version = next( + (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), + existing_version, + ) if new_version > existing_version: - min_versions[req.name] = f'{req.name}=={new_version}' # Change when a valid newer version is found + min_versions[req.name] = ( + f'{req.name}=={new_version}' # Change when a valid newer version is found + ) return list(min_versions.values()) @@ -77,7 +79,8 @@ def install_minimum(c): minimum_versions = _get_minimum_versions(dependencies, python_version) if minimum_versions: - c.run(f'python -m pip install {" ".join(minimum_versions)}') + install_deps = ' '.join(minimum_versions) + c.run(f'python -m pip install {install_deps}') @task @@ -107,19 +110,20 @@ def readme(c): def tutorials(c): for ipynb_file in glob.glob('tutorials/*.ipynb') + glob.glob('tutorials/**/*.ipynb'): if '.ipynb_checkpoints' not in ipynb_file: - c.run(( - 'jupyter nbconvert --execute --ExecutePreprocessor.timeout=3600 ' - f'--to=html --stdout {ipynb_file}' - ), hide='out') + c.run( + ( + 'jupyter nbconvert --execute --ExecutePreprocessor.timeout=3600 ' + f'--to=html --stdout {ipynb_file}' + ), + hide='out', + ) @task def lint(c): check_dependencies(c) - c.run('flake8 sdv') - c.run('flake8 tests --ignore=D,SFS2') - c.run('isort -c sdv tests') - c.run('pydocstyle sdv') + c.run('ruff check .') + c.run('ruff format . --check') def remove_readonly(func, path, _): diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index 358415625..58e8f7c2a 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -1,13 +1,20 @@ """Integration tests for the ``DataProcessor``.""" + import itertools import re import numpy as np import pandas as pd import pytest +from pandas.api.types import is_float_dtype from rdt.transformers import ( - AnonymizedFaker, BinaryEncoder, FloatFormatter, IDGenerator, UniformEncoder, - UnixTimestampEncoder) + AnonymizedFaker, + BinaryEncoder, + FloatFormatter, + IDGenerator, + UniformEncoder, + UnixTimestampEncoder, +) from sdv._utils import _get_datetime_format from sdv.data_processing import DataProcessor @@ -238,8 +245,7 @@ def test_prepare_for_fitting(self): """ # Setup data, metadata = download_demo( - modality='single_table', - dataset_name='student_placements_pii' + modality='single_table', dataset_name='student_placements_pii' ) dp = DataProcessor(metadata) @@ -266,7 +272,7 @@ def test_prepare_for_fitting(self): 'high_spec': UniformEncoder, 'high_perc': FloatFormatter, 'work_experience': UniformEncoder, - 'degree_perc': FloatFormatter + 'degree_perc': FloatFormatter, } for column_name, transformer_class in expected_transformers.items(): if transformer_class is not None: @@ -281,10 +287,7 @@ def test_prepare_for_fitting(self): def test_reverse_transform_with_formatters(self): """End to end test using formatters.""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='student_placements' - ) + data, metadata = download_demo(modality='single_table', dataset_name='student_placements') dp = DataProcessor(metadata) # Run @@ -316,19 +319,14 @@ def test_reverse_transform_with_formatters(self): assert start_date_data_format == reversed_start_date_format end_date_data_format = _get_datetime_format(data['end_date'][~data['end_date'].isna()][0]) - reversed_end_date = reverse_transformed['end_date'][ - ~reverse_transformed['end_date'].isna() - ] + reversed_end_date = reverse_transformed['end_date'][~reverse_transformed['end_date'].isna()] reversed_end_date_format = _get_datetime_format(reversed_end_date.iloc[0]) assert end_date_data_format == reversed_end_date_format def test_refit_hypertransformer(self): """Test data processor re-fits _hyper_transformer.""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='student_placements' - ) + data, metadata = download_demo(modality='single_table', dataset_name='student_placements') dp = DataProcessor(metadata) # Run @@ -342,7 +340,7 @@ def test_refit_hypertransformer(self): dp.fit(data) transformed = dp.transform(data) - assert all(transformed.dtypes == float) + assert all([is_float_dtype(dtype) for dtype in transformed.dtypes]) def test_localized_anonymized_columns(self): """Test data processor uses the default locale for anonymized columns.""" @@ -362,10 +360,7 @@ def test_categorical_column_with_numbers(self): """Test that UniformEncoder is assigned for categorical columns defined with numbers.""" # Setup data = pd.DataFrame({ - 'category_col': [ - 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 1, 2, - 1, 1, 2, 1, 2, 2 - ], + 'category_col': [1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2], 'numerical_col': np.random.rand(20), }) @@ -381,12 +376,9 @@ def test_categorical_column_with_numbers(self): assert isinstance(dp._hyper_transformer.field_transformers['category_col'], UniformEncoder) def test_update_transformers_id_generator(self): - """ Test that updating to transformer to id generator is valid""" + """Test that updating to transformer to id generator is valid""" # Setup - data = pd.DataFrame({ - 'user_id': list(range(4)), - 'user_cat': ['a', 'b', 'c', 'd'] - }) + data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) metadata.update_column('user_id', sdtype='id') diff --git a/tests/integration/dataset.py b/tests/integration/dataset.py index 88d70c2ac..aea8ec211 100644 --- a/tests/integration/dataset.py +++ b/tests/integration/dataset.py @@ -4,15 +4,8 @@ def load_multi_foreign_key(): - parent = pd.DataFrame({ - 'parent_id': range(10), - 'value': range(10) - }) - child = pd.DataFrame({ - 'parent_1_id': range(10), - 'parent_2_id': range(10), - 'value': range(10) - }) + parent = pd.DataFrame({'parent_id': range(10), 'value': range(10)}) + child = pd.DataFrame({'parent_1_id': range(10), 'parent_2_id': range(10), 'value': range(10)}) metadata = Metadata() metadata.add_table('parent', parent, primary_key='parent_id') diff --git a/tests/integration/datasets/test_demo.py b/tests/integration/datasets/test_demo.py index ae8c3058d..7fac34495 100644 --- a/tests/integration/datasets/test_demo.py +++ b/tests/integration/datasets/test_demo.py @@ -1,4 +1,3 @@ - import pandas as pd from pandas.api.types import is_integer_dtype @@ -13,20 +12,32 @@ def test_get_available_demos_single_table(): # Assert expected_table = pd.DataFrame({ 'dataset_name': [ - 'adult', 'alarm', 'census', - 'child', 'covtype', 'expedia_hotel_logs', - 'insurance', 'intrusion', 'news' + 'adult', + 'alarm', + 'census', + 'child', + 'covtype', + 'expedia_hotel_logs', + 'insurance', + 'intrusion', + 'news', ], 'size_MB': [ - '3.907448', '4.520128', '98.165608', - '3.200128', '255.645408', '0.200128', - '3.340128', '162.039016', '18.712096' + '3.907448', + '4.520128', + '98.165608', + '3.200128', + '255.645408', + '0.200128', + '3.340128', + '162.039016', + '18.712096', ], - 'num_tables': ['1'] * 9 + 'num_tables': ['1'] * 9, }) expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2) expected_table['num_tables'] = pd.to_numeric(expected_table['num_tables']) - assert (is_integer_dtype(tables_info['num_tables'])) + assert is_integer_dtype(tables_info['num_tables']) assert len(expected_table.merge(tables_info)) == len(expected_table) @@ -38,38 +49,172 @@ def test_get_available_demos_multi_table(): # Assert expected_table = pd.DataFrame({ 'dataset_name': [ - 'Accidents_v1', 'Atherosclerosis_v1', 'AustralianFootball_v1', - 'Biodegradability_v1', 'Bupa_v1', 'CORA_v1', 'Carcinogenesis_v1', - 'Chess_v1', 'Countries_v1', 'DCG_v1', 'Dunur_v1', 'Elti_v1', - 'FNHK_v1', 'Facebook_v1', 'Hepatitis_std_v1', 'Mesh_v1', - 'Mooney_Family_v1', 'MuskSmall_v1', 'NBA_v1', 'NCAA_v1', - 'PTE_v1', 'Pima_v1', 'PremierLeague_v1', 'Pyrimidine_v1', - 'SAP_v1', 'SAT_v1', 'SalesDB_v1', 'Same_gen_v1', - 'Student_loan_v1', 'Telstra_v1', 'Toxicology_v1', 'Triazine_v1', - 'TubePricing_v1', 'UTube_v1', 'UW_std_v1', 'WebKP_v1', - 'airbnb-simplified', 'financial_v1', 'ftp_v1', 'genes_v1', - 'got_families', 'imdb_MovieLens_v1', 'imdb_ijs_v1', 'imdb_small_v1', - 'legalActs_v1', 'mutagenesis_v1', 'nations_v1', 'restbase_v1', - 'rossmann', 'trains_v1', 'university_v1', 'walmart', 'world_v1' + 'Accidents_v1', + 'Atherosclerosis_v1', + 'AustralianFootball_v1', + 'Biodegradability_v1', + 'Bupa_v1', + 'CORA_v1', + 'Carcinogenesis_v1', + 'Chess_v1', + 'Countries_v1', + 'DCG_v1', + 'Dunur_v1', + 'Elti_v1', + 'FNHK_v1', + 'Facebook_v1', + 'Hepatitis_std_v1', + 'Mesh_v1', + 'Mooney_Family_v1', + 'MuskSmall_v1', + 'NBA_v1', + 'NCAA_v1', + 'PTE_v1', + 'Pima_v1', + 'PremierLeague_v1', + 'Pyrimidine_v1', + 'SAP_v1', + 'SAT_v1', + 'SalesDB_v1', + 'Same_gen_v1', + 'Student_loan_v1', + 'Telstra_v1', + 'Toxicology_v1', + 'Triazine_v1', + 'TubePricing_v1', + 'UTube_v1', + 'UW_std_v1', + 'WebKP_v1', + 'airbnb-simplified', + 'financial_v1', + 'ftp_v1', + 'genes_v1', + 'got_families', + 'imdb_MovieLens_v1', + 'imdb_ijs_v1', + 'imdb_small_v1', + 'legalActs_v1', + 'mutagenesis_v1', + 'nations_v1', + 'restbase_v1', + 'rossmann', + 'trains_v1', + 'university_v1', + 'walmart', + 'world_v1', ], 'size_MB': [ - '296.202744', '7.916808', '32.534832', '0.692008', '0.059144', '1.987328', '1.642592', - '0.403784', '10.52272', '0.321536', '0.020224', '0.054912', '141.560872', '1.481056', - '0.809472', '0.101856', '0.121784', '0.646752', '0.16632', '29.137896', '1.31464', - '0.160896', '17.37664', '0.038144', '196.479272', '0.500224', '325.19768', '0.056176', - '0.180256', '5.503512', '1.495496', '0.156496', '15.414536', '0.135912', '0.0576', - '1.9718', '293.14392', '94.718016', '5.45568', '0.440016', '0.001', '55.253264', - '259.140656', '0.205728', '186.132944', '0.618088', '0.540336', '1.01452', '73.328504', - '0.00644', '0.009632', '14.642184', '0.295032' + '296.202744', + '7.916808', + '32.534832', + '0.692008', + '0.059144', + '1.987328', + '1.642592', + '0.403784', + '10.52272', + '0.321536', + '0.020224', + '0.054912', + '141.560872', + '1.481056', + '0.809472', + '0.101856', + '0.121784', + '0.646752', + '0.16632', + '29.137896', + '1.31464', + '0.160896', + '17.37664', + '0.038144', + '196.479272', + '0.500224', + '325.19768', + '0.056176', + '0.180256', + '5.503512', + '1.495496', + '0.156496', + '15.414536', + '0.135912', + '0.0576', + '1.9718', + '293.14392', + '94.718016', + '5.45568', + '0.440016', + '0.001', + '55.253264', + '259.140656', + '0.205728', + '186.132944', + '0.618088', + '0.540336', + '1.01452', + '73.328504', + '0.00644', + '0.009632', + '14.642184', + '0.295032', ], 'num_tables': [ - '3', '4', '4', '5', '9', '3', '6', '2', '4', '2', '17', '11', '3', '2', - '7', '29', '68', '2', '4', '9', '38', '9', '4', '2', '4', '36', '4', - '4', '10', '5', '4', '2', '20', '2', '4', '3', '2', '8', '2', '3', '3', - '7', '7', '7', '5', '3', '3', '3', '2', '2', '5', '3', '3' - ] + '3', + '4', + '4', + '5', + '9', + '3', + '6', + '2', + '4', + '2', + '17', + '11', + '3', + '2', + '7', + '29', + '68', + '2', + '4', + '9', + '38', + '9', + '4', + '2', + '4', + '36', + '4', + '4', + '10', + '5', + '4', + '2', + '20', + '2', + '4', + '3', + '2', + '8', + '2', + '3', + '3', + '7', + '7', + '7', + '5', + '3', + '3', + '3', + '2', + '2', + '5', + '3', + '3', + ], }) expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2) expected_table['num_tables'] = pd.to_numeric(expected_table['num_tables']) - assert (is_integer_dtype(tables_info['num_tables'])) + assert is_integer_dtype(tables_info['num_tables']) assert len(expected_table.merge(tables_info, on='dataset_name')) == len(expected_table) diff --git a/tests/integration/datasets/test_local.py b/tests/integration/datasets/test_local.py index 135375372..ed7dfb8d6 100644 --- a/tests/integration/datasets/test_local.py +++ b/tests/integration/datasets/test_local.py @@ -6,33 +6,27 @@ @pytest.fixture def data(): - parent = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - 'B': [0.434, 0.312, 0.212, 0.339, 0.491] - }) - - child = pd.DataFrame(data={ - 'parent_id': [0, 1, 2, 2, 5], - 'C': ['Yes', 'No', 'Maye', 'No', 'No'] - }) - - grandchild = pd.DataFrame(data={ - 'child_id': [0, 1, 2, 3, 4], - 'D': [0.434, 0.312, 0.212, 0.339, 0.491] - }) - - grandchild2 = pd.DataFrame(data={ - 'child_id': [0, 1, 2, 3, 4], - 'E': [0.434, 0.312, 0.212, 0.339, 0.491] - }) - - return { - 'parent': parent, - 'child': child, - 'grandchild': grandchild, - 'grandchild2': grandchild2 - } + parent = pd.DataFrame( + data={ + 'id': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + 'B': [0.434, 0.312, 0.212, 0.339, 0.491], + } + ) + + child = pd.DataFrame( + data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maye', 'No', 'No']} + ) + + grandchild = pd.DataFrame( + data={'child_id': [0, 1, 2, 3, 4], 'D': [0.434, 0.312, 0.212, 0.339, 0.491]} + ) + + grandchild2 = pd.DataFrame( + data={'child_id': [0, 1, 2, 3, 4], 'E': [0.434, 0.312, 0.212, 0.339, 0.491]} + ) + + return {'parent': parent, 'child': child, 'grandchild': grandchild, 'grandchild2': grandchild2} def test_save_csvs(data, tmpdir): diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 769465477..8d3771164 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -1,4 +1,3 @@ - import pandas as pd from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic @@ -31,16 +30,16 @@ def test_evaluation(): 'id': {'sdtype': 'id'}, 'col': {'sdtype': 'numerical'}, }, - } + }, }, 'relationships': [ { 'parent_table_name': 'table1', 'parent_primary_key': 'id', 'child_table_name': 'table2', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } - ] + ], }) # Run and Assert @@ -53,6 +52,6 @@ def test_evaluation(): report.get_properties(), pd.DataFrame({ 'Property': ['Data Validity', 'Data Structure', 'Relationship Validity'], - 'Score': [1., 1., 1.], - }) + 'Score': [1.0, 1.0, 1.0], + }), ) diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index c4d1bd48f..5b9ee9123 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -1,4 +1,3 @@ - import pandas as pd from sdv.datasets.demo import download_demo @@ -27,18 +26,15 @@ def test_evaluation(): report.get_properties(), pd.DataFrame({ 'Property': ['Data Validity', 'Data Structure'], - 'Score': [1., 1.], - }) + 'Score': [1.0, 1.0], + }), ) def test_column_pair_plot_sample_size_parameter(): """Test the sample_size parameter for the column pair plot.""" # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(real_data) synthetic_data = synthesizer.sample(len(real_data)) @@ -49,7 +45,7 @@ def test_column_pair_plot_sample_size_parameter(): synthetic_data=synthetic_data, column_names=['room_rate', 'amenities_fee'], metadata=metadata, - sample_size=40 + sample_size=40, ) # Assert diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py index c133e1373..2fa7e00c3 100644 --- a/tests/integration/io/local/test_local.py +++ b/tests/integration/io/local/test_local.py @@ -5,13 +5,12 @@ class TestCSVHandler: - def test_integration_write_and_read(self, tmpdir): """Test end to end the write and read methods of ``CSVHandler``.""" # Prepare synthetic data synthetic_data = { 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), - 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}), } # Write synthetic data to CSV files @@ -36,13 +35,12 @@ def test_integration_write_and_read(self, tmpdir): class TestExcelHandler: - def test_integration_write_and_read(self, tmpdir): """Test end to end the write and read methods of ``ExcelHandler``.""" # Prepare synthetic data synthetic_data = { 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), - 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}), } # Write synthetic data to xslx files @@ -70,7 +68,7 @@ def test_integration_write_and_read_append_mode(self, tmpdir): # Prepare synthetic data synthetic_data = { 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), - 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}), } # Write synthetic data to xslx files @@ -97,12 +95,10 @@ def test_integration_write_and_read_append_mode(self, tmpdir): # Check if the dataframes match the original synthetic data expected_table_one = pd.concat( - [synthetic_data['table1'], synthetic_data['table1']], - ignore_index=True + [synthetic_data['table1'], synthetic_data['table1']], ignore_index=True ) expected_table_two = pd.concat( - [synthetic_data['table2'], synthetic_data['table2']], - ignore_index=True + [synthetic_data['table2'], synthetic_data['table2']], ignore_index=True ) pd.testing.assert_frame_equal(data['table1'], expected_table_one) pd.testing.assert_frame_equal(data['table2'], expected_table_two) diff --git a/tests/integration/lite/test_single_table.py b/tests/integration/lite/test_single_table.py index 0358433be..90c66cb5d 100644 --- a/tests/integration/lite/test_single_table.py +++ b/tests/integration/lite/test_single_table.py @@ -16,11 +16,7 @@ def test_sample(): metadata.detect_from_dataframe(data) preset = SingleTablePreset(metadata, name='FAST_ML') preset.fit(data) - samples = preset.sample( - num_rows=10, - max_tries_per_batch=20, - batch_size=5 - ) + samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) # Assert assert samples['a'].all() in [1, 2, 3, np.nan] @@ -30,10 +26,7 @@ def test_sample(): def test_sample_with_constraints(): """Test sampling for the ``SingleTablePreset``.""" # Setup - data = pd.DataFrame({ - 'a': [1, 2, 3], - 'b': [4, 5, 6] - }) + data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run metadata = SingleTableMetadata() @@ -42,19 +35,12 @@ def test_sample_with_constraints(): constraints = [ { 'constraint_class': 'Inequality', - 'constraint_parameters': { - 'low_column_name': 'a', - 'high_column_name': 'b' - } + 'constraint_parameters': {'low_column_name': 'a', 'high_column_name': 'b'}, } ] preset.add_constraints(constraints) preset.fit(data) - samples = preset.sample( - num_rows=10, - max_tries_per_batch=20, - batch_size=5 - ) + samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) # Assert assert len(samples) == 10 @@ -68,10 +54,7 @@ def test_warnings_are_shown(): "functionality, please use the 'GaussianCopulaSynthesizer'." ) # Setup - data = pd.DataFrame({ - 'a': [1, 2, 3], - 'b': [4, 5, 6] - }) + data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run metadata = SingleTableMetadata() @@ -83,10 +66,7 @@ def test_warnings_are_shown(): constraints = [ { 'constraint_class': 'Inequality', - 'constraint_parameters': { - 'low_column_name': 'a', - 'high_column_name': 'b' - } + 'constraint_parameters': {'low_column_name': 'a', 'high_column_name': 'b'}, } ] with pytest.warns(FutureWarning, match=warn_message): @@ -96,11 +76,7 @@ def test_warnings_are_shown(): preset.fit(data) with pytest.warns(FutureWarning, match=warn_message): - samples = preset.sample( - num_rows=10, - max_tries_per_batch=20, - batch_size=5 - ) + samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) # Assert assert len(samples) == 10 diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index 8a7b5d56d..543ba9191 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -18,11 +18,7 @@ def test_multi_table_metadata(): result = instance.to_dict() # Assert - assert result == { - 'tables': {}, - 'relationships': [], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' - } + assert result == {'tables': {}, 'relationships': [], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'} assert instance.tables == {} assert instance.relationships == [] @@ -30,6 +26,7 @@ def test_multi_table_metadata(): @patch('rdt.transformers') def test_add_column_relationship(mock_rdt_transformers): """Test ``add_column_relationship`` method.""" + # Setup class RandomLocationGeneratorMock: @classmethod @@ -64,14 +61,14 @@ def test_remove_primary_key(): 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'nesreca', - 'child_foreign_key': 'upravna_enota' + 'child_foreign_key': 'upravna_enota', }, { 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'oseba', - 'child_foreign_key': 'upravna_enota' - } + 'child_foreign_key': 'upravna_enota', + }, ] assert metadata.tables['nesreca'].primary_key is None assert metadata.relationships == expected_relationships @@ -87,47 +84,30 @@ def test_upgrade_metadata(tmp_path): 'upravna_enota': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'upravna_enota', - 'field': 'id_upravna_enota' - } - }, - 'id_nesreca': { - 'type': 'id', - 'subtype': 'integer' + 'ref': {'table': 'upravna_enota', 'field': 'id_upravna_enota'}, }, + 'id_nesreca': {'type': 'id', 'subtype': 'integer'}, }, - 'primary_key': 'id_nesreca' + 'primary_key': 'id_nesreca', }, 'oseba': { 'fields': { 'upravna_enota': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'upravna_enota', - 'field': 'id_upravna_enota' - } + 'ref': {'table': 'upravna_enota', 'field': 'id_upravna_enota'}, }, 'id_nesreca': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'nesreca', - 'field': 'id_nesreca' - } + 'ref': {'table': 'nesreca', 'field': 'id_nesreca'}, }, }, }, 'upravna_enota': { - 'fields': { - 'id_upravna_enota': { - 'type': 'id', - 'subtype': 'integer' - } - }, - 'primary_key': 'id_upravna_enota' - } + 'fields': {'id_upravna_enota': {'type': 'id', 'subtype': 'integer'}}, + 'primary_key': 'id_upravna_enota', + }, } } filepath = tmp_path / 'old.json' @@ -145,41 +125,41 @@ def test_upgrade_metadata(tmp_path): 'primary_key': 'id_nesreca', 'columns': { 'upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'}, - 'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'} - } + 'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'}, + }, }, 'oseba': { 'columns': { 'upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'}, - 'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'} + 'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'}, } }, 'upravna_enota': { 'primary_key': 'id_upravna_enota', - 'columns': {'id_upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'}} - } + 'columns': {'id_upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'}}, + }, }, 'relationships': [ { 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'nesreca', - 'child_foreign_key': 'upravna_enota' + 'child_foreign_key': 'upravna_enota', }, { 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'oseba', - 'child_foreign_key': 'upravna_enota' + 'child_foreign_key': 'upravna_enota', }, { 'parent_table_name': 'nesreca', 'parent_primary_key': 'id_nesreca', 'child_table_name': 'oseba', - 'child_foreign_key': 'id_nesreca' - } + 'child_foreign_key': 'id_nesreca', + }, ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } assert new_metadata['METADATA_SPEC_VERSION'] == expected_metadata['METADATA_SPEC_VERSION'] assert new_metadata['tables'] == expected_metadata['tables'] @@ -190,10 +170,7 @@ def test_upgrade_metadata(tmp_path): def test_detect_from_dataframes(): """Test the ``detect_from_dataframes`` method.""" # Setup - real_data, _ = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') metadata = MultiTableMetadata() @@ -215,9 +192,9 @@ def test_detect_from_dataframes(): 'city': {'sdtype': 'city', 'pii': True}, 'state': {'sdtype': 'administrative_unit', 'pii': True}, 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'categorical'} + 'classification': {'sdtype': 'categorical'}, }, - 'primary_key': 'hotel_id' + 'primary_key': 'hotel_id', }, 'guests': { 'columns': { @@ -230,20 +207,20 @@ def test_detect_from_dataframes(): 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, 'room_rate': {'sdtype': 'numerical'}, 'billing_address': {'sdtype': 'unknown', 'pii': True}, - 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True} + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, }, - 'primary_key': 'guest_email' - } + 'primary_key': 'guest_email', + }, }, 'relationships': [ { 'parent_table_name': 'hotels', 'child_table_name': 'guests', 'parent_primary_key': 'hotel_id', - 'child_foreign_key': 'hotel_id' + 'child_foreign_key': 'hotel_id', } ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } assert metadata.to_dict() == expected_metadata @@ -251,10 +228,7 @@ def test_detect_from_dataframes(): def test_detect_from_csvs(tmp_path): """Test the ``detect_from_csvs`` method.""" # Setup - real_data, _ = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') metadata = MultiTableMetadata() @@ -280,9 +254,9 @@ def test_detect_from_csvs(tmp_path): 'city': {'sdtype': 'city', 'pii': True}, 'state': {'sdtype': 'administrative_unit', 'pii': True}, 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'categorical'} + 'classification': {'sdtype': 'categorical'}, }, - 'primary_key': 'hotel_id' + 'primary_key': 'hotel_id', }, 'guests': { 'columns': { @@ -295,20 +269,20 @@ def test_detect_from_csvs(tmp_path): 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, 'room_rate': {'sdtype': 'numerical'}, 'billing_address': {'sdtype': 'unknown', 'pii': True}, - 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True} + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, }, - 'primary_key': 'guest_email' - } + 'primary_key': 'guest_email', + }, }, 'relationships': [ { 'parent_table_name': 'hotels', 'child_table_name': 'guests', 'parent_primary_key': 'hotel_id', - 'child_foreign_key': 'hotel_id' + 'child_foreign_key': 'hotel_id', } ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } assert metadata.to_dict() == expected_metadata @@ -317,10 +291,7 @@ def test_detect_from_csvs(tmp_path): def test_detect_table_from_csv(tmp_path): """Test the ``detect_table_from_csv`` method.""" # Setup - real_data, _ = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') metadata = MultiTableMetadata() @@ -355,13 +326,13 @@ def test_detect_table_from_csv(tmp_path): 'city': {'sdtype': 'categorical'}, 'state': {'sdtype': 'categorical'}, 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'categorical'} + 'classification': {'sdtype': 'categorical'}, }, - 'primary_key': 'hotel_id' + 'primary_key': 'hotel_id', } }, 'relationships': [], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } assert metadata.to_dict() == expected_metadata @@ -400,10 +371,8 @@ def test_get_table_metadata(): 'id_nesreca': {'sdtype': 'id'}, 'nesreca_val': {'sdtype': 'numerical'}, 'latitude': {'sdtype': 'latitude', 'pii': True}, - 'longitude': {'sdtype': 'longitude', 'pii': True} + 'longitude': {'sdtype': 'longitude', 'pii': True}, }, - 'column_relationships': [ - {'type': 'gps', 'column_names': ['latitude', 'longitude']} - ] + 'column_relationships': [{'type': 'gps', 'column_names': ['latitude', 'longitude']}], } assert table_metadata.to_dict() == expected_metadata diff --git a/tests/integration/metadata/test_single_table.py b/tests/integration/metadata/test_single_table.py index 4aaaa68ea..a5a903c80 100644 --- a/tests/integration/metadata/test_single_table.py +++ b/tests/integration/metadata/test_single_table.py @@ -1,4 +1,5 @@ """Integration tests for Single Table Metadata.""" + import json import re from unittest.mock import patch @@ -22,9 +23,7 @@ def test_single_table_metadata(): result = instance.to_dict() # Assert - assert result == { - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' - } + assert result == {'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'} assert instance.columns == {} assert instance._version == 'SINGLE_TABLE_V1' assert instance.primary_key is None @@ -36,6 +35,7 @@ def test_single_table_metadata(): @patch('rdt.transformers') def test_add_column_relationship(mock_rdt_transformers): """Test ``add_column_relationship`` method.""" + # Setup class RandomLocationGeneratorMock: @classmethod @@ -54,9 +54,7 @@ def _validate_sdtypes(cls, columns_to_sdtypes): # Assert instance.validate() - assert instance.column_relationships == [ - {'type': 'address', 'column_names': ['col2', 'col3']} - ] + assert instance.column_relationships == [{'type': 'address', 'column_names': ['col2', 'col3']}] def test_add_column_relationship_existing_column_in_relationship(): @@ -79,9 +77,7 @@ def test_add_column_relationship_existing_column_in_relationship(): ' Columns cannot be part of multiple relationships.' ) with pytest.raises(InvalidMetadataError, match=expected_message): - instance.add_column_relationship( - relationship_type='address', column_names=['col2', 'col4'] - ) + instance.add_column_relationship(relationship_type='address', column_names=['col2', 'col4']) @patch('rdt.transformers') @@ -90,6 +86,7 @@ def test_validate(mock_rdt_transformers): Ensure the method doesn't crash for a valid metadata. """ + # Setup class RandomLocationGeneratorMock: @classmethod @@ -116,13 +113,20 @@ def _validate_sdtypes(cls, columns_to_sdtypes): @patch('rdt.transformers') def test_validate_errors(mock_rdt_transformers): """Test ``SingleTableMetadata.validate`` raises the correct errors.""" + # Setup class RandomLocationGeneratorMock: @classmethod def _validate_sdtypes(cls, columns_to_sdtypes): valid_sdtypes = ( - 'country_code', 'administrative_unit', 'city', 'postcode', 'street_address', - 'secondary_address', 'state', 'state_abbr' + 'country_code', + 'administrative_unit', + 'city', + 'postcode', + 'street_address', + 'secondary_address', + 'state', + 'state_abbr', ) bad_columns = [] for column_name, sdtypes in columns_to_sdtypes.items(): @@ -156,7 +160,7 @@ def _validate_sdtypes(cls, columns_to_sdtypes): instance.column_relationships = [ {'type': 'address', 'column_names': ['col11']}, {'type': 'address', 'column_names': ['col1', 'col2']}, - {'type': 'fake_relationship', 'column_names': ['col3', 'col4']} + {'type': 'fake_relationship', 'column_names': ['col3', 'col4']}, ] err_msg = re.escape( @@ -193,43 +197,17 @@ def test_upgrade_metadata(tmp_path): # Setup old_metadata = { 'fields': { - 'start_date': { - 'type': 'datetime', - 'format': '%Y-%m-%d' - }, - 'end_date': { - 'type': 'datetime', - 'format': '%Y-%m-%d' - }, - 'salary': { - 'type': 'numerical', - 'subtype': 'integer' - }, - 'duration': { - 'type': 'categorical' - }, - 'student_id': { - 'type': 'id', - 'subtype': 'integer' - }, - 'high_perc': { - 'type': 'numerical', - 'subtype': 'float' - }, - 'placed': { - 'type': 'boolean' - }, - 'ssn': { - 'type': 'id', - 'subtype': 'integer' - }, - 'drivers_license': { - 'type': 'id', - 'subtype': 'string', - 'regex': 'regex' - } + 'start_date': {'type': 'datetime', 'format': '%Y-%m-%d'}, + 'end_date': {'type': 'datetime', 'format': '%Y-%m-%d'}, + 'salary': {'type': 'numerical', 'subtype': 'integer'}, + 'duration': {'type': 'categorical'}, + 'student_id': {'type': 'id', 'subtype': 'integer'}, + 'high_perc': {'type': 'numerical', 'subtype': 'float'}, + 'placed': {'type': 'boolean'}, + 'ssn': {'type': 'id', 'subtype': 'integer'}, + 'drivers_license': {'type': 'id', 'subtype': 'string', 'regex': 'regex'}, }, - 'primary_key': 'student_id' + 'primary_key': 'student_id', } filepath = tmp_path / 'old.json' old_metadata_file = open(filepath, 'w') @@ -242,44 +220,19 @@ def test_upgrade_metadata(tmp_path): # Assert expected_metadata = { 'columns': { - 'start_date': { - 'sdtype': 'datetime', - 'datetime_format': '%Y-%m-%d' - }, - 'end_date': { - 'sdtype': 'datetime', - 'datetime_format': '%Y-%m-%d' - }, - 'salary': { - 'sdtype': 'numerical', - 'computer_representation': 'Int64' - }, - 'duration': { - 'sdtype': 'categorical' - }, - 'student_id': { - 'sdtype': 'id', - 'regex_format': r'\d{30}' - }, - 'high_perc': { - 'sdtype': 'numerical', - 'computer_representation': 'Float' - }, - 'placed': { - 'sdtype': 'boolean' - }, - 'ssn': { - 'sdtype': 'id', - 'regex_format': r'\d{30}' - }, - 'drivers_license': { - 'sdtype': 'id', - 'regex_format': 'regex' - } + 'start_date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'end_date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'salary': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'duration': {'sdtype': 'categorical'}, + 'student_id': {'sdtype': 'id', 'regex_format': r'\d{30}'}, + 'high_perc': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'placed': {'sdtype': 'boolean'}, + 'ssn': {'sdtype': 'id', 'regex_format': r'\d{30}'}, + 'drivers_license': {'sdtype': 'id', 'regex_format': 'regex'}, }, 'primary_key': 'student_id', 'alternate_keys': ['ssn', 'drivers_license'], - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } assert new_metadata == expected_metadata @@ -287,10 +240,7 @@ def test_upgrade_metadata(tmp_path): def test_validate_unknown_sdtype(): """Test ``validate`` method works with ``unknown`` sdtype.""" # Setup - data, _ = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') metadata = SingleTableMetadata() metadata.detect_from_dataframe(data['hotels']) @@ -306,9 +256,9 @@ def test_validate_unknown_sdtype(): 'city': {'sdtype': 'city', 'pii': True}, 'state': {'sdtype': 'administrative_unit', 'pii': True}, 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'unknown', 'pii': True} + 'classification': {'sdtype': 'unknown', 'pii': True}, }, - 'primary_key': 'hotel_id' + 'primary_key': 'hotel_id', } assert metadata.to_dict() == expected_metadata @@ -322,7 +272,7 @@ def test_detect_from_dataframe_with_none_nan_and_nat(): 'f_nan_data': [float('nan')] * 100, 'none_data': [None] * 100, 'np_nan_data': [np.nan] * 100, - 'pd_nat_data': [pd.NaT] * 100 + 'pd_nat_data': [pd.NaT] * 100, }) stm = SingleTableMetadata() @@ -344,7 +294,6 @@ def test_detect_from_dataframe_with_pii_names(): 'addr_line_1': [1, 2, 3], 'First Name': [1, 2, 3], 'guest_email': [1, 2, 3], - }) metadata = SingleTableMetadata() @@ -359,8 +308,8 @@ def test_detect_from_dataframe_with_pii_names(): 'USER PHONE NUMBER': {'sdtype': 'phone_number', 'pii': True}, 'addr_line_1': {'sdtype': 'street_address', 'pii': True}, 'First Name': {'sdtype': 'first_name', 'pii': True}, - 'guest_email': {'sdtype': 'email', 'pii': True} - } + 'guest_email': {'sdtype': 'email', 'pii': True}, + }, } assert metadata.to_dict() == expected_metadata @@ -372,11 +321,13 @@ def test_detect_from_dataframe_with_pii_non_unique(): The metadata should not detect any primray key. """ # Setup - data = pd.DataFrame(data={ - 'Age': [int(i) for i in np.random.uniform(low=0, high=100, size=100)], - 'Sex': np.random.choice(['Male', 'Female'], size=100), - 'latitude': [round(i, 2) for i in np.random.uniform(low=-90, high=+90, size=50)] * 2 - }) + data = pd.DataFrame( + data={ + 'Age': [int(i) for i in np.random.uniform(low=0, high=100, size=100)], + 'Sex': np.random.choice(['Male', 'Female'], size=100), + 'latitude': [round(i, 2) for i in np.random.uniform(low=-90, high=+90, size=50)] * 2, + } + ) metadata = SingleTableMetadata() # Run @@ -397,7 +348,7 @@ def test_update_columns(): 'col3': {'sdtype': 'categorical'}, 'col4': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'col5': {'sdtype': 'unknown'}, - 'col6': {'sdtype': 'email', 'pii': True} + 'col6': {'sdtype': 'email', 'pii': True}, } }) @@ -405,7 +356,7 @@ def test_update_columns(): metadata.update_columns( ['col1', 'col3', 'col4', 'col5', 'col6'], sdtype='numerical', - computer_representation='Int64' + computer_representation='Int64', ) # Assert @@ -417,8 +368,8 @@ def test_update_columns(): 'col3': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, 'col4': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, 'col5': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, - 'col6': {'sdtype': 'numerical', 'computer_representation': 'Int64'} - } + 'col6': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + }, } assert metadata.to_dict() == expected_metadata @@ -433,20 +384,18 @@ def test_update_columns_invalid_kwargs_combination(): 'col3': {'sdtype': 'categorical'}, 'col4': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'col5': {'sdtype': 'unknown'}, - 'col6': {'sdtype': 'email', 'pii': True} + 'col6': {'sdtype': 'email', 'pii': True}, } }) # Run / Assert - expected_message = re.escape( - "Invalid values '(pii)' for 'numerical' sdtype." - ) + expected_message = re.escape("Invalid values '(pii)' for 'numerical' sdtype.") with pytest.raises(InvalidMetadataError, match=expected_message): metadata.update_columns( ['col1', 'col3', 'col4', 'col5', 'col6'], sdtype='numerical', computer_representation='Int64', - pii=True + pii=True, ) @@ -460,20 +409,18 @@ def test_update_columns_metadata(): 'col3': {'sdtype': 'categorical'}, 'col4': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'col5': {'sdtype': 'unknown'}, - 'col6': {'sdtype': 'email', 'pii': True} + 'col6': {'sdtype': 'email', 'pii': True}, } }) # Run - metadata.update_columns_metadata( - { - 'col1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, - 'col3': {'sdtype': 'email', 'pii': True}, - 'col4': {'sdtype': 'phone_number', 'pii': False}, - 'col5': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'col6': {'sdtype': 'unknown'} - } - ) + metadata.update_columns_metadata({ + 'col1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'col3': {'sdtype': 'email', 'pii': True}, + 'col4': {'sdtype': 'phone_number', 'pii': False}, + 'col5': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'col6': {'sdtype': 'unknown'}, + }) # Assert expected_metadata = { @@ -484,8 +431,8 @@ def test_update_columns_metadata(): 'col3': {'sdtype': 'email', 'pii': True}, 'col4': {'sdtype': 'phone_number', 'pii': False}, 'col5': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'col6': {'sdtype': 'unknown'} - } + 'col6': {'sdtype': 'unknown'}, + }, } assert metadata.to_dict() == expected_metadata @@ -500,7 +447,7 @@ def test_update_columns_metadata_invalid_kwargs_combination(): 'col3': {'sdtype': 'categorical'}, 'col4': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'col5': {'sdtype': 'unknown'}, - 'col6': {'sdtype': 'email', 'pii': True} + 'col6': {'sdtype': 'email', 'pii': True}, } }) @@ -509,15 +456,12 @@ def test_update_columns_metadata_invalid_kwargs_combination(): 'The following errors were found when updating columns:\n\n' "Invalid values '(pii)' for numerical column 'col1'.\n" "Invalid values '(pii)' for numerical column 'col2'." - ) with pytest.raises(InvalidMetadataError, match=expected_message): - metadata.update_columns_metadata( - { - 'col1': {'sdtype': 'numerical', 'computer_representation': 'Int64', 'pii': True}, - 'col2': {'pii': True} - } - ) + metadata.update_columns_metadata({ + 'col1': {'sdtype': 'numerical', 'computer_representation': 'Int64', 'pii': True}, + 'col2': {'pii': True}, + }) def test_column_relationship_validation(): @@ -527,14 +471,11 @@ def test_column_relationship_validation(): 'columns': { 'user_city': {'sdtype': 'city'}, 'user_zip': {'sdtype': 'postcode'}, - 'user_value': {'sdtype': 'unknown'} + 'user_value': {'sdtype': 'unknown'}, }, 'column_relationships': [ - { - 'type': 'address', - 'column_names': ['user_city', 'user_zip', 'user_value'] - } - ] + {'type': 'address', 'column_names': ['user_city', 'user_zip', 'user_value']} + ], }) expected_message = re.escape( diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index 0777aa441..07cb870b6 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -32,7 +32,8 @@ def test_visualize_graph_for_multi_table(): metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...') metadata.add_relationship( - '1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...') + '1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...' + ) model = HMASynthesizer(metadata) # Run diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 60488e054..a82049200 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -23,7 +23,6 @@ class TestHMASynthesizer: - def test_hma(self): """End to end integration tests with ``HMASynthesizer``. @@ -46,8 +45,7 @@ def test_hma(self): for table_name, table in normal_sample.items(): assert all(table.columns == data[table_name].columns) - for normal_table, increased_table in zip( - normal_sample.values(), increased_sample.values()): + for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()): assert increased_table.size > normal_table.size def test_hma_reset_sampling(self): @@ -111,7 +109,7 @@ def test_get_info(self): 'creation_date': today, 'is_fit': False, 'last_fit_date': None, - 'fitted_sdv_version': None + 'fitted_sdv_version': None, } # Run @@ -125,7 +123,7 @@ def test_get_info(self): 'creation_date': today, 'is_fit': True, 'last_fit_date': today, - 'fitted_sdv_version': version + 'fitted_sdv_version': version, } def test_hma_set_table_parameters(self): @@ -134,7 +132,7 @@ def test_hma_set_table_parameters(self): Validate that the ``set_table_parameters`` sets new parameters to the synthesizers. """ # Setup - data, metadata = download_demo('multi_table', 'got_families') + _data, metadata = download_demo('multi_table', 'got_families') hmasynthesizer = HMASynthesizer(metadata) # Run @@ -150,7 +148,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, } families_params = hmasynthesizer.get_table_parameters('families') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -159,7 +157,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, } char_families_params = hmasynthesizer.get_table_parameters('character_families') assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -168,13 +166,14 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, } assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' assert hmasynthesizer._table_synthesizers['families'].default_distribution == 'uniform' - assert hmasynthesizer._table_synthesizers['character_families'].default_distribution == \ - 'norm' + assert ( + hmasynthesizer._table_synthesizers['character_families'].default_distribution == 'norm' + ) def get_custom_constraint_data_and_metadata(self): """Return data and metadata for the custom constraint tests.""" @@ -203,7 +202,7 @@ def get_custom_constraint_data_and_metadata(self): parent_primary_key='primary_key', parent_table_name='parent', child_foreign_key='user_id', - child_table_name='child' + child_table_name='child', ) return parent_data, child_data, metadata @@ -216,9 +215,7 @@ def test_hma_custom_constraint(self): constraint = { 'table_name': 'parent', 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col'] - } + 'constraint_parameters': {'column_names': ['numerical_col']}, } synthesizer.add_custom_constraint_class(MyConstraint, 'MyConstraint') @@ -229,7 +226,7 @@ def test_hma_custom_constraint(self): # Assert Processed Data np.testing.assert_equal( processed_data['parent']['numerical_col'].array, - (parent_data['numerical_col'] ** 2.0).array + (parent_data['numerical_col'] ** 2.0).array, ) # Run - Fit the model @@ -251,17 +248,13 @@ def test_hma_custom_constraint_2_tables(self): constraint_parent = { 'table_name': 'parent', 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col'] - } + 'constraint_parameters': {'column_names': ['numerical_col']}, } constraint_child = { 'table_name': 'child', 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col_2'] - } + 'constraint_parameters': {'column_names': ['numerical_col_2']}, } synthesizer.add_custom_constraint_class(MyConstraint, 'MyConstraint') @@ -272,11 +265,11 @@ def test_hma_custom_constraint_2_tables(self): # Assert Processed Data np.testing.assert_equal( processed_data['parent']['numerical_col'].array, - (parent_data['numerical_col'] ** 2.0).array + (parent_data['numerical_col'] ** 2.0).array, ) np.testing.assert_equal( processed_data['child']['numerical_col_2'].array, - (child_data['numerical_col_2'] ** 2.0).array + (child_data['numerical_col_2'] ** 2.0).array, ) # Run - Fit the model @@ -296,13 +289,10 @@ def test_hma_custom_constraint_loaded_from_file(self): constraint = { 'table_name': 'parent', 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col'] - } + 'constraint_parameters': {'column_names': ['numerical_col']}, } synthesizer.load_custom_constraint_classes( - 'tests/integration/single_table/custom_constraints.py', - ['MyConstraint'] + 'tests/integration/single_table/custom_constraints.py', ['MyConstraint'] ) # Run @@ -312,7 +302,7 @@ def test_hma_custom_constraint_loaded_from_file(self): # Assert Processed Data np.testing.assert_equal( processed_data['parent']['numerical_col'].array, - (parent_data['numerical_col'] ** 2.0).array + (parent_data['numerical_col'] ** 2.0).array, ) # Run - Fit the model @@ -325,22 +315,20 @@ def test_hma_custom_constraint_loaded_from_file(self): def test_hma_with_inequality_constraint(self): """Test that when new columns are created by the constraint this still works.""" # Setup - parent_table = pd.DataFrame(data={ - 'id': [1, 2, 3, 4, 5], - 'column': [1.2, 2.1, 2.2, 2.1, 1.4] - }) + parent_table = pd.DataFrame( + data={'id': [1, 2, 3, 4, 5], 'column': [1.2, 2.1, 2.2, 2.1, 1.4]} + ) - child_table = pd.DataFrame(data={ - 'id': [1, 2, 3, 4, 5], - 'parent_id': [1, 1, 3, 2, 1], - 'low_column': [1, 3, 3, 1, 2], - 'high_column': [2, 4, 5, 2, 4] - }) + child_table = pd.DataFrame( + data={ + 'id': [1, 2, 3, 4, 5], + 'parent_id': [1, 1, 3, 2, 1], + 'low_column': [1, 3, 3, 1, 2], + 'high_column': [2, 4, 5, 2, 4], + } + ) - data = { - 'parent_table': parent_table, - 'child_table': child_table - } + data = {'parent_table': parent_table, 'child_table': child_table} metadata = MultiTableMetadata() metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table) @@ -356,7 +344,7 @@ def test_hma_with_inequality_constraint(self): parent_table_name='parent_table', child_table_name='child_table', parent_primary_key='id', - child_foreign_key='parent_id' + child_foreign_key='parent_id', ) constraint = { @@ -364,8 +352,8 @@ def test_hma_with_inequality_constraint(self): 'table_name': 'child_table', 'constraint_parameters': { 'low_column_name': 'low_column', - 'high_column_name': 'high_column' - } + 'high_column_name': 'high_column', + }, } synthesizer = HMASynthesizer(metadata) @@ -421,25 +409,15 @@ def test_save_and_load(self, tmp_path): def test_hma_primary_key_and_foreign_key_only(self): """Test that ``HMASynthesizer`` can handle tables with primary and foreign keys only.""" # Setup - users = pd.DataFrame({ - 'user_id': [1, 2, 3], - 'user_name': ['John', 'Doe', 'Johanna'] - }) - sessions = pd.DataFrame({ - 'session_id': ['a', 'b', 'c'], - 'clicks': [10, 20, 30] - }) + users = pd.DataFrame({'user_id': [1, 2, 3], 'user_name': ['John', 'Doe', 'Johanna']}) + sessions = pd.DataFrame({'session_id': ['a', 'b', 'c'], 'clicks': [10, 20, 30]}) games = pd.DataFrame({ 'game_id': ['a1', 'b2', 'c3'], 'session_id': ['a', 'b', 'c'], - 'user_id': [1, 2, 3] + 'user_id': [1, 2, 3], }) - data = { - 'users': users, - 'sessions': sessions, - 'games': games - } + data = {'users': users, 'sessions': sessions, 'games': games} metadata = MultiTableMetadata() for table_name, table in data.items(): @@ -480,10 +458,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): """ # Loading the demo data - real_data, metadata = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + real_data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels') # Creating a Synthesizer synthesizer = HMASynthesizer(metadata) @@ -502,19 +477,14 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): assert synthetic_data['guests'][column].isin(real_data['guests'][column]).sum() == 0 # Evaluate Real vs Synthetic Data - quality_report = evaluate_quality( - real_data, - synthetic_data, - metadata, - verbose=False - ) + quality_report = evaluate_quality(real_data, synthetic_data, metadata, verbose=False) column_plot = get_column_plot( real_data=real_data, synthetic_data=synthetic_data, column_name='has_rewards', table_name='guests', - metadata=metadata + metadata=metadata, ) column_pair_plot = get_column_pair_plot( @@ -522,7 +492,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): synthetic_data=synthetic_data, column_names=['room_rate', 'room_type'], table_name='guests', - metadata=metadata + metadata=metadata, ) # Assert @@ -546,15 +516,13 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): loaded_synthesizer.sample() # HMA Customization - custom_synthesizer = HMASynthesizer( - metadata - ) + custom_synthesizer = HMASynthesizer(metadata) custom_synthesizer.set_table_parameters( table_name='hotels', table_parameters={ 'default_distribution': 'truncnorm', - } + }, ) custom_synthesizer.fit(real_data) @@ -565,7 +533,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): 'a', 'b', 'loc', - 'scale' + 'scale', ] assert learned_distributions['rating']['distribution'] == 'truncnorm' @@ -578,9 +546,7 @@ def test_use_own_data_using_hma(self, tmp_path): # Setup data_folder = tmp_path / 'datasets' download_demo( - modality='multi_table', - dataset_name='fake_hotels', - output_folder_name=data_folder + modality='multi_table', dataset_name='fake_hotels', output_folder_name=data_folder ) # Run - load CSVs @@ -592,14 +558,8 @@ def test_use_own_data_using_hma(self, tmp_path): # Metadata metadata = MultiTableMetadata() - metadata.detect_table_from_dataframe( - table_name='guests', - data=datasets['guests'] - ) - metadata.detect_table_from_dataframe( - table_name='hotels', - data=datasets['hotels'] - ) + metadata.detect_table_from_dataframe(table_name='guests', data=datasets['guests']) + metadata.detect_table_from_dataframe(table_name='hotels', data=datasets['hotels']) # Assert - detected metadata correctly for table in metadata.tables: @@ -610,25 +570,19 @@ def test_use_own_data_using_hma(self, tmp_path): table_name='guests', column_name='checkin_date', sdtype='datetime', - datetime_format='%d %b %Y' + datetime_format='%d %b %Y', ) metadata.update_column( table_name='guests', column_name='checkout_date', sdtype='datetime', - datetime_format='%d %b %Y' + datetime_format='%d %b %Y', ) metadata.update_column( - table_name='hotels', - column_name='hotel_id', - sdtype='id', - regex_format='HID_[0-9]{3,4}' + table_name='hotels', column_name='hotel_id', sdtype='id', regex_format='HID_[0-9]{3,4}' ) metadata.update_column( - table_name='guests', - column_name='hotel_id', - sdtype='id', - regex_format='HID_[0-9]{3,4}' + table_name='guests', column_name='hotel_id', sdtype='id', regex_format='HID_[0-9]{3,4}' ) metadata.update_column( table_name='hotels', @@ -646,40 +600,25 @@ def test_use_own_data_using_hma(self, tmp_path): sdtype='categorical', ) metadata.update_column( - table_name='guests', - column_name='guest_email', - sdtype='email', - pii=True + table_name='guests', column_name='guest_email', sdtype='email', pii=True ) metadata.update_column( - table_name='guests', - column_name='billing_address', - sdtype='address', - pii=True + table_name='guests', column_name='billing_address', sdtype='address', pii=True ) metadata.update_column( table_name='guests', column_name='credit_card_number', sdtype='credit_card_number', - pii=True - ) - metadata.set_primary_key( - table_name='hotels', - column_name='hotel_id' - ) - metadata.set_primary_key( - table_name='guests', - column_name='guest_email' - ) - metadata.add_alternate_keys( - table_name='guests', - column_names=['credit_card_number'] + pii=True, ) + metadata.set_primary_key(table_name='hotels', column_name='hotel_id') + metadata.set_primary_key(table_name='guests', column_name='guest_email') + metadata.add_alternate_keys(table_name='guests', column_names=['credit_card_number']) metadata.add_relationship( parent_table_name='hotels', child_table_name='guests', parent_primary_key='hotel_id', - child_foreign_key='hotel_id' + child_foreign_key='hotel_id', ) # Assert - check updated metadata @@ -733,7 +672,7 @@ def test_progress_bar_print(self, capsys): r'Preprocess Tables:', r'Learning relationships:', r"\(1/2\) Tables 'characters' and 'character_families' \('character_id'\):", - r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):" + r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):", ] # Run @@ -750,15 +689,12 @@ def test_progress_bar_print(self, capsys): def test_warning_message_too_many_cols(self, capsys): """Test that a warning appears if there are more than 1000 expected columns""" # Setup - (_, metadata) = download_demo( - modality='multi_table', - dataset_name='NBA_v1' - ) + (_, metadata) = download_demo(modality='multi_table', dataset_name='NBA_v1') key_phrases = [ r'PerformanceAlert:', r'large number of columns.', - r'contact us at info@sdv.dev for enterprise solutions.' + r'contact us at info@sdv.dev for enterprise solutions.', ] # Run @@ -770,10 +706,7 @@ def test_warning_message_too_many_cols(self, capsys): for pattern in key_phrases: match = re.search(pattern, captured.out + captured.err) assert match is not None - (_, small_metadata) = download_demo( - modality='multi_table', - dataset_name='trains_v1' - ) + (_, small_metadata) = download_demo(modality='multi_table', dataset_name='trains_v1') # Run HMASynthesizer(small_metadata) @@ -788,20 +721,23 @@ def test_warning_message_too_many_cols(self, capsys): def test_hma_three_linear_nodes(self): """Test it works on a simple 'grandparent-parent-child' dataset.""" # Setup - grandparent = pd.DataFrame(data={ - 'grandparent_ID': [0, 1, 2, 3, 4], - 'data': ['0', '1', '2', '3', '4'] - }) - parent = pd.DataFrame(data={ - 'parent_ID': ['a', 'b', 'c', 'd', 'e'], - 'grandparent_ID': [0, 0, 1, 1, 3], - 'data': [True, False, False, False, True] - }) - child = pd.DataFrame(data={ - 'child_ID': ['00', '01', '02', '03', '04'], - 'parent_ID': ['b', 'b', 'a', 'e', 'e'], - 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No'] - }) + grandparent = pd.DataFrame( + data={'grandparent_ID': [0, 1, 2, 3, 4], 'data': ['0', '1', '2', '3', '4']} + ) + parent = pd.DataFrame( + data={ + 'parent_ID': ['a', 'b', 'c', 'd', 'e'], + 'grandparent_ID': [0, 0, 1, 1, 3], + 'data': [True, False, False, False, True], + } + ) + child = pd.DataFrame( + data={ + 'child_ID': ['00', '01', '02', '03', '04'], + 'parent_ID': ['b', 'b', 'a', 'e', 'e'], + 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No'], + } + ) data = {'grandparent': grandparent, 'parent': parent, 'child': child} metadata = MultiTableMetadata.load_from_dict({ 'tables': { @@ -809,40 +745,40 @@ def test_hma_three_linear_nodes(self): 'primary_key': 'grandparent_ID', 'columns': { 'grandparent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, 'parent': { 'primary_key': 'parent_ID', 'columns': { 'parent_ID': {'sdtype': 'id'}, 'grandparent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, 'child': { 'primary_key': 'child_ID', 'columns': { 'child_ID': {'sdtype': 'id'}, 'parent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } - } + 'data': {'sdtype': 'categorical'}, + }, + }, }, 'relationships': [ { 'parent_table_name': 'grandparent', 'parent_primary_key': 'grandparent_ID', 'child_table_name': 'parent', - 'child_foreign_key': 'grandparent_ID' + 'child_foreign_key': 'grandparent_ID', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'parent_ID', 'child_table_name': 'child', - 'child_foreign_key': 'parent_ID' - } - ] + 'child_foreign_key': 'parent_ID', + }, + ], }) synthesizer = HMASynthesizer(metadata) @@ -864,61 +800,61 @@ def test_hma_three_linear_nodes(self): def test_hma_one_parent_two_children(self): """Test it works on a simple 'child-parent-child' dataset.""" # Setup - parent = pd.DataFrame(data={ - 'parent_ID': [0, 1, 2, 3, 4], - 'data': ['0', '1', '2', '3', '4'] - }) - child1 = pd.DataFrame(data={ - 'child_ID': ['a', 'b', 'c', 'd', 'e'], - 'parent_ID': [0, 0, 1, 1, 3], - 'data': [True, False, False, False, True] - }) - child2 = pd.DataFrame(data={ - 'child_ID': ['00', '01', '02', '03', '04'], - 'parent_ID': [0, 1, 2, 3, 4], - 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No'] - }) + parent = pd.DataFrame( + data={'parent_ID': [0, 1, 2, 3, 4], 'data': ['0', '1', '2', '3', '4']} + ) + child1 = pd.DataFrame( + data={ + 'child_ID': ['a', 'b', 'c', 'd', 'e'], + 'parent_ID': [0, 0, 1, 1, 3], + 'data': [True, False, False, False, True], + } + ) + child2 = pd.DataFrame( + data={ + 'child_ID': ['00', '01', '02', '03', '04'], + 'parent_ID': [0, 1, 2, 3, 4], + 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No'], + } + ) data = {'parent': parent, 'child1': child1, 'child2': child2} metadata = MultiTableMetadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'parent_ID', - 'columns': { - 'parent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'columns': {'parent_ID': {'sdtype': 'id'}, 'data': {'sdtype': 'categorical'}}, }, 'child1': { 'primary_key': 'child_ID', 'columns': { 'child_ID': {'sdtype': 'id'}, 'parent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, 'child2': { 'primary_key': 'child_ID', 'columns': { 'child_ID': {'sdtype': 'id'}, 'parent_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } - } + 'data': {'sdtype': 'categorical'}, + }, + }, }, 'relationships': [ { 'parent_table_name': 'parent', 'parent_primary_key': 'parent_ID', 'child_table_name': 'child1', - 'child_foreign_key': 'parent_ID' + 'child_foreign_key': 'parent_ID', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'parent_ID', 'child_table_name': 'child2', - 'child_foreign_key': 'parent_ID' + 'child_foreign_key': 'parent_ID', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) @@ -940,36 +876,30 @@ def test_hma_one_parent_two_children(self): def test_hma_two_parents_one_child(self): """Test it works on a simple 'parent-child-parent' dataset.""" # Setup - child = pd.DataFrame(data={ - 'child_ID': ['a', 'b', 'c', 'd', 'e'], - 'parent_ID1': [0, 1, 2, 3, 3], - 'parent_ID2': [0, 1, 2, 3, 4], - 'data': ['0', '1', '2', '3', '4'] - }) - parent1 = pd.DataFrame(data={ - 'parent_ID1': [0, 1, 2, 3, 4], - 'data': [True, False, False, False, True] - }) - parent2 = pd.DataFrame(data={ - 'parent_ID2': [0, 1, 2, 3, 4], - 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No'] - }) + child = pd.DataFrame( + data={ + 'child_ID': ['a', 'b', 'c', 'd', 'e'], + 'parent_ID1': [0, 1, 2, 3, 3], + 'parent_ID2': [0, 1, 2, 3, 4], + 'data': ['0', '1', '2', '3', '4'], + } + ) + parent1 = pd.DataFrame( + data={'parent_ID1': [0, 1, 2, 3, 4], 'data': [True, False, False, False, True]} + ) + parent2 = pd.DataFrame( + data={'parent_ID2': [0, 1, 2, 3, 4], 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No']} + ) data = {'parent1': parent1, 'child': child, 'parent2': parent2} metadata = MultiTableMetadata.load_from_dict({ 'tables': { 'parent1': { 'primary_key': 'parent_ID1', - 'columns': { - 'parent_ID1': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'columns': {'parent_ID1': {'sdtype': 'id'}, 'data': {'sdtype': 'categorical'}}, }, 'parent2': { 'primary_key': 'parent_ID2', - 'columns': { - 'parent_ID2': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'columns': {'parent_ID2': {'sdtype': 'id'}, 'data': {'sdtype': 'categorical'}}, }, 'child': { 'primary_key': 'child_ID', @@ -977,8 +907,8 @@ def test_hma_two_parents_one_child(self): 'child_ID': {'sdtype': 'id'}, 'parent_ID1': {'sdtype': 'id'}, 'parent_ID2': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, }, 'relationships': [ @@ -986,15 +916,15 @@ def test_hma_two_parents_one_child(self): 'parent_table_name': 'parent1', 'parent_primary_key': 'parent_ID1', 'child_table_name': 'child', - 'child_foreign_key': 'parent_ID1' + 'child_foreign_key': 'parent_ID1', }, { 'parent_table_name': 'parent2', 'parent_primary_key': 'parent_ID2', 'child_table_name': 'child', - 'child_foreign_key': 'parent_ID2' + 'child_foreign_key': 'parent_ID2', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) @@ -1024,68 +954,66 @@ def test_hma_two_lineages_one_grandchild(self): gc """ # Setup - root1 = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'data': [True, False, False, False, True] - }) - root2 = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'data': [True, False, False, False, True] - }) - child1 = pd.DataFrame(data={ - 'child_ID': ['a', 'b', 'c', 'd', 'e'], - 'root1_ID': [0, 1, 2, 3, 3], - 'data': [True, False, False, False, True] - }) - child2 = pd.DataFrame(data={ - 'child_ID': ['a', 'b', 'c', 'd', 'e'], - 'root2_ID': [0, 1, 2, 3, 4], - 'data': [True, False, False, False, True] - }) - grandchild = pd.DataFrame(data={ - 'grandchild_ID': ['a', 'b', 'c', 'd', 'e'], - 'child1_ID': ['a', 'b', 'c', 'd', 'e'], - 'child2_ID': ['a', 'b', 'c', 'd', 'e'], - 'data': [True, False, False, False, True] - }) + root1 = pd.DataFrame( + data={'id': [0, 1, 2, 3, 4], 'data': [True, False, False, False, True]} + ) + root2 = pd.DataFrame( + data={'id': [0, 1, 2, 3, 4], 'data': [True, False, False, False, True]} + ) + child1 = pd.DataFrame( + data={ + 'child_ID': ['a', 'b', 'c', 'd', 'e'], + 'root1_ID': [0, 1, 2, 3, 3], + 'data': [True, False, False, False, True], + } + ) + child2 = pd.DataFrame( + data={ + 'child_ID': ['a', 'b', 'c', 'd', 'e'], + 'root2_ID': [0, 1, 2, 3, 4], + 'data': [True, False, False, False, True], + } + ) + grandchild = pd.DataFrame( + data={ + 'grandchild_ID': ['a', 'b', 'c', 'd', 'e'], + 'child1_ID': ['a', 'b', 'c', 'd', 'e'], + 'child2_ID': ['a', 'b', 'c', 'd', 'e'], + 'data': [True, False, False, False, True], + } + ) data = { 'root1': root1, 'root2': root2, 'child1': child1, 'child2': child2, - 'grandchild': grandchild + 'grandchild': grandchild, } metadata = MultiTableMetadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'id', - 'columns': { - 'id': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'columns': {'id': {'sdtype': 'id'}, 'data': {'sdtype': 'categorical'}}, }, 'root2': { 'primary_key': 'id', - 'columns': { - 'id': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'columns': {'id': {'sdtype': 'id'}, 'data': {'sdtype': 'categorical'}}, }, 'child1': { 'primary_key': 'child_ID', 'columns': { 'child_ID': {'sdtype': 'id'}, 'root1_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, 'child2': { 'primary_key': 'child_ID', 'columns': { 'child_ID': {'sdtype': 'id'}, 'root2_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, 'grandchild': { 'primary_key': 'grandchild_ID', @@ -1093,8 +1021,8 @@ def test_hma_two_lineages_one_grandchild(self): 'grandchild_ID': {'sdtype': 'id'}, 'child1_ID': {'sdtype': 'id'}, 'child2_ID': {'sdtype': 'id'}, - 'data': {'sdtype': 'categorical'} - } + 'data': {'sdtype': 'categorical'}, + }, }, }, 'relationships': [ @@ -1102,27 +1030,27 @@ def test_hma_two_lineages_one_grandchild(self): 'parent_table_name': 'root1', 'parent_primary_key': 'id', 'child_table_name': 'child1', - 'child_foreign_key': 'root1_ID' + 'child_foreign_key': 'root1_ID', }, { 'parent_table_name': 'root2', 'parent_primary_key': 'id', 'child_table_name': 'child2', - 'child_foreign_key': 'root2_ID' + 'child_foreign_key': 'root2_ID', }, { 'parent_table_name': 'child1', 'parent_primary_key': 'child_ID', 'child_table_name': 'grandchild', - 'child_foreign_key': 'child1_ID' + 'child_foreign_key': 'child1_ID', }, { 'parent_table_name': 'child2', 'parent_primary_key': 'child_ID', 'child_table_name': 'grandchild', - 'child_foreign_key': 'child2_ID' + 'child_foreign_key': 'child2_ID', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) @@ -1150,11 +1078,7 @@ def test_hma_numerical_distributions(self): # Run synthesizer.set_table_parameters( table_name='guests', - table_parameters={ - 'numerical_distributions': { - 'amenities_fee': 'beta' - } - } + table_parameters={'numerical_distributions': {'amenities_fee': 'beta'}}, ) synthesizer.fit(data) samples = synthesizer.sample(scale=1) @@ -1167,10 +1091,7 @@ def test_hma_numerical_distributions(self): def test_get_learned_distributions_error_msg(self): """Ensure the error message is correct when calling ``get_learned_distributions``.""" # Setup - data, metadata = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels') synth = HMASynthesizer(metadata) # Run @@ -1193,8 +1114,7 @@ def test__get_likelihoods(self): sampled_data = {} sampled_data['characters'] = hmasynthesizer._sample_rows( - hmasynthesizer._table_synthesizers['characters'], - len(data['characters']) + hmasynthesizer._table_synthesizers['characters'], len(data['characters']) ) hmasynthesizer._sample_children('characters', sampled_data) @@ -1203,7 +1123,7 @@ def test__get_likelihoods(self): sampled_data['character_families'], sampled_data['characters'].set_index('character_id'), 'character_families', - 'character_id' + 'character_id', ) # Assert @@ -1222,7 +1142,7 @@ def test__extract_parameters(self): '__sessions__user_id__a': -1, '__sessions__user_id__b': 1000, '__sessions__user_id__loc': 0.5, - '__sessions__user_id__scale': -0.25 + '__sessions__user_id__scale': -0.25, }) instance = HMASynthesizer(MultiTableMetadata()) instance.extended_columns = { @@ -1231,11 +1151,11 @@ def test__extract_parameters(self): '__sessions__user_id__a': FloatFormatter(enforce_min_max_values=True), '__sessions__user_id__b': FloatFormatter(enforce_min_max_values=True), '__sessions__user_id__loc': FloatFormatter(enforce_min_max_values=True), - '__sessions__user_id__scale': FloatFormatter(enforce_min_max_values=True) + '__sessions__user_id__scale': FloatFormatter(enforce_min_max_values=True), } } for col, float_formatter in instance.extended_columns['sessions'].items(): - float_formatter.fit(pd.DataFrame({col: [0., 100.]}), col) + float_formatter.fit(pd.DataFrame({col: [0.0, 100.0]}), col) instance._max_child_rows = {'__sessions__user_id__num_rows': 10} @@ -1243,13 +1163,7 @@ def test__extract_parameters(self): result = instance._extract_parameters(parent_row, 'sessions', 'user_id') # Assert - expected_result = { - 'a': 0., - 'b': 100., - 'loc': 0.5, - 'num_rows': 10., - 'scale': 0. - } + expected_result = {'a': 0.0, 'b': 100.0, 'loc': 0.5, 'num_rows': 10.0, 'scale': 0.0} assert result == expected_result def test__recreate_child_synthesizer_with_default_parameters(self): @@ -1261,31 +1175,23 @@ def test__recreate_child_synthesizer_with_default_parameters(self): f'{prefix}univariates__brand__a': 100, f'{prefix}univariates__brand__b': 10, f'{prefix}univariates__brand__loc': 0.5, - f'{prefix}univariates__brand__scale': -0.25 + f'{prefix}univariates__brand__scale': -0.25, }) metadata = MultiTableMetadata.load_from_dict({ 'tables': { - 'users': { - 'columns': { - 'user_id': {'sdtype': 'id'} - }, - 'primary_key': 'user_id' - }, + 'users': {'columns': {'user_id': {'sdtype': 'id'}}, 'primary_key': 'user_id'}, 'sessions': { - 'columns': { - 'user_id': {'sdtype': 'id'}, - 'brand': {'sdtype': 'categorical'} - } - } + 'columns': {'user_id': {'sdtype': 'id'}, 'brand': {'sdtype': 'categorical'}} + }, }, 'relationships': [ { 'parent_table_name': 'users', 'child_table_name': 'sessions', 'parent_primary_key': 'user_id', - 'child_foreign_key': 'user_id' + 'child_foreign_key': 'user_id', } - ] + ], }) instance = HMASynthesizer(metadata) instance.set_table_parameters('sessions', {'default_distribution': 'truncnorm'}) @@ -1296,18 +1202,18 @@ def test__recreate_child_synthesizer_with_default_parameters(self): f'{prefix}univariates__brand__a': FloatFormatter(enforce_min_max_values=True), f'{prefix}univariates__brand__b': FloatFormatter(enforce_min_max_values=True), f'{prefix}univariates__brand__loc': FloatFormatter(enforce_min_max_values=True), - f'{prefix}univariates__brand__scale': FloatFormatter(enforce_min_max_values=True) + f'{prefix}univariates__brand__scale': FloatFormatter(enforce_min_max_values=True), } } for col, float_formatter in instance.extended_columns['sessions'].items(): - float_formatter.fit(pd.DataFrame({col: [0., 100.]}), col) + float_formatter.fit(pd.DataFrame({col: [0.0, 100.0]}), col) instance._default_parameters = { 'sessions': { 'univariates__brand__a': 5, 'univariates__brand__b': 84, 'univariates__brand__loc': 1, - 'univariates__brand__scale': 1 + 'univariates__brand__scale': 1, } } @@ -1320,7 +1226,7 @@ def test__recreate_child_synthesizer_with_default_parameters(self): 'univariates__brand__b': 84, 'univariates__brand__loc': 1, 'univariates__brand__scale': 1, - 'num_rows': 10 + 'num_rows': 10, } assert child_synthesizer._get_parameters() == expected_result @@ -1367,7 +1273,8 @@ def test_metadata_updated_no_warning(self, tmp_path): # Run 3 instance = HMASynthesizer(metadata_detect) metadata_detect.update_column( - table_name='characters', column_name='age', sdtype='categorical') + table_name='characters', column_name='age', sdtype='categorical' + ) file_name = tmp_path / 'multitable_2.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -1429,36 +1336,31 @@ def test_null_foreign_keys(self): parent_table_name='parent_table', child_table_name='child_table1', parent_primary_key='id', - child_foreign_key='fk' + child_foreign_key='fk', ) metadata.add_relationship( parent_table_name='parent_table', child_table_name='child_table2', parent_primary_key='id', - child_foreign_key='fk1' + child_foreign_key='fk1', ) metadata.add_relationship( parent_table_name='parent_table', child_table_name='child_table2', parent_primary_key='id', - child_foreign_key='fk2' + child_foreign_key='fk2', ) data = { - 'parent_table': pd.DataFrame({ - 'id': [1, 2, 3] - }), - 'child_table1': pd.DataFrame({ - 'id': [1, 2, 3], - 'fk': [1, 2, np.nan] - }), + 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), + 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], - 'fk2': [1, 2, np.nan] - }) + 'fk2': [1, 2, np.nan], + }), } synthesizer = HMASynthesizer(metadata) @@ -1481,22 +1383,22 @@ def test_null_foreign_keys(self): parametrization = [ - ('update_column', { - 'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical' - }), + ('update_column', {'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical'}), ('set_primary_key', {'table_name': 'arrival', 'column_name': 'id_flight'}), ( - 'add_column_relationship', { + 'add_column_relationship', + { 'table_name': 'departure', 'relationship_type': 'address', - 'column_names': ['city', 'country'] - } + 'column_names': ['city', 'country'], + }, ), ('add_alternate_keys', {'table_name': 'departure', 'column_names': ['city', 'country']}), ('set_sequence_key', {'table_name': 'departure', 'column_name': 'city'}), - ('add_column', { - 'table_name': 'departure', 'column_name': 'postal_code', 'sdtype': 'postal_code' - }), + ( + 'add_column', + {'table_name': 'departure', 'column_name': 'postal_code', 'sdtype': 'postal_code'}, + ), ] @@ -1514,7 +1416,7 @@ def test_metadata_updated_warning(method, kwargs): 'id': {'sdtype': 'id'}, 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country_code'} + 'country': {'sdtype': 'country_code'}, }, }, 'arrival': { @@ -1524,7 +1426,7 @@ def test_metadata_updated_warning(method, kwargs): 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, 'country': {'sdtype': 'country'}, - 'id_flight': {'sdtype': 'id'} + 'id_flight': {'sdtype': 'id'}, }, }, }, @@ -1533,9 +1435,9 @@ def test_metadata_updated_warning(method, kwargs): 'parent_table_name': 'departure', 'parent_primary_key': 'id', 'child_table_name': 'arrival', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } - ] + ], }) expected_message = re.escape( "We strongly recommend saving the metadata using 'save_to_json' for replicability" @@ -1564,7 +1466,7 @@ def test_save_and_load_with_downgraded_version(tmp_path): 'id': {'sdtype': 'id'}, 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country'} + 'country': {'sdtype': 'country'}, }, }, 'arrival': { @@ -1574,7 +1476,7 @@ def test_save_and_load_with_downgraded_version(tmp_path): 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, 'country': {'sdtype': 'country'}, - 'id_flight': {'sdtype': 'id'} + 'id_flight': {'sdtype': 'id'}, }, }, }, @@ -1583,9 +1485,9 @@ def test_save_and_load_with_downgraded_version(tmp_path): 'parent_table_name': 'departure', 'parent_primary_key': 'id', 'child_table_name': 'arrival', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } - ] + ], }) instance = HMASynthesizer(metadata) @@ -1615,7 +1517,7 @@ def test_fit_raises_version_error(): 'id': {'sdtype': 'id'}, 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country'} + 'country': {'sdtype': 'country'}, }, }, 'arrival': { @@ -1625,7 +1527,7 @@ def test_fit_raises_version_error(): 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, 'country': {'sdtype': 'country'}, - 'id_flight': {'sdtype': 'id'} + 'id_flight': {'sdtype': 'id'}, }, }, }, @@ -1634,9 +1536,9 @@ def test_fit_raises_version_error(): 'parent_table_name': 'departure', 'parent_primary_key': 'id', 'child_table_name': 'arrival', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } - ] + ], }) instance = HMASynthesizer(metadata) @@ -1671,7 +1573,7 @@ def test_hma_relationship_validity(): def test_hma_not_fit_raises_sampling_error(): """Test that ``HMA`` will raise a ``SamplingError`` if it wasn't fit.""" # Setup - data, metadata = download_demo('multi_table', 'Dunur_v1') + _data, metadata = download_demo('multi_table', 'Dunur_v1') synthesizer = HMASynthesizer(metadata) # Run and Assert @@ -1712,7 +1614,7 @@ def test_fit_and_sample_numerical_col_names(): 'parent_table_name': '0', 'parent_primary_key': 1, 'child_table_name': '1', - 'child_foreign_key': 2 + 'child_foreign_key': 2, } ] metadata = MultiTableMetadata.load_from_dict(metadata_dict) @@ -1741,10 +1643,7 @@ def test_detect_from_dataframe_numerical_col(): 2: [2, 3, 4], 'categorical_col': ['a', 'b', 'a'], }) - child_data = pd.DataFrame({ - 3: [1000, 1001, 1000], - 4: [1, 2, 3] - }) + child_data = pd.DataFrame({3: [1000, 1001, 1000], 4: [1, 2, 3]}) data = { 'parent_data': parent_data, 'child_data': child_data, @@ -1761,7 +1660,7 @@ def test_detect_from_dataframe_numerical_col(): parent_primary_key='1', parent_table_name='parent_data', child_foreign_key='3', - child_table_name='child_data' + child_table_name='child_data', ) test_metadata = MultiTableMetadata() @@ -1775,7 +1674,7 @@ def test_detect_from_dataframe_numerical_col(): parent_primary_key='1', parent_table_name='parent_data', child_foreign_key='3', - child_table_name='child_data' + child_table_name='child_data', ) # Run @@ -1799,10 +1698,7 @@ def test_table_name_logging(caplog): 'parent_id': [1, 2, 3, 4, 5, 6], 'col': ['a', 'b', 'a', 'b', 'a', 'b'], }) - child_data = pd.DataFrame({ - 'id': [1, 2, 3, 4, 5, 6], - 'parent_id': [1, 2, 3, 4, 5, 6] - }) + child_data = pd.DataFrame({'id': [1, 2, 3, 4, 5, 6], 'parent_id': [1, 2, 3, 4, 5, 6]}) data = { 'parent_data': parent_data, 'child_data': child_data, @@ -1844,10 +1740,7 @@ def test_disjointed_tables(): def test_small_sample(): """Test that the sample function still works with a small scale""" # Setup - data, metadata = download_demo( - modality='multi_table', - dataset_name='fake_hotels' - ) + data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels') synthesizer = HMASynthesizer(metadata) synthesizer.fit(data) @@ -1859,7 +1752,7 @@ def test_small_sample(): with pytest.warns(Warning, match=warn_msg): synthetic_data = synthesizer.sample(scale=0.01) - assert (len(synthetic_data['hotels']) == 1) - assert (len(synthetic_data['guests']) >= len(data['guests']) * .01) + assert len(synthetic_data['hotels']) == 1 + assert len(synthetic_data['guests']) >= len(data['guests']) * 0.01 assert synthetic_data['hotels'].columns.tolist() == data['hotels'].columns.tolist() assert synthetic_data['guests'].columns.tolist() == data['guests'].columns.tolist() diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index fc86bec73..25f867960 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -19,7 +19,7 @@ def _get_par_data_and_metadata(): 'date': [date, date, date, date], 'column2': ['b', 'a', 'a', 'c'], 'entity': [1, 1, 2, 2], - 'context': ['a', 'a', 'b', 'b'] + 'context': ['a', 'a', 'b', 'b'], }) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) @@ -128,30 +128,25 @@ def test_sythesize_sequences(tmp_path): * Save and Load. """ # Setup - real_data, metadata = download_demo( - modality='sequential', - dataset_name='nasdaq100_2019' - ) + real_data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019') assert real_data[real_data['Symbol'] == 'AMZN']['Sector'].unique() - synthesizer = PARSynthesizer( - metadata, - epochs=5, - context_columns=['Sector', 'Industry'] - ) + synthesizer = PARSynthesizer(metadata, epochs=5, context_columns=['Sector', 'Industry']) custom_synthesizer = PARSynthesizer( - metadata, - epochs=5, - context_columns=['Sector', 'Industry'], - verbose=True + metadata, epochs=5, context_columns=['Sector', 'Industry'], verbose=True + ) + scenario_context = pd.DataFrame( + data={ + 'Symbol': ['COMPANY-A', 'COMPANY-B', 'COMPANY-C', 'COMPANY-D', 'COMPANY-E'], + 'Sector': ['Technology'] * 2 + ['Consumer Services'] * 3, + 'Industry': [ + 'Computer Manufacturing', + 'Computer Software: Prepackaged Software', + 'Hotels/Resorts', + 'Restaurants', + 'Clothing/Shoe/Accessory Stores', + ], + } ) - scenario_context = pd.DataFrame(data={ - 'Symbol': ['COMPANY-A', 'COMPANY-B', 'COMPANY-C', 'COMPANY-D', 'COMPANY-E'], - 'Sector': ['Technology'] * 2 + ['Consumer Services'] * 3, - 'Industry': [ - 'Computer Manufacturing', 'Computer Software: Prepackaged Software', - 'Hotels/Resorts', 'Restaurants', 'Clothing/Shoe/Accessory Stores' - ] - }) # Run - Fit synthesizer.fit(real_data) @@ -161,8 +156,7 @@ def test_sythesize_sequences(tmp_path): synthetic_data = synthesizer.sample(num_sequences=10) custom_synthetic_data = custom_synthesizer.sample(num_sequences=3, sequence_length=2) custom_synthetic_data_conditional = custom_synthesizer.sample_sequential_columns( - context_columns=scenario_context, - sequence_length=2 + context_columns=scenario_context, sequence_length=2 ) # Save and Load @@ -182,7 +176,7 @@ def test_sythesize_sequences(tmp_path): 'Computer Software: Prepackaged Software', 'Hotels/Resorts', 'Restaurants', - 'Clothing/Shoe/Accessory Stores' + 'Clothing/Shoe/Accessory Stores', ] assert industries in custom_synthetic_data_conditional['Industry'].unique() @@ -201,7 +195,10 @@ def test_sythesize_sequences(tmp_path): def test_par_subset_of_data(): """Test it when the data index is not continuous GH#1973.""" # download data - data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019',) + data, metadata = download_demo( + modality='sequential', + dataset_name='nasdaq100_2019', + ) # modify the data by choosing a subset of it data_subset = data.copy() @@ -212,7 +209,8 @@ def test_par_subset_of_data(): for i, symbol in enumerate(symbols): symbol_mask = data_subset['Symbol'] == symbol data_subset = data_subset.drop( - data_subset[symbol_mask].sample(frac=i / (2 * len(symbols))).index) + data_subset[symbol_mask].sample(frac=i / (2 * len(symbols))).index + ) # now run PAR synthesizer = PARSynthesizer(metadata, epochs=5, verbose=True) @@ -242,7 +240,7 @@ def test_par_subset_of_data_simplified(): 'sdtype': 'datetime', }, }, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', }) synthesizer = PARSynthesizer(metadata, epochs=0) @@ -258,24 +256,14 @@ def test_par_missing_sequence_index(): """Test if PAR Synthesizer can run without a sequence key""" # Setup metadata_dict = { - 'columns': { - 'value': { - 'sdtype': 'numerical' - }, - 'e_id': { - 'sdtype': 'id' - } - }, + 'columns': {'value': {'sdtype': 'numerical'}, 'e_id': {'sdtype': 'id'}}, 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', - 'sequence_key': 'e_id' + 'sequence_key': 'e_id', } metadata = SingleTableMetadata().load_from_dict(metadata_dict) - data = pd.DataFrame({ - 'value': [10, 20, 30], - 'e_id': [1, 2, 3] - }) + data = pd.DataFrame({'value': [10, 20, 30], 'e_id': [1, 2, 3]}) # Run synthesizer = PARSynthesizer(metadata) @@ -290,38 +278,22 @@ def test_par_missing_sequence_index(): def test_constraints_on_par(): """Test if only simple constraints work on PARSynthesizer.""" # Setup - real_data, metadata = download_demo( - modality='sequential', - dataset_name='nasdaq100_2019' - ) + real_data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019') - synthesizer = PARSynthesizer( - metadata, - epochs=5, - context_columns=['Sector', 'Industry'] - ) + synthesizer = PARSynthesizer(metadata, epochs=5, context_columns=['Sector', 'Industry']) market_constraint = { 'constraint_class': 'Positive', - 'constraint_parameters': { - 'column_name': 'MarketCap', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'MarketCap', 'strict_boundaries': True}, } volume_constraint = { 'constraint_class': 'Positive', - 'constraint_parameters': { - 'column_name': 'Volume', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'Volume', 'strict_boundaries': True}, } context_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_name': 'Sector', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'Sector', 'strict_boundaries': True}, } # Run @@ -349,28 +321,41 @@ def test_par_unique_sequence_index_with_enforce_min_max(): test_id = list(range(10)) s_key = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] visits = [ - '2021-01-01', '2021-01-03', '2021-01-05', '2021-01-07', '2021-01-09', - '2021-09-11', '2021-09-17', '2021-10-01', '2021-10-08', '2021-11-01' + '2021-01-01', + '2021-01-03', + '2021-01-05', + '2021-01-07', + '2021-01-09', + '2021-09-11', + '2021-09-17', + '2021-10-01', + '2021-10-08', + '2021-11-01', ] pre_date = [ - '2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05', - '2021-04-01', '2021-04-02', '2021-04-03', '2021-04-04', '2021-04-05' + '2020-01-01', + '2020-01-02', + '2020-01-03', + '2020-01-04', + '2020-01-05', + '2021-04-01', + '2021-04-02', + '2021-04-03', + '2021-04-04', + '2021-04-05', ] - test_df = pd.DataFrame({ - 'id': test_id, - 's_key': s_key, - 'visits': visits, - 'pre_date': pre_date - }) + test_df = pd.DataFrame({'id': test_id, 's_key': s_key, 'visits': visits, 'pre_date': pre_date}) test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply( - pd.to_datetime, format='%Y-%m-%d', errors='coerce') + pd.to_datetime, format='%Y-%m-%d', errors='coerce' + ) metadata = SingleTableMetadata() metadata.detect_from_dataframe(test_df) metadata.update_column(column_name='s_key', sdtype='id') metadata.set_sequence_key('s_key') metadata.set_sequence_index('visits') - synthesizer = PARSynthesizer(metadata, enforce_min_max_values=True, - enforce_rounding=False, epochs=100, verbose=True) + synthesizer = PARSynthesizer( + metadata, enforce_min_max_values=True, enforce_rounding=False, epochs=100, verbose=True + ) # Run synthesizer.fit(test_df) diff --git a/tests/integration/single_table/custom_constraints.py b/tests/integration/single_table/custom_constraints.py index 537f16953..cbdf68c14 100644 --- a/tests/integration/single_table/custom_constraints.py +++ b/tests/integration/single_table/custom_constraints.py @@ -20,11 +20,7 @@ def reverse_transform(column_names, data): return data -MyConstraint = create_custom_constraint_class( - is_valid, - transform, - reverse_transform -) +MyConstraint = create_custom_constraint_class(is_valid, transform, reverse_transform) def amenities_is_valid(column_names, data): @@ -42,10 +38,7 @@ def amenities_transform(column_names, data): boolean_column = column_names[0] numerical_column = column_names[1] typical_value = data[numerical_column].median() - data[numerical_column] = data[numerical_column].mask( - data[boolean_column], - typical_value - ) + data[numerical_column] = data[numerical_column].mask(data[boolean_column], typical_value) return data @@ -60,7 +53,5 @@ def amenities_reverse_transform(column_names, data): IfTrueThenZero = create_custom_constraint_class( - amenities_is_valid, - amenities_transform, - amenities_reverse_transform + amenities_is_valid, amenities_transform, amenities_reverse_transform ) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index a9579775a..dc9c328d5 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -15,22 +15,20 @@ from sdv.metadata import SingleTableMetadata from sdv.sampling import Condition from sdv.single_table import ( - CopulaGANSynthesizer, CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer) + CopulaGANSynthesizer, + CTGANSynthesizer, + GaussianCopulaSynthesizer, + TVAESynthesizer, +) from sdv.single_table.base import BaseSingleTableSynthesizer METADATA = SingleTableMetadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { - 'column1': { - 'sdtype': 'numerical' - }, - 'column2': { - 'sdtype': 'numerical' - }, - 'column3': { - 'sdtype': 'numerical' - } - } + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'numerical'}, + 'column3': {'sdtype': 'numerical'}, + }, }) SYNTHESIZERS = [ @@ -46,17 +44,11 @@ def test_conditional_sampling_graceful_reject_sampling_true_dict(synthesizer): data = pd.DataFrame({ 'column1': list(range(100)), 'column2': list(range(100)), - 'column3': list(range(100)) + 'column3': list(range(100)), }) synthesizer.fit(data) - conditions = [ - Condition({ - 'column1': 28, - 'column2': 37, - 'column3': 93 - }) - ] + conditions = [Condition({'column1': 28, 'column2': 37, 'column3': 93})] with pytest.raises(ValueError): # noqa: PT011 synthesizer.sample_from_conditions(conditions=conditions) @@ -67,15 +59,11 @@ def test_conditional_sampling_graceful_reject_sampling_true_dataframe(synthesize data = pd.DataFrame({ 'column1': list(range(100)), 'column2': list(range(100)), - 'column3': list(range(100)) + 'column3': list(range(100)), }) synthesizer.fit(data) - conditions = pd.DataFrame({ - 'column1': [28], - 'column2': [37], - 'column3': [93] - }) + conditions = pd.DataFrame({'column1': [28], 'column2': [37], 'column3': [93]}) with pytest.raises(ValueError, match='a'): synthesizer.sample_remaining_columns(conditions) @@ -99,7 +87,7 @@ def test_sample_from_conditions_with_batch_size(): data = pd.DataFrame({ 'column1': list(range(100)), 'column2': list(range(100)), - 'column3': list(range(100)) + 'column3': list(range(100)), }) metadata = SingleTableMetadata() @@ -109,10 +97,7 @@ def test_sample_from_conditions_with_batch_size(): model = GaussianCopulaSynthesizer(metadata) model.fit(data) - conditions = [ - Condition({'column1': 10}, num_rows=100), - Condition({'column1': 50}, num_rows=10) - ] + conditions = [Condition({'column1': 10}, num_rows=100), Condition({'column1': 50}, num_rows=10)] # Run sampled_data = model.sample_from_conditions(conditions, batch_size=50) @@ -128,7 +113,7 @@ def test_sample_from_conditions_negative_float(): data = pd.DataFrame({ 'column1': [-float(i) for i in range(100)], 'column2': list(range(100)), - 'column3': list(range(100)) + 'column3': list(range(100)), }) metadata = SingleTableMetadata() @@ -139,30 +124,24 @@ def test_sample_from_conditions_negative_float(): model = GaussianCopulaSynthesizer(metadata) model.fit(data) conditions = [ - Condition({'column1': -10.}, num_rows=100), - Condition({'column1': -50}, num_rows=10) + Condition({'column1': -10.0}, num_rows=100), + Condition({'column1': -50}, num_rows=10), ] # Run sampled_data = model.sample_from_conditions(conditions) # Assert - expected = pd.Series([-10.] * 100 + [-50.] * 10, name='column1') + expected = pd.Series([-10.0] * 100 + [-50.0] * 10, name='column1') pd.testing.assert_series_equal(sampled_data['column1'], expected) def test_sample_from_conditions_with_nans(): """Test it crashes when condition has nans (GH#1758).""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) - my_condition = Condition( - num_rows=250, - column_values={'room_type': None, 'has_rewards': False} - ) + my_condition = Condition(num_rows=250, column_values={'room_type': None, 'has_rewards': False}) # Run synthesizer.fit(data) @@ -179,15 +158,11 @@ def test_sample_from_conditions_with_nans(): def test_sample_remaining_columns_with_all_nans(): """Test it crashes when every condition row has a nan (GH#1758).""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) - known_columns = pd.DataFrame(data={ - 'has_rewards': [np.nan, False, True], - 'amenities_fee': [5.00, np.nan, None] - }) + known_columns = pd.DataFrame( + data={'has_rewards': [np.nan, False, True], 'amenities_fee': [5.00, np.nan, None]} + ) # Run synthesizer.fit(data) @@ -204,23 +179,18 @@ def test_sample_remaining_columns_with_all_nans(): def test_sample_remaining_columns_with_some_nans(): """Test it warns when some of the condition rows contain nans (GH#1758).""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) - known_columns = pd.DataFrame(data={ - 'has_rewards': [True, False, np.nan], - 'amenities_fee': [5.00, np.nan, None] - }) + known_columns = pd.DataFrame( + data={'has_rewards': [True, False, np.nan], 'amenities_fee': [5.00, np.nan, None]} + ) # Run synthesizer.fit(data) # Assert warn_msg = ( - 'Missing values are not yet supported. ' - 'Rows with any missing values will not be created.' + 'Missing values are not yet supported. ' 'Rows with any missing values will not be created.' ) with pytest.warns(UserWarning, match=warn_msg): synthesizer.sample_remaining_columns(known_columns=known_columns) @@ -229,10 +199,7 @@ def test_sample_remaining_columns_with_some_nans(): def test_sample_keys_are_scrambled(): """Test that the keys are scrambled in the sampled data.""" # Setup - data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') metadata.update_column('guest_email', sdtype='id', regex_format='[A-Z]{3}') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -255,12 +222,12 @@ def test_multiple_fits(): data_1 = pd.DataFrame({ 'city': ['LA', 'SF', 'CHI', 'LA', 'LA'], 'state': ['CA', 'CA', 'IL', 'CA', 'CA'], - 'measurement': [27.123, 28.756, 26.908, 21.002, 30.987] + 'measurement': [27.123, 28.756, 26.908, 21.002, 30.987], }) data_2 = pd.DataFrame({ 'city': ['LA', 'LA', 'CHI', 'LA', 'LA'], 'state': ['CA', 'CA', 'IL', 'CA', 'CA'], - 'measurement': [27.1, 28.7, 26.9, 21.2, 30.9] + 'measurement': [27.1, 28.7, 26.9, 21.2, 30.9], }) metadata = SingleTableMetadata() metadata.add_column('city', sdtype='categorical') @@ -268,9 +235,7 @@ def test_multiple_fits(): metadata.add_column('measurement', sdtype='numerical') constraint = { 'constraint_class': 'FixedCombinations', - 'constraint_parameters': { - 'column_names': ['city', 'state'] - } + 'constraint_parameters': {'column_names': ['city', 'state']}, } model = GaussianCopulaSynthesizer(metadata) model.add_constraints([constraint]) @@ -291,7 +256,7 @@ def test_sampling(synthesizer): data = pd.DataFrame({ 'column1': list(range(100)), 'column2': list(range(100)), - 'column3': list(range(100)) + 'column3': list(range(100)), }) synthesizer.fit(data) @@ -310,20 +275,11 @@ def test_sampling_reset_sampling(synthesizer): metadata = SingleTableMetadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { - 'column1': { - 'sdtype': 'numerical' - }, - 'column2': { - 'sdtype': 'address' - }, - 'column3': { - 'sdtype': 'email' - }, - 'column4': { - 'sdtype': 'ssn', - 'pii': True - } - } + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'address'}, + 'column3': {'sdtype': 'email'}, + 'column4': {'sdtype': 'ssn', 'pii': True}, + }, }) data = pd.DataFrame({ 'column1': list(range(100)), @@ -357,11 +313,7 @@ def test_config_creation_doesnt_raise_error(): # Run test_metadata.detect_from_dataframe(test_data) - test_metadata.update_column( - column_name='address_col', - sdtype='address', - pii=False - ) + test_metadata.update_column(column_name='address_col', sdtype='address', pii=False) synthesizer = GaussianCopulaSynthesizer(test_metadata) synthesizer.fit(test_data) @@ -383,7 +335,8 @@ def test_transformers_correctly_auto_assigned(): metadata.set_primary_key('primary_key') metadata.update_column(column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( - metadata, enforce_min_max_values=False, enforce_rounding=False) + metadata, enforce_min_max_values=False, enforce_rounding=False + ) # Run synthesizer.auto_assign_transformers(data) @@ -410,33 +363,29 @@ def test_transformers_correctly_auto_assigned(): def test_modeling_with_complex_datetimes(): """Test that models work with datetimes passed as strings or ints with complex format.""" # Setup - data = pd.DataFrame(data={ - 'string_column': [ - '20220902110443000000', - '20220916230356000000', - '20220826173917000000', - '20220826212135000000', - '20220929111311000000' - ], - 'int_column': [ - 20220902110443000000, - 20220916230356000000, - 20220826173917000000, - 20220826212135000000, - 20220929111311000000 - ] - }) + data = pd.DataFrame( + data={ + 'string_column': [ + '20220902110443000000', + '20220916230356000000', + '20220826173917000000', + '20220826212135000000', + '20220929111311000000', + ], + 'int_column': [ + 20220902110443000000, + 20220916230356000000, + 20220826173917000000, + 20220826212135000000, + 20220929111311000000, + ], + } + ) test_metadata = { 'columns': { - 'string_column': { - 'sdtype': 'datetime', - 'datetime_format': '%Y%m%d%H%M%S%f' - }, - 'int_column': { - 'sdtype': 'datetime', - 'datetime_format': '%Y%m%d%H%M%S%f' - } + 'string_column': {'sdtype': 'datetime', 'datetime_format': '%Y%m%d%H%M%S%f'}, + 'int_column': {'sdtype': 'datetime', 'datetime_format': '%Y%m%d%H%M%S%f'}, } } @@ -460,11 +409,13 @@ def test_auto_assign_transformers_and_update_with_pii(): still assign the expected transformer to it. """ # Setup - data = pd.DataFrame(data={ - 'id': ['N', 'A', 'K', 'F', 'P'], - 'numerical': [1, 2, 3, 2, 1], - 'name': ['A', 'A', 'B', 'B', 'B'] - }) + data = pd.DataFrame( + data={ + 'id': ['N', 'A', 'K', 'F', 'P'], + 'numerical': [1, 2, 3, 2, 1], + 'name': ['A', 'A', 'B', 'B', 'B'], + } + ) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) @@ -491,11 +442,13 @@ def test_auto_assign_transformers_and_update_with_pii(): def test_refitting_a_model(): """Test that refitting a model resets the sampling state of the generators.""" # Setup - data = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'numerical': [1, 2, 3, 2, 1], - 'name': ['A', 'A', 'B', 'B', 'B'] - }) + data = pd.DataFrame( + data={ + 'id': [0, 1, 2, 3, 4], + 'numerical': [1, 2, 3, 2, 1], + 'name': ['A', 'A', 'B', 'B', 'B'], + } + ) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) @@ -537,7 +490,7 @@ def test_get_info(): 'creation_date': today, 'is_fit': False, 'last_fit_date': None, - 'fitted_sdv_version': None + 'fitted_sdv_version': None, } # Run @@ -551,7 +504,7 @@ def test_get_info(): 'creation_date': today, 'is_fit': True, 'last_fit_date': today, - 'fitted_sdv_version': version + 'fitted_sdv_version': version, } @@ -715,10 +668,8 @@ def test_metadata_updated_warning_detect(mock__fit): ('update_column', {'column_name': 'col 1', 'sdtype': 'categorical'}), ('set_primary_key', {'column_name': 'col 1'}), ( - 'add_column_relationship', { - 'relationship_type': 'address', - 'column_names': ['city', 'country'] - } + 'add_column_relationship', + {'relationship_type': 'address', 'column_names': ['city', 'country']}, ), ('add_alternate_keys', {'column_names': ['col 1', 'col 2']}), ('set_sequence_key', {'column_name': 'col 1'}), @@ -793,9 +744,7 @@ def test_fit_and_sample_numerical_col_names(synthesizer_class): # Setup num_rows = 50 num_cols = 10 - values = { - i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols) - } + values = {i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)} data = pd.DataFrame(values) metadata = SingleTableMetadata() metadata_dict = {'columns': {}} diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 151fb8d83..5a70ffea8 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -24,10 +24,7 @@ def _isinstance_side_effect(*args, **kwargs): return isinstance(args[0], args[1]) -DEMO_DATA, DEMO_METADATA = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' -) +DEMO_DATA, DEMO_METADATA = download_demo(modality='single_table', dataset_name='fake_hotel_guests') @pytest.fixture @@ -72,7 +69,7 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): 'C', 'D', 'E', - ] + ], }) metadata = SingleTableMetadata() @@ -83,9 +80,7 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): model = GaussianCopulaSynthesizer(metadata) constraint = { 'constraint_class': 'Unique', - 'constraint_parameters': { - 'column_names': ['index'] - } + 'constraint_parameters': {'column_names': ['index']}, } model.add_constraints([constraint]) @@ -138,7 +133,7 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): 'C3', 'D4', 'E5', - ] + ], }) metadata = SingleTableMetadata() @@ -150,9 +145,7 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): model = GaussianCopulaSynthesizer(metadata) constraint = { 'constraint_class': 'Unique', - 'constraint_parameters': { - 'column_names': ['test_column'] - } + 'constraint_parameters': {'column_names': ['test_column']}, } model.add_constraints([constraint]) @@ -198,7 +191,7 @@ def test_fit_with_unique_constraint_on_data_subset(): 'C', 'D', 'E', - ] + ], }) metadata = SingleTableMetadata() @@ -209,9 +202,7 @@ def test_fit_with_unique_constraint_on_data_subset(): test_df = test_df.iloc[[1, 3, 4]] constraint = { 'constraint_class': 'Unique', - 'constraint_parameters': { - 'column_names': ['test_column'] - } + 'constraint_parameters': {'column_names': ['test_column']}, } model = GaussianCopulaSynthesizer(metadata) model.add_constraints([constraint]) @@ -228,17 +219,19 @@ def test_fit_with_unique_constraint_on_data_subset(): def test_conditional_sampling_with_constraints(): """Test constraints with conditional sampling. GH#1737""" # Setup - data = pd.DataFrame(data={ - 'A': [round(i, 2) for i in np.random.uniform(low=0, high=10, size=100)], - 'B': [round(i) for i in np.random.uniform(low=0, high=10, size=100)], - 'C': np.random.choice(['Yes', 'No', 'Maybe'], size=100) - }) + data = pd.DataFrame( + data={ + 'A': [round(i, 2) for i in np.random.uniform(low=0, high=10, size=100)], + 'B': [round(i) for i in np.random.uniform(low=0, high=10, size=100)], + 'C': np.random.choice(['Yes', 'No', 'Maybe'], size=100), + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'numerical'}, 'B': {'sdtype': 'numerical'}, - 'C': {'sdtype': 'categorical'} + 'C': {'sdtype': 'categorical'}, } }) @@ -250,8 +243,8 @@ def test_conditional_sampling_with_constraints(): 'column_name': 'B', 'low_value': 0, 'high_value': 10, - 'strict_boundaries': False - } + 'strict_boundaries': False, + }, } my_condition = Condition(num_rows=250, column_values={'B': 1}) @@ -269,8 +262,7 @@ def test_conditional_sampling_with_constraints(): @patch('sdv.single_table.base.isinstance') -@patch('sdv.single_table.copulas.multivariate.GaussianMultivariate', - spec_set=GaussianMultivariate) +@patch('sdv.single_table.copulas.multivariate.GaussianMultivariate', spec_set=GaussianMultivariate) def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstance_mock): """Test that the ``sample`` method handles constraints with conditions. @@ -295,7 +287,7 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc data = pd.DataFrame({ 'city': ['LA', 'SF', 'CHI', 'LA', 'LA'], 'state': ['CA', 'CA', 'IL', 'CA', 'CA'], - 'age': [27, 28, 26, 21, 30] + 'age': [27, 28, 26, 21, 30], }) metadata = SingleTableMetadata() @@ -307,20 +299,12 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc constraint = { 'constraint_class': 'FixedCombinations', - 'constraint_parameters': { - 'column_names': ['city', 'state'] - } + 'constraint_parameters': {'column_names': ['city', 'state']}, } model.add_constraints([constraint]) sampled_numeric_data = [ - pd.DataFrame({ - 'city#state': [0.1, 1, 0.75, 0.25, 0.25], - 'age': [30, 30, 30, 30, 30] - }), - pd.DataFrame({ - 'city#state': [0.75], - 'age': [30] - }), + pd.DataFrame({'city#state': [0.1, 1, 0.75, 0.25, 0.25], 'age': [30, 30, 30, 30, 30]}), + pd.DataFrame({'city#state': [0.75], 'age': [30]}), ] gm_mock.return_value.sample.side_effect = sampled_numeric_data model.fit(data) @@ -334,7 +318,7 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc expected_data = pd.DataFrame({ 'city': ['LA', 'SF', 'LA', 'LA', 'SF'], 'state': ['CA', 'CA', 'CA', 'CA', 'CA'], - 'age': [30, 30, 30, 30, 30] + 'age': [30, 30, 30, 30, 30], }) sample_calls = model._model.sample.mock_calls assert len(sample_calls) == 2 @@ -355,19 +339,14 @@ def test_custom_constraints_from_file(tmpdir): metadata.detect_from_dataframe(data) metadata.update_column(column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( - metadata, - enforce_min_max_values=False, - enforce_rounding=False + metadata, enforce_min_max_values=False, enforce_rounding=False ) synthesizer.load_custom_constraint_classes( - 'tests/integration/single_table/custom_constraints.py', - ['MyConstraint'] + 'tests/integration/single_table/custom_constraints.py', ['MyConstraint'] ) constraint = { 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col'] - } + 'constraint_parameters': {'column_names': ['numerical_col']}, } # Run @@ -404,16 +383,12 @@ def test_custom_constraints_from_object(tmpdir): metadata.detect_from_dataframe(data) metadata.update_column(column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( - metadata, - enforce_min_max_values=False, - enforce_rounding=False + metadata, enforce_min_max_values=False, enforce_rounding=False ) synthesizer.add_custom_constraint_class(MyConstraint, 'MyConstraint') constraint = { 'constraint_class': 'MyConstraint', - 'constraint_parameters': { - 'column_names': ['numerical_col'] - } + 'constraint_parameters': {'column_names': ['numerical_col']}, } # Run @@ -445,8 +420,8 @@ def test_synthesizer_with_inequality_constraint(demo_data, demo_metadata): 'constraint_class': 'Inequality', 'constraint_parameters': { 'low_column_name': 'checkin_date', - 'high_column_name': 'checkout_date' - } + 'high_column_name': 'checkout_date', + }, } synthesizer.add_constraints([checkin_lessthan_checkout]) @@ -456,23 +431,23 @@ def test_synthesizer_with_inequality_constraint(demo_data, demo_metadata): sampled = synthesizer.sample(num_rows=500) synthesizer.validate(sampled) _sampled = sampled[~sampled['checkout_date'].isna()] - assert all( - pd.to_datetime(_sampled['checkin_date']) < pd.to_datetime(_sampled['checkout_date']) - ) + assert all(pd.to_datetime(_sampled['checkin_date']) < pd.to_datetime(_sampled['checkout_date'])) def test_inequality_constraint_with_datetimes_and_nones(): """Test that the ``Inequality`` constraint works with ``None`` and ``datetime``.""" # Setup - data = pd.DataFrame(data={ - 'A': [None, None, '2020-01-02', '2020-03-04'] * 2, - 'B': [None, '2021-03-04', '2021-12-31', None] * 2 - }) + data = pd.DataFrame( + data={ + 'A': [None, None, '2020-01-02', '2020-03-04'] * 2, + 'B': [None, '2021-03-04', '2021-12-31', None] * 2, + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, } }) @@ -481,10 +456,7 @@ def test_inequality_constraint_with_datetimes_and_nones(): synth.add_constraints([ { 'constraint_class': 'Inequality', - 'constraint_parameters': { - 'low_column_name': 'A', - 'high_column_name': 'B' - } + 'constraint_parameters': {'low_column_name': 'A', 'high_column_name': 'B'}, } ]) synth.validate(data) @@ -496,13 +468,29 @@ def test_inequality_constraint_with_datetimes_and_nones(): # Assert expected_sampled = pd.DataFrame({ 'A': [ - '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', - '2020-01-02', '2020-01-02', '2020-01-02', np.nan + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + np.nan, ], 'B': [ - np.nan, '2021-12-30', '2021-12-30', '2021-12-30', np.nan, - '2021-12-30', np.nan, '2021-12-30', np.nan, '2021-12-30' - ] + np.nan, + '2021-12-30', + '2021-12-30', + '2021-12-30', + np.nan, + '2021-12-30', + np.nan, + '2021-12-30', + np.nan, + '2021-12-30', + ], }) pd.testing.assert_frame_equal(expected_sampled, sampled) @@ -510,15 +498,17 @@ def test_inequality_constraint_with_datetimes_and_nones(): def test_scalar_inequality_constraint_with_datetimes_and_nones(): """Test that the ``ScalarInequality`` constraint works with ``None`` and ``datetime``.""" # Setup - data = pd.DataFrame(data={ - 'A': [None, None, '2020-01-02', '2020-03-04'], - 'B': [None, '2021-03-04', '2021-12-31', None] - }) + data = pd.DataFrame( + data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None], + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, } }) @@ -527,11 +517,7 @@ def test_scalar_inequality_constraint_with_datetimes_and_nones(): synth.add_constraints([ { 'constraint_class': 'ScalarInequality', - 'constraint_parameters': { - 'column_name': 'A', - 'relation': '>=', - 'value': '2019-01-01' - } + 'constraint_parameters': {'column_name': 'A', 'relation': '>=', 'value': '2019-01-01'}, } ]) synth.validate(data) @@ -555,7 +541,7 @@ def test_scalar_inequality_constraint_with_datetimes_and_nones(): 2: '2021-07-26', 3: '2021-07-02', 4: '2021-06-06', - } + }, }) pd.testing.assert_frame_equal(expected_sampled, sampled) @@ -563,15 +549,17 @@ def test_scalar_inequality_constraint_with_datetimes_and_nones(): def test_scalar_range_constraint_with_datetimes_and_nones(): """Test that the ``ScalarRange`` constraint works with ``None`` and ``datetime``.""" # Setup - data = pd.DataFrame(data={ - 'A': [None, None, '2020-01-02', '2020-03-04'], - 'B': [None, '2021-03-04', '2021-12-31', None] - }) + data = pd.DataFrame( + data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None], + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, } }) @@ -584,8 +572,8 @@ def test_scalar_range_constraint_with_datetimes_and_nones(): 'column_name': 'A', 'low_value': '2019-10-30', 'high_value': '2020-03-04', - 'strict_boundaries': False - } + 'strict_boundaries': False, + }, } ]) synth.validate(data) @@ -619,7 +607,7 @@ def test_scalar_range_constraint_with_datetimes_and_nones(): 7: '2021-06-19', 8: np.nan, 9: np.nan, - } + }, }) pd.testing.assert_frame_equal(expected_sampled, sampled) @@ -627,17 +615,19 @@ def test_scalar_range_constraint_with_datetimes_and_nones(): def test_range_constraint_with_datetimes_and_nones(): """Test that the ``Range`` constraint works with ``None`` and ``datetime``.""" # Setup - data = pd.DataFrame(data={ - 'A': [None, None, '2020-01-02', '2020-03-04'], - 'B': [None, '2021-03-04', '2021-12-31', None], - 'C': [None, '2022-03-04', '2022-12-31', None], - }) + data = pd.DataFrame( + data={ + 'A': [None, None, '2020-01-02', '2020-03-04'], + 'B': [None, '2021-03-04', '2021-12-31', None], + 'C': [None, '2022-03-04', '2022-12-31', None], + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'C': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + 'C': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, } }) @@ -650,8 +640,8 @@ def test_range_constraint_with_datetimes_and_nones(): 'low_column_name': 'A', 'middle_column_name': 'B', 'high_column_name': 'C', - 'strict_boundaries': False - } + 'strict_boundaries': False, + }, } ]) synth.validate(data) @@ -663,17 +653,41 @@ def test_range_constraint_with_datetimes_and_nones(): # Assert expected_sampled = pd.DataFrame({ 'A': [ - '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', - '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02', '2020-01-02' + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', + '2020-01-02', ], 'B': [ - np.nan, np.nan, '2021-12-30', np.nan, np.nan, '2021-12-30', np.nan, - np.nan, np.nan, '2021-12-30' + np.nan, + np.nan, + '2021-12-30', + np.nan, + np.nan, + '2021-12-30', + np.nan, + np.nan, + np.nan, + '2021-12-30', ], 'C': [ - np.nan, np.nan, '2022-12-30', np.nan, np.nan, '2022-12-30', np.nan, - np.nan, np.nan, '2022-12-30' - ] + np.nan, + np.nan, + '2022-12-30', + np.nan, + np.nan, + '2022-12-30', + np.nan, + np.nan, + np.nan, + '2022-12-30', + ], }) pd.testing.assert_frame_equal(expected_sampled, sampled) @@ -681,10 +695,7 @@ def test_range_constraint_with_datetimes_and_nones(): def test_inequality_constraint_all_possible_nans_configurations(): """Test that the inequality constraint works with all possible NaN configurations.""" # Setup - data = pd.DataFrame(data={ - 'A': [0, 1, np.nan, np.nan, 2], - 'B': [2, np.nan, 3, np.nan, 3] - }) + data = pd.DataFrame(data={'A': [0, 1, np.nan, np.nan, 2], 'B': [2, np.nan, 3, np.nan, 3]}) metadata = SingleTableMetadata.load_from_dict({ 'columns': { @@ -694,17 +705,12 @@ def test_inequality_constraint_all_possible_nans_configurations(): }) synthesizer = GaussianCopulaSynthesizer(metadata) - synthesizer.add_constraints( - [ - { - 'constraint_class': 'Inequality', - 'constraint_parameters': { - 'low_column_name': 'A', - 'high_column_name': 'B' - } - } - ] - ) + synthesizer.add_constraints([ + { + 'constraint_class': 'Inequality', + 'constraint_parameters': {'low_column_name': 'A', 'high_column_name': 'B'}, + } + ]) # Run synthesizer.fit(data) @@ -720,17 +726,19 @@ def test_inequality_constraint_all_possible_nans_configurations(): def test_range_constraint_all_possible_nans_configurations(): """Test that the range constraint works with all possible NaN configurations.""" # Setup - data = pd.DataFrame(data={ - 'low': [1, 4, np.nan, 0, 4, np.nan, np.nan, 5, np.nan], - 'middle': [2, 5, 3, np.nan, 5, np.nan, 5, np.nan, np.nan], - 'high': [3, 7, 8, 4, np.nan, 9, np.nan, np.nan, np.nan] - }) + data = pd.DataFrame( + data={ + 'low': [1, 4, np.nan, 0, 4, np.nan, np.nan, 5, np.nan], + 'middle': [2, 5, 3, np.nan, 5, np.nan, 5, np.nan, np.nan], + 'high': [3, 7, 8, 4, np.nan, 9, np.nan, np.nan, np.nan], + } + ) metadata_dict = { 'columns': { 'low': {'sdtype': 'numerical'}, 'middle': {'sdtype': 'numerical'}, - 'high': {'sdtype': 'numerical'} + 'high': {'sdtype': 'numerical'}, } } @@ -742,8 +750,8 @@ def test_range_constraint_all_possible_nans_configurations(): 'constraint_parameters': { 'low_column_name': 'low', 'middle_column_name': 'middle', - 'high_column_name': 'high' - } + 'high_column_name': 'high', + }, } # Run @@ -779,6 +787,7 @@ def test_range_constraint_all_possible_nans_configurations(): def test_custom_constraint_with_key(): """Test that a custom constraint can work with a primary key.""" + # Setup def is_valid(column_names, data): return data['key'] == data['letter'] + '_' + data['number'] @@ -794,16 +803,14 @@ def reverse_transform(column_names, data): return data custom_constraint = create_custom_constraint_class( - is_valid_fn=is_valid, - transform_fn=transform, - reverse_transform_fn=reverse_transform + is_valid_fn=is_valid, transform_fn=transform, reverse_transform_fn=reverse_transform ) data = pd.DataFrame({ 'key': ['a_1', 'b_2', 'c_3'], 'letter': ['a', 'b', 'c'], 'number': ['1', '2', '3'], - 'other': [7, 8, 9] + 'other': [7, 8, 9], }) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) @@ -816,7 +823,7 @@ def reverse_transform(column_names, data): 'constraint_class': 'custom', 'constraint_parameters': { 'column_names': ['letter', 'number'], - } + }, } synth.add_constraints([id_must_match]) @@ -844,8 +851,8 @@ def test_timezone_aware_constraints(): 'constraint_parameters': { 'low_column_name': 'col1', 'high_column_name': 'col2', - 'strict_boundaries': True - } + 'strict_boundaries': True, + }, } # Run @@ -878,31 +885,29 @@ def reverse_transform(column_names, data): return data custom_constraint = create_custom_constraint_class( - is_valid_fn=is_valid, - transform_fn=transform, - reverse_transform_fn=reverse_transform + is_valid_fn=is_valid, transform_fn=transform, reverse_transform_fn=reverse_transform ) synth.add_custom_constraint_class(custom_constraint, 'custom') checkin_checkout_constraint = { 'constraint_class': 'Inequality', 'constraint_parameters': { 'low_column_name': 'checkin_date', - 'high_column_name': 'checkout_date' - } + 'high_column_name': 'checkout_date', + }, } error_constraint = { 'constraint_class': 'custom', 'constraint_parameters': { 'column_names': ['room_rate'], - } + }, } overlapped_constraint = { 'constraint_class': 'ScalarInequality', 'constraint_parameters': { 'column_name': 'checkout_date', 'relation': '>', - 'value': '01 Jan 1990' - } + 'value': '01 Jan 1990', + }, } synth.add_constraints( constraints=[checkin_checkout_constraint, error_constraint, overlapped_constraint] @@ -917,7 +922,7 @@ def reverse_transform(column_names, data): "Unable to transform ScalarInequality with columns ['checkout_date'] because they are not " 'all available in the data. This happens due to multiple, overlapping constraints.', "Unable to transform CustomConstraint with columns ['room_rate'] due to an error in " - 'transform: \nTransform error\nUsing the reject sampling approach instead.' + 'transform: \nTransform error\nUsing the reject sampling approach instead.', ] log_messages = [record[2] for record in caplog.record_tuples] for log in expected_logs: @@ -926,6 +931,7 @@ def reverse_transform(column_names, data): def test_aggregate_constraint_errors(demo_data, demo_metadata): """Test that if there are multiple constraint errors, they are raised together.""" + # Setup class BadConstraint(Constraint): def __init__(self, column_name): @@ -937,11 +943,11 @@ def _transform(self, table_data): synth = GaussianCopulaSynthesizer(demo_metadata) bad_constraint1 = { 'constraint_class': 'BadConstraint', - 'constraint_parameters': {'column_name': 'room_rate'} + 'constraint_parameters': {'column_name': 'room_rate'}, } bad_constraint2 = { 'constraint_class': 'BadConstraint', - 'constraint_parameters': {'column_name': 'checkin_date'} + 'constraint_parameters': {'column_name': 'checkin_date'}, } synth.add_constraints(constraints=[bad_constraint1, bad_constraint2]) @@ -954,14 +960,16 @@ def _transform(self, table_data): def test_constraint_datetime_check(): """Test datetime columns are correctly identified in constraints. GH#1692""" # Setup - data = pd.DataFrame(data={ - 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], - 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'] - }) + data = pd.DataFrame( + data={ + 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], + 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'], + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'low_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, - 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'} + 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, } }) my_constraint = { @@ -969,8 +977,8 @@ def test_constraint_datetime_check(): 'constraint_parameters': { 'low_column_name': 'low_col', 'high_column_name': 'high_col', - 'strict_boundaries': False - } + 'strict_boundaries': False, + }, } # Run @@ -985,6 +993,6 @@ def test_constraint_datetime_check(): # Assert expected_dataframe = pd.DataFrame({ 'low_col': ['18 Jul, 15', '09 Aug, 15', '24 Jun, 15'], - 'high_col': ['05 Sep, 15', '26 Sep, 15', '12 Aug, 15'] + 'high_col': ['05 Sep, 15', '26 Sep, 15', '12 Aug, 15'], }) pd.testing.assert_frame_equal(samples, expected_dataframe) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 8fea5b621..a2553c90e 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -4,8 +4,13 @@ import pandas as pd import pytest from rdt.transformers import ( - AnonymizedFaker, CustomLabelEncoder, FloatFormatter, IDGenerator, LabelEncoder, - PseudoAnonymizedFaker) + AnonymizedFaker, + CustomLabelEncoder, + FloatFormatter, + IDGenerator, + LabelEncoder, + PseudoAnonymizedFaker, +) from sdv.datasets.demo import download_demo from sdv.errors import ConstraintsNotMetError @@ -22,10 +27,7 @@ def test_synthesize_table_gaussian_copula(tmp_path): synthesizer customization. """ # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) custom_synthesizer = GaussianCopulaSynthesizer( metadata, @@ -33,20 +35,18 @@ def test_synthesize_table_gaussian_copula(tmp_path): numerical_distributions={ 'checkin_date': 'uniform', 'checkout_date': 'uniform', - 'room_rate': 'gaussian_kde' - } + 'room_rate': 'gaussian_kde', + }, ) sensitive_columns = ['guest_email', 'billing_address', 'credit_card_number'] model_path = tmp_path / 'synthesizer.pkl' suite_guests_with_rewards = Condition( - num_rows=250, - column_values={'room_type': 'SUITE', 'has_rewards': True} + num_rows=250, column_values={'room_type': 'SUITE', 'has_rewards': True} ) suite_guests_without_rewards = Condition( - num_rows=250, - column_values={'room_type': 'SUITE', 'has_rewards': False} + num_rows=250, column_values={'room_type': 'SUITE', 'has_rewards': False} ) # Run - fit @@ -54,24 +54,20 @@ def test_synthesize_table_gaussian_copula(tmp_path): synthetic_data = synthesizer.sample(num_rows=500) # Run - evaluate - quality_report = evaluate_quality( - real_data, - synthetic_data, - metadata - ) + quality_report = evaluate_quality(real_data, synthetic_data, metadata) column_plot = get_column_plot( real_data=real_data, synthetic_data=synthetic_data, column_name='room_rate', - metadata=metadata + metadata=metadata, ) pair_plot = get_column_pair_plot( real_data=real_data, synthetic_data=synthetic_data, column_names=['room_rate', 'room_type'], - metadata=metadata + metadata=metadata, ) # Run - save model @@ -81,21 +77,16 @@ def test_synthesize_table_gaussian_copula(tmp_path): custom_synthesizer.fit(real_data) synthetic_data_customized = custom_synthesizer.sample(num_rows=500) learned_distributions = custom_synthesizer.get_learned_distributions() - custom_quality_report = evaluate_quality( - real_data, - synthetic_data_customized, - metadata - ) + custom_quality_report = evaluate_quality(real_data, synthetic_data_customized, metadata) custom_column_plot = get_column_plot( real_data=real_data, synthetic_data=synthetic_data_customized, column_name='room_rate', - metadata=metadata + metadata=metadata, + ) + simulated_synthetic_data = custom_synthesizer.sample_from_conditions( + conditions=[suite_guests_with_rewards, suite_guests_without_rewards] ) - simulated_synthetic_data = custom_synthesizer.sample_from_conditions(conditions=[ - suite_guests_with_rewards, - suite_guests_without_rewards - ]) # Assert - fit assert set(real_data.columns) == set(synthetic_data.columns) @@ -125,7 +116,7 @@ def test_synthesize_table_gaussian_copula(tmp_path): 'a', 'b', 'loc', - 'scale' + 'scale', ] assert learned_distributions['has_rewards']['distribution'] == 'truncnorm' assert set(real_data.columns) == set(simulated_synthetic_data.columns) @@ -144,17 +135,14 @@ def test_adding_constraints(tmp_path): * Save, load and sample from the model storing both custom and pre-defined constraints. """ # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') checkin_lessthan_checkout = { 'constraint_class': 'Inequality', 'constraint_parameters': { 'low_column_name': 'checkin_date', - 'high_column_name': 'checkout_date' - } + 'high_column_name': 'checkout_date', + }, } synthesizer = GaussianCopulaSynthesizer(metadata) @@ -172,30 +160,26 @@ def test_adding_constraints(tmp_path): # Load custom constraint class synthesizer.load_custom_constraint_classes( - 'tests/integration/single_table/custom_constraints.py', - ['IfTrueThenZero'] + 'tests/integration/single_table/custom_constraints.py', ['IfTrueThenZero'] ) rewards_member_no_fee = { 'constraint_class': 'IfTrueThenZero', - 'constraint_parameters': { - 'column_names': ['has_rewards', 'amenities_fee'] - } + 'constraint_parameters': {'column_names': ['has_rewards', 'amenities_fee']}, } synthesizer.add_constraints([rewards_member_no_fee]) # Re-Fit the model synthesizer.preprocess(real_data) - synthesizer.update_transformers({ - 'checkin_date#checkout_date.nan_component': LabelEncoder() - }) + synthesizer.update_transformers({'checkin_date#checkout_date.nan_component': LabelEncoder()}) synthesizer.fit(real_data) synthetic_data_custom_constraint = synthesizer.sample(500) # Assert validation = synthetic_data_custom_constraint[synthetic_data_custom_constraint['has_rewards']] assert validation['amenities_fee'].sum() == 0.0 - assert isinstance(synthesizer.get_transformers()['checkin_date#checkout_date.nan_component'], - LabelEncoder) + assert isinstance( + synthesizer.get_transformers()['checkin_date#checkout_date.nan_component'], LabelEncoder + ) # Save and Load model_path = tmp_path / 'synthesizer.pkl' @@ -226,33 +210,22 @@ def test_custom_processing_anonymization(): * Anonymization and pseudo-anonymization """ # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = GaussianCopulaSynthesizer(metadata) transformers_synthesizer = GaussianCopulaSynthesizer(metadata) anonymization_synthesizer = GaussianCopulaSynthesizer(metadata) - room_type_transformer = CustomLabelEncoder( - order=['BASIC', 'DELUXE', 'SUITE'], - add_noise=True - ) + room_type_transformer = CustomLabelEncoder(order=['BASIC', 'DELUXE', 'SUITE'], add_noise=True) amenities_fee_transformer = FloatFormatter( - learn_rounding_scheme=True, - enforce_min_max_values=True, - missing_value_replacement=0.00 + learn_rounding_scheme=True, enforce_min_max_values=True, missing_value_replacement=0.00 ) sensitive_columns = ['guest_email', 'billing_address', 'credit_card_number'] guest_email_transformer = AnonymizedFaker( - provider_name='misc', - function_name='uuid4', - enforce_uniqueness=True + provider_name='misc', function_name='uuid4', enforce_uniqueness=True ) billing_address_transformer = PseudoAnonymizedFaker( - provider_name='address', - function_name='address' + provider_name='address', function_name='address' ) # Run - Pre-process data @@ -264,7 +237,7 @@ def test_custom_processing_anonymization(): transformers_synthesizer.preprocess(real_data) transformers_synthesizer.update_transformers({ 'room_type': room_type_transformer, - 'amenities_fee': amenities_fee_transformer + 'amenities_fee': amenities_fee_transformer, }) transformers_synthesizer.fit(real_data) @@ -272,7 +245,7 @@ def test_custom_processing_anonymization(): anonymization_synthesizer.preprocess(real_data) anonymization_synthesizer.update_transformers({ 'guest_email': guest_email_transformer, - 'billing_address': billing_address_transformer + 'billing_address': billing_address_transformer, }) anonymization_synthesizer.fit(real_data) anonymized_sample = anonymization_synthesizer.sample(num_rows=100) @@ -302,10 +275,7 @@ def test_update_transformers_with_id_generator(): # Setup min_value_id = 5 sample_num = 20 - data = pd.DataFrame({ - 'user_id': list(range(4)), - 'user_cat': ['a', 'b', 'c', 'd'] - }) + data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) stm = SingleTableMetadata() stm.detect_from_dataframe(data) @@ -332,10 +302,7 @@ def test_update_transformers_with_id_generator(): def test_validate_with_failing_constraint(): """Validate that the ``constraint`` are raising errors if there is an error during validate.""" # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') real_data['checkin_date'][0] = real_data['checkout_date'][1] gc = GaussianCopulaSynthesizer(metadata) @@ -343,12 +310,10 @@ def test_validate_with_failing_constraint(): 'constraint_class': 'Inequality', 'constraint_parameters': { 'low_column_name': 'checkin_date', - 'high_column_name': 'checkout_date' - } + 'high_column_name': 'checkout_date', + }, } - gc.add_constraints([ - checkin_lessthan_checkout - ]) + gc.add_constraints([checkin_lessthan_checkout]) error_msg = ( "Data is not valid for the 'Inequality' constraint:" @@ -364,18 +329,16 @@ def test_validate_with_failing_constraint(): def test_numerical_columns_gets_pii(): """Test that the synthesizer works when a ``numerical`` column gets converted to ``PII``.""" # Setup - data = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'city': [0, 0, 0, 0, 0], - 'numerical': [21, 22, 23, 24, 25] - }) + data = pd.DataFrame( + data={'id': [0, 1, 2, 3, 4], 'city': [0, 0, 0, 0, 0], 'numerical': [21, 22, 23, 24, 25]} + ) metadata = SingleTableMetadata.load_from_dict({ 'primary_key': 'id', 'columns': { 'id': {'sdtype': 'id'}, 'city': {'sdtype': 'city'}, - 'numerical': {'sdtype': 'numerical'} - } + 'numerical': {'sdtype': 'numerical'}, + }, }) synth = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') synth.fit(data) @@ -395,7 +358,7 @@ def test_numerical_columns_gets_pii(): 6: 795819284, 7: 607278621, 8: 783746695, - 9: 162118876 + 9: 162118876, }, 'city': { 0: 'Danielfort', @@ -407,9 +370,9 @@ def test_numerical_columns_gets_pii(): 6: 'Ryanfort', 7: 'West Stephenland', 8: 'Davidland', - 9: 'Port Christopher' + 9: 'Port Christopher', }, - 'numerical': {0: 22, 1: 24, 2: 22, 3: 23, 4: 22, 5: 24, 6: 23, 7: 24, 8: 24, 9: 24} + 'numerical': {0: 22, 1: 24, 2: 22, 3: 23, 4: 22, 5: 24, 6: 23, 7: 24, 8: 24, 9: 24}, }) pd.testing.assert_frame_equal(expected_sampled, sampled) @@ -419,8 +382,26 @@ def test_categorical_column_with_numbers(): # Setup data = pd.DataFrame({ 'category_col': [ - 1, 2, 1, 2, 1, 2, np.nan, 1, 1, np.nan, 2, 2, np.nan, 2, - 1, 1, np.nan, 1, 2, 2 + 1, + 2, + 1, + 2, + 1, + 2, + np.nan, + 1, + 1, + np.nan, + 2, + 2, + np.nan, + 2, + 1, + 1, + np.nan, + 1, + 2, + 2, ], 'numerical_col': np.random.rand(20), }) diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 9ce456081..4e0218985 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -23,7 +23,7 @@ def test__estimate_num_columns(): metadata.add_column('datetime', sdtype='datetime') metadata.add_column('boolean', sdtype='boolean') data = pd.DataFrame({ - 'numerical': [.1, .2, .3], + 'numerical': [0.1, 0.2, 0.3], 'datetime': ['2020-01-01', '2020-01-02', '2020-01-03'], 'categorical': ['a', 'b', 'b'], 'categorical2': ['a', 'b', 'b'], @@ -54,15 +54,9 @@ def test_synthesize_table_ctgan(tmp_path): Tests quality reports, anonymization, and customizing the synthesizer. """ # Setup - real_data, metadata = download_demo( - modality='single_table', - dataset_name='fake_hotel_guests' - ) + real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') synthesizer = CTGANSynthesizer(metadata) - custom_synthesizer = CTGANSynthesizer( - metadata, - epochs=100 - ) + custom_synthesizer = CTGANSynthesizer(metadata, epochs=100) sensitive_columns = ['guest_email', 'billing_address', 'credit_card_number'] model_path = tmp_path / 'synthesizer.pkl' @@ -71,24 +65,20 @@ def test_synthesize_table_ctgan(tmp_path): synthetic_data = synthesizer.sample(num_rows=500) # Run - evaluate - quality_report = evaluate_quality( - real_data, - synthetic_data, - metadata - ) + quality_report = evaluate_quality(real_data, synthetic_data, metadata) column_plot = get_column_plot( real_data=real_data, synthetic_data=synthetic_data, column_name='room_type', - metadata=metadata + metadata=metadata, ) pair_plot = get_column_pair_plot( real_data=real_data, synthetic_data=synthetic_data, column_names=['room_rate', 'room_type'], - metadata=metadata + metadata=metadata, ) # Run - save model @@ -97,11 +87,7 @@ def test_synthesize_table_ctgan(tmp_path): # Run - custom synthesizer custom_synthesizer.fit(real_data) synthetic_data_customized = custom_synthesizer.sample(num_rows=500) - custom_quality_report = evaluate_quality( - real_data, - synthetic_data_customized, - metadata - ) + custom_quality_report = evaluate_quality(real_data, synthetic_data_customized, metadata) # Assert - fit assert set(real_data.columns) == set(synthetic_data.columns) @@ -141,16 +127,18 @@ def test_categoricals_are_not_preprocessed(): for different data types. """ # Setup - data = pd.DataFrame(data={ - 'age': [56, 61, 36, 52, 42], - 'therapy': [True, False, True, False, True], - 'alcohol': ['medium', 'medium', 'low', 'high', 'low'], - }) + data = pd.DataFrame( + data={ + 'age': [56, 61, 36, 52, 42], + 'therapy': [True, False, True, False, True], + 'alcohol': ['medium', 'medium', 'low', 'high', 'low'], + } + ) metadata = SingleTableMetadata.load_from_dict({ 'columns': { 'age': {'sdtype': 'numerical'}, 'therapy': {'sdtype': 'boolean'}, - 'alcohol': {'sdtype': 'categorical'} + 'alcohol': {'sdtype': 'categorical'}, } }) @@ -188,8 +176,8 @@ def test_categorical_metadata_with_int_data(): 'columns': { 'A': {'sdtype': 'categorical'}, 'B': {'sdtype': 'numerical'}, - 'C': {'sdtype': 'categorical'} - } + 'C': {'sdtype': 'categorical'}, + }, } metadata = SingleTableMetadata.load_from_dict(metadata_dict) @@ -255,13 +243,9 @@ def test_ctgansynthesizer_with_constraints_generating_categorical_values(): my_synthesizer = CTGANSynthesizer(metadata) my_constraint = { 'constraint_class': 'FixedCombinations', - 'constraint_parameters': { - 'column_names': ['high_spec', 'degree_type'] - } + 'constraint_parameters': {'column_names': ['high_spec', 'degree_type']}, } - my_synthesizer.add_constraints(constraints=[ - my_constraint - ]) + my_synthesizer.add_constraints(constraints=[my_constraint]) # Run my_synthesizer.fit(data) @@ -274,17 +258,16 @@ def test_ctgansynthesizer_with_constraints_generating_categorical_values(): def test_ctgan_with_dropped_columns(): """Test CTGANSynthesizer doesn't crash when applied to columns that will be dropped. GH#1741""" # Setup - data = pd.DataFrame(data={ - 'user_id': ['100', '101', '102', '103', '104'], - 'user_ssn': ['111-11-1111', '222-22-2222', '333-33-3333', '444-44-4444', '555-55-5555'] - }) + data = pd.DataFrame( + data={ + 'user_id': ['100', '101', '102', '103', '104'], + 'user_ssn': ['111-11-1111', '222-22-2222', '333-33-3333', '444-44-4444', '555-55-5555'], + } + ) metadata_dict = { 'primary_key': 'user_id', - 'columns': { - 'user_id': {'sdtype': 'id'}, - 'user_ssn': {'sdtype': 'ssn'} - } + 'columns': {'user_id': {'sdtype': 'id'}, 'user_ssn': {'sdtype': 'ssn'}}, } metadata = SingleTableMetadata.load_from_dict(metadata_dict) @@ -300,16 +283,19 @@ def test_ctgan_with_dropped_columns(): assert all(id_val.startswith('sdv-id-') for id_val in samples['user_id']) pd.testing.assert_series_equal( samples['user_id'], - pd.Series([ - 'sdv-id-IOsBJZ', - 'sdv-id-CFcIuA', - 'sdv-id-prYgtc', - 'sdv-id-yrTTYM', - 'sdv-id-kLtfIW', - 'sdv-id-nCFkOi', - 'sdv-id-kKQXYV', - 'sdv-id-aPRybP', - 'sdv-id-RHPiGX', - 'sdv-id-SJNtGY' - ], name='user_id') + pd.Series( + [ + 'sdv-id-IOsBJZ', + 'sdv-id-CFcIuA', + 'sdv-id-prYgtc', + 'sdv-id-yrTTYM', + 'sdv-id-kLtfIW', + 'sdv-id-nCFkOi', + 'sdv-id-kKQXYV', + 'sdv-id-aPRybP', + 'sdv-id-RHPiGX', + 'sdv-id-SJNtGY', + ], + name='user_id', + ), ) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 8f4c5fc37..b74ac7905 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -15,53 +15,44 @@ @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict( - { - 'tables': { - 'parent': { - 'columns': { - 'id': {'sdtype': 'id'}, - 'A': {'sdtype': 'categorical'}, - 'B': {'sdtype': 'numerical'} - }, - 'primary_key': 'id' + return MultiTableMetadata.load_from_dict({ + 'tables': { + 'parent': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'numerical'}, }, - 'child': { - 'columns': { - 'parent_id': {'sdtype': 'id'}, - 'C': {'sdtype': 'categorical'} - } - } + 'primary_key': 'id', }, - 'relationships': [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id', - 'child_foreign_key': 'parent_id' - } - ] - } - ) + 'child': {'columns': {'parent_id': {'sdtype': 'id'}, 'C': {'sdtype': 'categorical'}}}, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id', + 'child_foreign_key': 'parent_id', + } + ], + }) @pytest.fixture def data(): - parent = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - 'B': [0.434, 0.312, 0.212, 0.339, 0.491] - }) + parent = pd.DataFrame( + data={ + 'id': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + 'B': [0.434, 0.312, 0.212, 0.339, 0.491], + } + ) - child = pd.DataFrame(data={ - 'parent_id': [0, 1, 2, 2, 5], - 'C': ['Yes', 'No', 'Maye', 'No', 'No'] - }) + child = pd.DataFrame( + data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maye', 'No', 'No']} + ) - return { - 'parent': parent, - 'child': child - } + return {'parent': parent, 'child': child} def test_simplify_schema(capsys): @@ -152,13 +143,13 @@ def test_simplify_schema_big_demo_datasets(): 'NBA_v1', 'NCAA_v1', 'PremierLeague_v1', - 'financial_v1' + 'financial_v1', ] for dataset in list_datasets: real_data, metadata = download_demo('multi_table', dataset) # Run - data_simplify, metadata_simplify = simplify_schema(real_data, metadata) + _data_simplify, metadata_simplify = simplify_schema(real_data, metadata) # Assert estimate_column_before = _get_total_estimated_columns(metadata) @@ -174,7 +165,7 @@ def test_simplify_schema_big_demo_datasets(): ('MuskSmall_v1', 'molecule', 'conformation', 50, 150), ('NBA_v1', 'Team', 'Actions', 10, 200), ('NCAA_v1', 'tourney_slots', 'tourney_compact_results', 1000, 1000), - ] + ], ) def test_get_random_subset(dataset_name, main_table_1, main_table_2, num_rows_1, num_rows_2): """Test ``get_random_subset`` end to end. diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 35b0f28d2..5139ab0c1 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -12,53 +12,44 @@ @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict( - { - 'tables': { - 'parent': { - 'columns': { - 'id': {'sdtype': 'id'}, - 'A': {'sdtype': 'categorical'}, - 'B': {'sdtype': 'numerical'} - }, - 'primary_key': 'id' + return MultiTableMetadata.load_from_dict({ + 'tables': { + 'parent': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'numerical'}, }, - 'child': { - 'columns': { - 'parent_id': {'sdtype': 'id'}, - 'C': {'sdtype': 'categorical'} - } - } + 'primary_key': 'id', }, - 'relationships': [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id', - 'child_foreign_key': 'parent_id' - } - ] - } - ) + 'child': {'columns': {'parent_id': {'sdtype': 'id'}, 'C': {'sdtype': 'categorical'}}}, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id', + 'child_foreign_key': 'parent_id', + } + ], + }) @pytest.fixture def data(): - parent = pd.DataFrame(data={ - 'id': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - 'B': [0.434, 0.312, 0.212, 0.339, 0.491] - }) + parent = pd.DataFrame( + data={ + 'id': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + 'B': [0.434, 0.312, 0.212, 0.339, 0.491], + } + ) - child = pd.DataFrame(data={ - 'parent_id': [0, 1, 2, 2, 5], - 'C': ['Yes', 'No', 'Maye', 'No', 'No'] - }) + child = pd.DataFrame( + data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maye', 'No', 'No']} + ) - return { - 'parent': parent, - 'child': child - } + return {'parent': parent, 'child': child} def test_drop_unknown_references(metadata, data, capsys): @@ -142,9 +133,7 @@ def test_drop_unknown_references_not_drop_missing_values(metadata, data): data['child'].loc[3, 'parent_id'] = np.nan # Run - cleaned_data = drop_unknown_references( - data, metadata, drop_missing_values=False, verbose=False - ) + cleaned_data = drop_unknown_references(data, metadata, drop_missing_values=False, verbose=False) # Assert pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 35cff9bd9..d088673e0 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,4 +1,5 @@ """Tests for the ``tasks.py`` file.""" + from tasks import _get_minimum_versions @@ -15,7 +16,7 @@ def test_get_minimum_versions(): "pandas>=1.2.0,<2;python_version<'3.10'", "pandas>=1.3.0,<2;python_version>='3.10'", 'humanfriendly>=8.2,<11', - 'pandas @ git+https://github.com/pandas-dev/pandas.git@master' + 'pandas @ git+https://github.com/pandas-dev/pandas.git@master', ] # Run diff --git a/tests/unit/constraints/test_base.py b/tests/unit/constraints/test_base.py index 402136796..146e171c9 100644 --- a/tests/unit/constraints/test_base.py +++ b/tests/unit/constraints/test_base.py @@ -1,4 +1,5 @@ """Tests for the sdv.constraints.base module.""" + import re from unittest.mock import Mock, patch @@ -7,10 +8,18 @@ from copulas.univariate import GaussianUnivariate from sdv.constraints.base import ( - ColumnsModel, Constraint, _get_qualified_name, _module_contains_callable_name, get_subclasses, - import_object) + ColumnsModel, + Constraint, + _get_qualified_name, + _module_contains_callable_name, + get_subclasses, + import_object, +) from sdv.constraints.errors import ( - AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError) + AggregateConstraintsError, + ConstraintMetadataError, + MissingConstraintColumnError, +) from sdv.constraints.tabular import FixedCombinations from sdv.errors import ConstraintsNotMetError @@ -99,6 +108,7 @@ def test_get_subclasses(): Output: - Dict of the subclasses of the class: ``Child`` and ``GrandChild`` classes. """ + # Setup class Parent: pass @@ -113,10 +123,7 @@ class GrandChild(Child): subclasses = get_subclasses(Parent) # Assert - expected_subclasses = { - 'Child': Child, - 'GrandChild': GrandChild - } + expected_subclasses = {'Child': Child, 'GrandChild': GrandChild} assert subclasses == expected_subclasses @@ -157,8 +164,7 @@ def test_import_object_function(): assert imported is import_object -class TestConstraint(): - +class TestConstraint: def test__validate_inputs(self): """Test the ``_validate_inputs`` method. @@ -179,8 +185,11 @@ def test__validate_inputs(self): @patch('sdv.constraints.base.Constraint._validate_metadata_columns') @patch('sdv.constraints.base.Constraint._validate_metadata_specific_to_constraint') def test__validate_metadata( - self, validate_metadata_columns_mock, validate_metadata_specific_to_constraint_mock, - validate_inputs_mock): + self, + validate_metadata_columns_mock, + validate_metadata_specific_to_constraint_mock, + validate_inputs_mock, + ): """Test the ``_validate_metadata`` method. The method should compile the error messages returned from ``_validate_inputs``, @@ -221,9 +230,7 @@ def test_fit(self): - Table data (pandas.DataFrame) """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 3] - }) + table_data = pd.DataFrame({'a': [1, 2, 3]}) instance = Constraint() instance._fit = Mock() instance._validate_data_meets_constraint = Mock() @@ -249,10 +256,7 @@ def test__validate_data_meets_constraints(self): - No error """ # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2], - 'b': [3, 4, 5] - }, index=[0, 1, 2]) + data = pd.DataFrame({'a': [0, 1, 2], 'b': [3, 4, 5]}, index=[0, 1, 2]) constraint = Constraint() constraint.constraint_columns = ['a', 'b'] constraint.is_valid = Mock() @@ -277,10 +281,10 @@ def test__validate_data_meets_constraints_invalid_input(self): - A ``ConstraintsNotMetError`` is thrown """ # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2, 3, 4, 5, 6, 7], - 'b': [3, 4, 5, 6, 7, 8, 9, 10] - }, index=[0, 1, 2, 3, 4, 5, 6, 7]) + data = pd.DataFrame( + {'a': [0, 1, 2, 3, 4, 5, 6, 7], 'b': [3, 4, 5, 6, 7, 8, 9, 10]}, + index=[0, 1, 2, 3, 4, 5, 6, 7], + ) constraint = Constraint() constraint.constraint_columns = ['a', 'b'] is_valid_result = pd.Series([True, False, True, False, False, False, False, False]) @@ -308,10 +312,7 @@ def test__validate_data_meets_constraints_missing_cols(self): - No error """ # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2], - 'b': [3, 4, 5] - }, index=[0, 1, 2]) + data = pd.DataFrame({'a': [0, 1, 2], 'b': [3, 4, 5]}, index=[0, 1, 2]) constraint = Constraint() constraint.constraint_columns = ['a', 'b', 'c'] constraint.is_valid = Mock() @@ -500,9 +501,7 @@ def test_is_valid(self): - Series of ``True`` values (pandas.Series) """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 3] - }) + table_data = pd.DataFrame({'a': [1, 2, 3]}) # Run instance = Constraint() @@ -524,9 +523,7 @@ def test_filter_valid(self): - Table data, with only the valid rows (pandas.DataFrame) """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 3] - }) + table_data = pd.DataFrame({'a': [1, 2, 3]}) constraint_mock = Mock() constraint_mock.is_valid.return_value = pd.Series([True, True, False]) @@ -535,9 +532,7 @@ def test_filter_valid(self): out = Constraint.filter_valid(constraint_mock, table_data) # Assert - expected_out = pd.DataFrame({ - 'a': [1, 2] - }) + expected_out = pd.DataFrame({'a': [1, 2]}) pd.testing.assert_frame_equal(expected_out, out) def test_filter_valid_with_invalid_index(self): @@ -554,9 +549,7 @@ def test_filter_valid_with_invalid_index(self): - Table data, with only the valid rows (pandas.DataFrame) """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 3] - }) + table_data = pd.DataFrame({'a': [1, 2, 3]}) constraint_mock = Mock() is_valid = pd.Series([True, True, False]) @@ -567,9 +560,7 @@ def test_filter_valid_with_invalid_index(self): out = Constraint.filter_valid(constraint_mock, table_data) # Assert - expected_out = pd.DataFrame({ - 'a': [1, 2] - }) + expected_out = pd.DataFrame({'a': [1, 2]}) pd.testing.assert_frame_equal(expected_out, out) def test_from_dict_fqn(self): @@ -588,7 +579,7 @@ def test_from_dict_fqn(self): 'constraint_class': 'sdv.constraints.tabular.FixedCombinations', 'constraint_parameters': { 'column_names': ['a', 'b'], - } + }, } # Run @@ -614,7 +605,7 @@ def test_from_dict_subclass(self): 'constraint_class': 'FixedCombinations', 'constraint_parameters': { 'column_names': ['a', 'b'], - } + }, } # Run @@ -643,13 +634,12 @@ def test_to_dict(self): 'constraint_class': 'sdv.constraints.tabular.FixedCombinations', 'constraint_parameters': { 'column_names': ['a', 'b'], - } + }, } assert constraint_dict == expected_dict class TestColumnsModel: - def test___init__(self): """Test the ``__init__`` method of ``ColumnsModel``. @@ -701,8 +691,9 @@ def test___init__list(self): @patch('sdv.constraints.base.OneHotEncoder') @patch('sdv.constraints.base.UnixTimestampEncoder') @patch('sdv.constraints.base.BinaryEncoder') - def test__get_hyper_transformer_config(self, mock_binaryencoder, mock_unixtimestampencoder, - mock_onehotencoder, mock_floatformatter): + def test__get_hyper_transformer_config( + self, mock_binaryencoder, mock_unixtimestampencoder, mock_onehotencoder, mock_floatformatter + ): """Test the ``_get_hyper_transformer_config``. Test that the method ``_get_hyper_transformer_config`` returns the expected @@ -743,16 +734,15 @@ def test__get_hyper_transformer_config(self, mock_binaryencoder, mock_unixtimest 'amount': 'numerical', 'name': 'categorical', 'joindate': 'datetime', - 'is_valid': 'boolean' + 'is_valid': 'boolean', }, 'transformers': { 'age': age_float_formatter, 'amount': amount_float_formatter, 'name': mock_onehotencoder, 'joindate': mock_unixtimestampencoder.return_value, - 'is_valid': mock_binaryencoder.return_value - } - + 'is_valid': mock_binaryencoder.return_value, + }, } @patch('sdv.constraints.base.GaussianMultivariate') @@ -783,7 +773,7 @@ def test_fit(self, mock_hyper_transformer, mock_gaussian_multivariate): table_data = pd.DataFrame({ 'age': [1, 2, 3, 4], 'age_when_joined': [5, 6, 7, 8], - 'retirement': ['a', 'b', 'c', 'd'] + 'retirement': ['a', 'b', 'c', 'd'], }) mock_hyper_transformer.return_value.fit_transform.return_value = 'transformed_data' @@ -825,22 +815,15 @@ def test__reject_sample(self): # Setup constraint = Mock() constraint.is_valid.side_effect = lambda x: pd.Series( - [True for _ in range(len(x))], - index=x.index + [True for _ in range(len(x))], index=x.index ) instance = ColumnsModel(constraint, ['a', 'b']) instance._hyper_transformer = Mock() instance._model = Mock() transformed_conditions = [pd.DataFrame([[1], [1], [1], [1], [1]], columns=['b'])] instance._model.sample.side_effect = [ - pd.DataFrame({ - 'a': [1, 1], - 'b': [2, 3] - }), - pd.DataFrame({ - 'a': [1, 1, 1, 1], - 'b': [4, 5, 6, 7] - }) + pd.DataFrame({'a': [1, 1], 'b': [2, 3]}), + pd.DataFrame({'a': [1, 1, 1, 1], 'b': [4, 5, 6, 7]}), ] instance._hyper_transformer.transform.side_effect = transformed_conditions instance._hyper_transformer.reverse_transform = lambda x: x @@ -849,10 +832,7 @@ def test__reject_sample(self): transformed_data = instance._reject_sample(num_rows=5, conditions={'b': 1}) # Assert - expected_result = pd.DataFrame({ - 'a': [1, 1, 1, 1, 1], - 'b': [2, 3, 4, 5, 6] - }) + expected_result = pd.DataFrame({'a': [1, 1, 1, 1, 1], 'b': [2, 3, 4, 5, 6]}) model_calls = instance._model.sample.mock_calls assert len(model_calls) == 2 instance._model.sample.assert_any_call(num_rows=5, conditions={'b': 1}) @@ -860,10 +840,7 @@ def test__reject_sample(self): pd.testing.assert_frame_equal(transformed_data, expected_result) expected_call_1 = pd.DataFrame({'a': [1, 1], 'b': [2, 3]}) - expected_call_2 = pd.DataFrame({ - 'a': [1, 1, 1, 1], - 'b': [4, 5, 6, 7] - }) + expected_call_2 = pd.DataFrame({'a': [1, 1, 1, 1], 'b': [4, 5, 6, 7]}) pd.testing.assert_frame_equal(expected_call_1, constraint.is_valid.call_args_list[0][0][0]) pd.testing.assert_frame_equal(expected_call_2, constraint.is_valid.call_args_list[1][0][0]) @@ -889,18 +866,14 @@ def test__reject_sampling_duplicates_valid_rows(self): # Setup constraint = Mock() constraint.is_valid.side_effect = lambda x: pd.Series( - [True for _ in range(len(x))], - index=x.index, dtype=bool + [True for _ in range(len(x))], index=x.index, dtype=bool ) instance = ColumnsModel(constraint, ['a', 'b']) instance._hyper_transformer = Mock() instance._model = Mock() transformed_conditions = [pd.DataFrame([[1], [1], [1], [1], [1]], columns=['b'])] instance._model.sample.side_effect = [pd.DataFrame()] * 100 + [ - pd.DataFrame({ - 'a': [1, 1], - 'b': [2, 3] - }) + pd.DataFrame({'a': [1, 1], 'b': [2, 3]}) ] instance._hyper_transformer.transform.side_effect = transformed_conditions instance._hyper_transformer.reverse_transform = lambda x: x @@ -909,10 +882,7 @@ def test__reject_sampling_duplicates_valid_rows(self): transformed_data = instance._reject_sample(num_rows=5, conditions={'b': 1}) # Assert - expected_result = pd.DataFrame({ - 'a': [1, 1, 1, 1, 1], - 'b': [2, 3, 2, 3, 2] - }) + expected_result = pd.DataFrame({'a': [1, 1, 1, 1, 1], 'b': [2, 3, 2, 3, 2]}) model_calls = instance._model.sample.mock_calls assert len(model_calls) == 101 instance._model.sample.assert_any_call(num_rows=5, conditions={'b': 1}) @@ -940,18 +910,14 @@ def test__reject_sampling_no_valid_rows(self): # Setup constraint = Mock() constraint.is_valid.side_effect = lambda x: pd.Series( - [False for _ in range(len(x))], - index=x.index, dtype=bool + [False for _ in range(len(x))], index=x.index, dtype=bool ) instance = ColumnsModel(constraint, ['a', 'b']) instance._hyper_transformer = Mock() instance._model = Mock() transformed_conditions = [pd.DataFrame([[1], [1], [1], [1], [1]], columns=['b'])] instance._model.sample.side_effect = [pd.DataFrame()] * 100 + [ - pd.DataFrame({ - 'a': [1, 1], - 'b': [2, 3] - }) + pd.DataFrame({'a': [1, 1], 'b': [2, 3]}) ] instance._hyper_transformer.transform.side_effect = transformed_conditions instance._hyper_transformer.reverse_transform = lambda x: x @@ -989,10 +955,7 @@ def test_sample(self): instance._hyper_transformer.reverse_transform = lambda x: x instance._reject_sample = Mock() instance._reject_sample.side_effect = [ - pd.DataFrame({ - 'a': [1, 1, 1, 1, 1], - 'b': [2, 3, 4, 5, 6] - }) + pd.DataFrame({'a': [1, 1, 1, 1, 1], 'b': [2, 3, 4, 5, 6]}) ] # Run @@ -1000,10 +963,7 @@ def test_sample(self): transformed_data = instance.sample(data) # Assert - expected_result = pd.DataFrame({ - 'a': [1, 1, 1, 1, 1], - 'b': [2, 3, 4, 5, 6] - }) + expected_result = pd.DataFrame({'a': [1, 1, 1, 1, 1], 'b': [2, 3, 4, 5, 6]}) instance._reject_sample.assert_any_call(num_rows=5, conditions={'b': 1}) pd.testing.assert_frame_equal(transformed_data, expected_result) diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 56d714cce..aa50ac941 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -8,14 +8,30 @@ import numpy as np import pandas as pd import pytest +from pandas.api.types import is_float_dtype from sdv.constraints.errors import ( - AggregateConstraintsError, ConstraintMetadataError, FunctionError, InvalidFunctionError, - MissingConstraintColumnError) + AggregateConstraintsError, + ConstraintMetadataError, + FunctionError, + InvalidFunctionError, + MissingConstraintColumnError, +) from sdv.constraints.tabular import ( - FixedCombinations, FixedIncrements, Inequality, Negative, OneHotEncoding, Positive, Range, - ScalarInequality, ScalarRange, Unique, _RecreateCustomConstraint, - _validate_inputs_custom_constraint, create_custom_constraint_class) + FixedCombinations, + FixedIncrements, + Inequality, + Negative, + OneHotEncoding, + Positive, + Range, + ScalarInequality, + ScalarRange, + Unique, + _RecreateCustomConstraint, + _validate_inputs_custom_constraint, + create_custom_constraint_class, +) def dummy_transform_table(table_data): @@ -54,8 +70,7 @@ def dummy_is_valid_column(column_data): return [True] * len(column_data) -class TestCreateCustomConstraint(): - +class TestCreateCustomConstraint: @patch('sdv.constraints.tabular.create_custom_constraint_class') def test___recreatecustomconstraint___call__(self, create_custom_constraint_mock): """Test that custom constraints are recreated properly.""" @@ -71,17 +86,13 @@ class MockClass: create_custom_constraint_mock.return_value = MockClass # Run - recreated_class = class_recreator( - dummy_is_valid, - dummy_transform, - dummy_reverse_transform - ) + recreated_class = class_recreator(dummy_is_valid, dummy_transform, dummy_reverse_transform) # Assert create_custom_constraint_mock.assert_called_once_with( is_valid_fn=dummy_is_valid, transform_fn=dummy_transform, - reverse_transform_fn=dummy_reverse_transform + reverse_transform_fn=dummy_reverse_transform, ) assert isinstance(recreated_class, MockClass) @@ -93,9 +104,7 @@ def test__validate_inputs(self): Raises: - List of ValueErrors """ - err_msg = ( - "Missing required values {'column_names'} in a CustomConstraint constraint." - ) + err_msg = "Missing required values {'column_names'} in a CustomConstraint constraint." # Run / Assert constraint = create_custom_constraint_class(sorted, sorted, sorted) with pytest.raises(AggregateConstraintsError, match=err_msg): @@ -156,7 +165,8 @@ def test__validate_inputs_custom_constraint_transform_not_callable(self): # Run / Assert with pytest.raises(ValueError, match=err_msg): _validate_inputs_custom_constraint( - is_valid_fn=sorted, transform_fn='a', reverse_transform_fn=sorted) + is_valid_fn=sorted, transform_fn='a', reverse_transform_fn=sorted + ) def test__validate_inputs_custom_constraint_reverse_transform_not_callable(self): """Test validation when ``reverse_transform_fn`` is not callable. @@ -172,7 +182,8 @@ def test__validate_inputs_custom_constraint_reverse_transform_not_callable(self) # Run / Assert with pytest.raises(ValueError, match=err_msg): _validate_inputs_custom_constraint( - is_valid_fn=sorted, transform_fn=sorted, reverse_transform_fn=10) + is_valid_fn=sorted, transform_fn=sorted, reverse_transform_fn=10 + ) def test__validate_metadata_columns(self): """Test the ``_validate_metadata_columns`` method. @@ -262,7 +273,7 @@ def test_create_custom_constraint_class_is_valid(self): lambda _, x: pd.Series([True if x_i >= 0 else False for x_i in x['col']]) ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run valid_out = custom_constraint.is_valid(data) @@ -286,7 +297,7 @@ def test_create_custom_constraint_class_is_valid_wrong_shape(self): lambda _, x: pd.Series([True, True, True]) ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run err_msg = '`is_valid_fn` did not produce exactly 1 True/False value for each row.' @@ -309,7 +320,7 @@ def test_create_custom_constraint_class_is_valid_not_a_series(self): lambda _, x: [True if x_i >= 0 else False for x_i in x['col']] ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run err_msg = ( @@ -330,6 +341,7 @@ def test_create_custom_constraint_class_transform(self): Output: - pd.DataFrame of transformed values """ + # Setup def test_is_valid(*_): return pd.Series([True] * 5) @@ -343,13 +355,13 @@ def test_reverse_transform(dummy, data): custom_constraint = create_custom_constraint_class( test_is_valid, test_transform, test_reverse_transform )('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run transform_out = custom_constraint.transform(data) # Assert - expected_out = pd.DataFrame({'col': [100, 1, 0, 9, .25]}) + expected_out = pd.DataFrame({'col': [100, 1, 0, 9, 0.25]}) pd.testing.assert_frame_equal(transform_out, expected_out) def test_create_custom_constraint_class_transform_not_defined(self): @@ -363,11 +375,9 @@ def test_create_custom_constraint_class_transform_not_defined(self): - Original data """ # Setup - custom_constraint = create_custom_constraint_class( - lambda _, x: pd.Series([True] * 5) - ) + custom_constraint = create_custom_constraint_class(lambda _, x: pd.Series([True] * 5)) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run out = custom_constraint.transform(data) @@ -389,10 +399,10 @@ def test_create_custom_constraint_class_transform_wrong_shape(self): custom_constraint = create_custom_constraint_class( lambda _, x: pd.Series([True] * 5), lambda _, x: pd.DataFrame({'col': [1, 2, 3]}), - sorted + sorted, ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run err_msg = 'Transformation did not produce the same number of rows as the original' @@ -412,12 +422,10 @@ def test_create_custom_constraint_class_incorrect_transform(self): """ # Setup custom_constraint = create_custom_constraint_class( - lambda _, x: pd.Series([True] * 5), - lambda _: pd.DataFrame({'col': [1, 2, 3]}), - sorted + lambda _, x: pd.Series([True] * 5), lambda _: pd.DataFrame({'col': [1, 2, 3]}), sorted ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run with pytest.raises(FunctionError): @@ -436,18 +444,16 @@ def test_create_custom_constraint_class_reverse_transform(self): """ # Setup custom_constraint = create_custom_constraint_class( - sorted, - sorted, - lambda _, x: pd.DataFrame({'col': x['col'] ** 2}) + sorted, sorted, lambda _, x: pd.DataFrame({'col': x['col'] ** 2}) ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run transformed_out = custom_constraint.reverse_transform(data) # Assert - expected_out = pd.DataFrame({'col': [100, 1, 0, 9, .25]}) + expected_out = pd.DataFrame({'col': [100, 1, 0, 9, 0.25]}) pd.testing.assert_frame_equal(transformed_out, expected_out) def test_create_custom_constraint_class_reverse_transform_not_defined(self): @@ -462,7 +468,7 @@ def test_create_custom_constraint_class_reverse_transform_not_defined(self): """ # Setup custom_constraint = create_custom_constraint_class(sorted)('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run out = custom_constraint.reverse_transform(data) @@ -483,12 +489,10 @@ def test_create_custom_constraint_class_reverse_transform_wrong_shape(self): """ # Setup custom_constraint = create_custom_constraint_class( - sorted, - sorted, - lambda _, x: pd.DataFrame({'col': [1, 2, 3]}) + sorted, sorted, lambda _, x: pd.DataFrame({'col': [1, 2, 3]}) ) custom_constraint = custom_constraint('col') - data = pd.DataFrame({'col': [-10, 1, 0, 3, -.5]}) + data = pd.DataFrame({'col': [-10, 1, 0, 3, -0.5]}) # Run err_msg = 'Reverse transform did not produce the same number of rows as the original.' @@ -508,9 +512,7 @@ def test_create_custom_constraint_class___reduce__(self): reverse_transfom_fn = Mock() custom_constraint = create_custom_constraint_class( - is_valid_fn, - transform_fn, - reverse_transfom_fn + is_valid_fn, transform_fn, reverse_transfom_fn ) custom_constraint = custom_constraint(['col']) @@ -525,12 +527,11 @@ def test_create_custom_constraint_class___reduce__(self): 'metadata': None, 'column_names': ['col'], 'constraint_columns': ('col',), - 'kwargs': {} + 'kwargs': {}, } -class TestFixedCombinations(): - +class TestFixedCombinations: def test__validate_inputs(self): """Test the ``FixedCombinations._validate_inputs`` method. @@ -601,8 +602,7 @@ def test__validate_metadata_specific_to_constraint(self): # Run FixedCombinations._validate_metadata_specific_to_constraint( - metadata, - column_names=['a', 'b'] + metadata, column_names=['a', 'b'] ) def test__validate_metadata_specific_to_constraint_incorrect_types(self): @@ -618,8 +618,7 @@ def test__validate_metadata_specific_to_constraint_incorrect_types(self): ) with pytest.raises(ConstraintMetadataError, match=error_message): FixedCombinations._validate_metadata_specific_to_constraint( - metadata, - column_names=['a', 'b'] + metadata, column_names=['a', 'b'] ) def test___init__(self): @@ -682,10 +681,7 @@ def test__fit(self): instance._fit(table_data) # Asserts - expected_combinations = pd.DataFrame({ - 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] - }) + expected_combinations = pd.DataFrame({'b': ['d', 'e', 'f'], 'c': ['g', 'h', 'i']}) assert instance._separator == '##' assert instance._joint_column == 'b##c' pd.testing.assert_frame_equal(instance._combinations, expected_combinations) @@ -707,7 +703,7 @@ def test_is_valid_true(self): table_data = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) columns = ['b', 'c'] instance = FixedCombinations(column_names=columns) @@ -736,7 +732,7 @@ def test_is_valid_false(self): table_data = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) columns = ['b', 'c'] instance = FixedCombinations(column_names=columns) @@ -746,7 +742,7 @@ def test_is_valid_false(self): incorrect_table = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['D', 'E', 'F'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) out = instance.is_valid(incorrect_table) @@ -772,7 +768,7 @@ def test_is_valid_non_string_true(self): 'a': ['a', 'b', 'c'], 'b': [1, 2, 3], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) columns = ['b', 'c', 'd'] instance = FixedCombinations(column_names=columns) @@ -802,7 +798,7 @@ def test_is_valid_non_string_false(self): 'a': ['a', 'b', 'c'], 'b': [1, 2, 3], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) columns = ['b', 'c', 'd'] instance = FixedCombinations(column_names=columns) @@ -813,7 +809,7 @@ def test_is_valid_non_string_false(self): 'a': ['a', 'b', 'c'], 'b': [6, 7, 8], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) out = instance.is_valid(incorrect_table) @@ -838,7 +834,7 @@ def test_transform(self): table_data = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) columns = ['b', 'c'] instance = FixedCombinations(column_names=columns) @@ -871,7 +867,7 @@ def test_transform_non_string(self): 'a': ['a', 'b', 'c'], 'b': [1, 2, 3], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) columns = ['b', 'c', 'd'] instance = FixedCombinations(column_names=columns) @@ -901,7 +897,7 @@ def test_transform_not_all_columns_provided(self): table_data = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) columns = ['b', 'c'] instance = FixedCombinations(column_names=columns) @@ -928,7 +924,7 @@ def test_reverse_transform(self): table_data = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) columns = ['b', 'c'] instance = FixedCombinations(column_names=columns) @@ -944,7 +940,7 @@ def test_reverse_transform(self): expected_out = pd.DataFrame({ 'a': ['a', 'b', 'c'], 'b': ['d', 'e', 'f'], - 'c': ['g', 'h', 'i'] + 'c': ['g', 'h', 'i'], }) pd.testing.assert_frame_equal(expected_out, out) @@ -966,7 +962,7 @@ def test_reverse_transform_non_string(self): 'a': ['a', 'b', 'c'], 'b': [1, 2, 3], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) columns = ['b', 'c', 'd'] instance = FixedCombinations(column_names=columns) @@ -983,13 +979,12 @@ def test_reverse_transform_non_string(self): 'a': ['a', 'b', 'c'], 'b': [1, 2, 3], 'c': ['g', 'h', 'i'], - 'd': [2.4, 1.23, 5.6] + 'd': [2.4, 1.23, 5.6], }) pd.testing.assert_frame_equal(expected_out, out) -class TestInequality(): - +class TestInequality: def test__validate_inputs(self): """Test the ``Inequality._validate_inputs`` method. @@ -1051,7 +1046,8 @@ def test__validate_metadata_columns_raises_error(self): ) with pytest.raises(ConstraintMetadataError, match=error_message): Inequality._validate_metadata_columns( - metadata, low_column_name='a', high_column_name='c') + metadata, low_column_name='a', high_column_name='c' + ) def test__validate_metadata_specific_to_constraint_datetime(self): """Test the ``_validate_metadata_specific_to_constraint`` with datetimes. @@ -1068,9 +1064,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self): # Run Inequality._validate_metadata_specific_to_constraint( - metadata, - high_column_name='a', - low_column_name='b' + metadata, high_column_name='a', low_column_name='b' ) def test__validate_metadata_specific_to_constraint_datetime_error(self): @@ -1093,9 +1087,7 @@ def test__validate_metadata_specific_to_constraint_datetime_error(self): ) with pytest.raises(ConstraintMetadataError, match=error_message): Inequality._validate_metadata_specific_to_constraint( - metadata, - high_column_name='a', - low_column_name='b' + metadata, high_column_name='a', low_column_name='b' ) def test__validate_metadata_specific_to_constraint_numerical(self): @@ -1113,9 +1105,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): # Run Inequality._validate_metadata_specific_to_constraint( - metadata, - high_column_name='a', - low_column_name='b' + metadata, high_column_name='a', low_column_name='b' ) def test__validate_metadata_specific_to_constraint_numerical_error(self): @@ -1138,9 +1128,7 @@ def test__validate_metadata_specific_to_constraint_numerical_error(self): ) with pytest.raises(ConstraintMetadataError, match=error_message): Inequality._validate_metadata_specific_to_constraint( - metadata, - high_column_name='a', - low_column_name='b' + metadata, high_column_name='a', low_column_name='b' ) def test__validate_init_inputs_incorrect_column(self): @@ -1238,10 +1226,7 @@ def test__get_is_datetime_incorrect_data(self): # Setup instance = Inequality(low_column_name='a', high_column_name='b') instance.metadata = Mock() - instance.metadata.columns = { - 'a': {'sdtype': 'datetime'}, - 'b': {'sdtype': 'categorical'} - } + instance.metadata.columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'categorical'}} # Run / Assert err_msg = 'Both high and low must be datetime.' @@ -1257,10 +1242,7 @@ def test__validate_columns_exist_incorrect_columns(self): - Table with given data. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4, 5, 6] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4, 5, 6]}) instance = Inequality(low_column_name='a', high_column_name='c') # Run / Assert @@ -1281,10 +1263,7 @@ def test__fit(self): - _dtype should be a list of integer dtypes. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4, 5, 6] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4, 5, 6]}) instance = Inequality(low_column_name='a', high_column_name='b') instance._validate_columns_exist = Mock() instance._get_is_datetime = Mock(return_value='abc') @@ -1314,16 +1293,10 @@ def test__fit_floats(self): - _dtype should be a float dtype. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4., 5., 6.] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4.0, 5.0, 6.0]}) instance = Inequality(low_column_name='a', high_column_name='b') instance.metadata = Mock() - instance.metadata.columns = { - 'a': {'sdtype': 'datetime'}, - 'b': {'sdtype': 'datetime'} - } + instance.metadata.columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}} # Run instance._fit(table_data) @@ -1344,14 +1317,11 @@ def test__fit_datetime(self): # Setup table_data = pd.DataFrame({ 'a': pd.to_datetime(['2020-01-01']), - 'b': pd.to_datetime(['2020-01-02']) + 'b': pd.to_datetime(['2020-01-02']), }) instance = Inequality(low_column_name='a', high_column_name='b') instance.metadata = Mock() - instance.metadata.columns = { - 'a': {'sdtype': 'datetime'}, - 'b': {'sdtype': 'datetime'} - } + instance.metadata.columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}} # Run instance._fit(table_data) @@ -1377,7 +1347,7 @@ def test_is_valid(self): table_data = pd.DataFrame({ 'a': [1, np.nan, 3, 4, None, 6, 8, 0], 'b': [4, 2, 2, 4, np.nan, -6, 10, float('nan')], - 'c': [7, 8, 9, 10, 11, 12, 13, 14] + 'c': [7, 8, 9, 10, 11, 12, 13, 14], }) out = instance.is_valid(table_data) @@ -1403,7 +1373,7 @@ def test_is_valid_strict_boundaries_true(self): table_data = pd.DataFrame({ 'a': [1, np.nan, 3, 4, None, 6, 8, 0], 'b': [4, 2, 2, 4, np.nan, -6, 10, float('nan')], - 'c': [7, 8, 9, 10, 11, 12, 13, 14] + 'c': [7, 8, 9, 10, 11, 12, 13, 14], }) out = instance.is_valid(table_data) @@ -1429,7 +1399,7 @@ def test_is_valid_datetimes(self): table_data = pd.DataFrame({ 'a': [datetime(2020, 5, 17), datetime(2021, 9, 1), np.nan], 'b': [datetime(2020, 5, 18), datetime(2020, 9, 2), datetime(2020, 9, 2)], - 'c': [7, 8, 9] + 'c': [7, 8, 9], }) out = instance.is_valid(table_data) @@ -1448,7 +1418,7 @@ def test_is_valid_datetime_objects(self): table_data = pd.DataFrame({ 'a': ['2020-5-17', '2021-9-1', np.nan], 'b': ['2020-5-18', '2020-9-2', '2020-9-2'], - 'c': [7, 8, 9] + 'c': [7, 8, 9], }) out = instance.is_valid(table_data) @@ -1490,18 +1460,16 @@ def test__transform(self): pd.testing.assert_frame_equal(out, expected_out) def test__transform_with_nans(self): - # Setup instance = Inequality(low_column_name='a', high_column_name='b') instance._diff_column_name = 'a#b' table_data_with_nans = pd.DataFrame({ - 'a': [1, np.nan, 3, np.nan], 'b': [np.nan, 2, 4, np.nan] + 'a': [1, np.nan, 3, np.nan], + 'b': [np.nan, 2, 4, np.nan], }) - table_data_without_nans = pd.DataFrame({ - 'a': [1, 2, 3], 'b': [2, 3, 4] - }) + table_data_without_nans = pd.DataFrame({'a': [1, 2, 3], 'b': [2, 3, 4]}) # Run output_with_nans = instance._transform(table_data_with_nans) @@ -1509,9 +1477,9 @@ def test__transform_with_nans(self): # Assert expected_output_with_nans = pd.DataFrame({ - 'a': [1., 2., 3., 2.], + 'a': [1.0, 2.0, 3.0, 2.0], 'a#b': [np.log(2)] * 4, - 'a#b.nan_component': ['b', 'a', 'None', 'a, b'] + 'a#b.nan_component': ['b', 'a', 'None', 'a, b'], }) expected_output_without_nans = pd.DataFrame({ @@ -1716,7 +1684,7 @@ def test_reverse_transform_datetime(self): expected_out = pd.DataFrame({ 'a': pd.to_datetime(['2020-01-01T00:00:00', '2020-01-02T00:00:00']), 'c': [1, 2], - 'b': pd.to_datetime(['2020-01-01T00:00:01', '2020-01-02T00:00:01']) + 'b': pd.to_datetime(['2020-01-01T00:00:01', '2020-01-02T00:00:01']), }) pd.testing.assert_frame_equal(out, expected_out) @@ -1744,14 +1712,13 @@ def test_reverse_transform_datetime_dtype_is_object(self): expected_out = pd.DataFrame({ 'a': ['2020-01-01T00:00:00', '2020-01-02T00:00:00'], 'c': [1, 2], - 'b': [pd.Timestamp('2020-01-01 00:00:01'), pd.Timestamp('2020-01-02 00:00:01')] + 'b': [pd.Timestamp('2020-01-01 00:00:01'), pd.Timestamp('2020-01-02 00:00:01')], }) expected_out['b'] = expected_out['b'].astype(np.dtype('O')) pd.testing.assert_frame_equal(out, expected_out) -class TestScalarInequality(): - +class TestScalarInequality: def test__validate_inputs(self): """Test the ``ScalarInequality._validate_inputs`` method. @@ -1763,7 +1730,8 @@ def test__validate_inputs(self): # Run / Assert with pytest.raises(AggregateConstraintsError) as error: ScalarInequality._validate_inputs( - not_high_column=None, not_low_column=None, relation='+') + not_high_column=None, not_low_column=None, relation='+' + ) err_msg = ( r'Missing required values {(.*)} in a ScalarInequality constraint.' @@ -1836,10 +1804,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): # Run ScalarInequality._validate_metadata_specific_to_constraint( - metadata, - column_name='a', - relation='>', - value=7 + metadata, column_name='a', relation='>', value=7 ) def test__validate_metadata_specific_to_constraint_numerical_error(self): @@ -1862,10 +1827,7 @@ def test__validate_metadata_specific_to_constraint_numerical_error(self): error_message = "'value' must be an int or float." with pytest.raises(ConstraintMetadataError, match=error_message): ScalarInequality._validate_metadata_specific_to_constraint( - metadata, - column_name='a', - relation='>', - value='7' + metadata, column_name='a', relation='>', value='7' ) @patch('sdv.constraints.tabular.matches_datetime_format') @@ -1889,10 +1851,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self, datetime_forma # Run ScalarInequality._validate_metadata_specific_to_constraint( - metadata, - column_name='a', - relation='>', - value='1/1/2020' + metadata, column_name='a', relation='>', value='1/1/2020' ) @patch('sdv.constraints.tabular.matches_datetime_format') @@ -1921,10 +1880,7 @@ def test__validate_metadata_specific_to_constraint_datetime_error(self, datetime error_message = "'value' must be a datetime string of the right format" with pytest.raises(ConstraintMetadataError, match=error_message): ScalarInequality._validate_metadata_specific_to_constraint( - metadata, - column_name='a', - relation='>', - value='1-1-2020' + metadata, column_name='a', relation='>', value='1-1-2020' ) def test__validate_metadata_specific_to_constraint_bad_type(self): @@ -1954,10 +1910,7 @@ def test__validate_metadata_specific_to_constraint_bad_type(self): ) with pytest.raises(ConstraintMetadataError, match=error_message): ScalarInequality._validate_metadata_specific_to_constraint( - metadata, - column_name='a', - relation='>', - value=7 + metadata, column_name='a', relation='>', value=7 ) def test__validate_init_inputs_incorrect_column(self): @@ -2088,10 +2041,7 @@ def test__validate_columns_exist_incorrect_columns(self): - Table with given data. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4, 5, 6] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4, 5, 6]}) instance = ScalarInequality(column_name='c', value=5, relation='>') # Run / Assert @@ -2112,10 +2062,7 @@ def test__fit(self): - _dtype should be a list of integer dtypes. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4, 5, 6] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4, 5, 6]}) instance = ScalarInequality(column_name='b', value=3, relation='>') instance._validate_columns_exist = Mock() instance._get_is_datetime = Mock(return_value=False) @@ -2140,10 +2087,7 @@ def test__fit_floats(self): - _dtype should be a float dtype. """ # Setup - table_data = pd.DataFrame({ - 'a': [1, 2, 4], - 'b': [4., 5., 6.] - }) + table_data = pd.DataFrame({'a': [1, 2, 4], 'b': [4.0, 5.0, 6.0]}) instance = ScalarInequality(column_name='b', value=10, relation='>') instance.metadata = MagicMock() @@ -2166,13 +2110,15 @@ def test__fit_datetime(self): # Setup table_data = pd.DataFrame({ 'a': pd.to_datetime(['2020-01-01']), - 'b': pd.to_datetime(['2020-01-02']) + 'b': pd.to_datetime(['2020-01-02']), }) instance = ScalarInequality(column_name='b', value='2020-01-01', relation='>') - instance.metadata = Mock(columns={ - 'a': {'sdtype': 'datetime'}, - 'b': {'sdtype': 'datetime'}, - }) + instance.metadata = Mock( + columns={ + 'a': {'sdtype': 'datetime'}, + 'b': {'sdtype': 'datetime'}, + } + ) # Run instance._fit(table_data) @@ -2217,15 +2163,12 @@ def test_is_valid_datetimes(self): - False should be returned for the strictly invalid rows and True for the rest. """ # Setup - instance = ScalarInequality( - column_name='b', - value='8/31/2021', - relation='>=') + instance = ScalarInequality(column_name='b', value='8/31/2021', relation='>=') # Run table_data = pd.DataFrame({ 'b': [datetime(2021, 8, 30), datetime(2021, 8, 31), datetime(2021, 9, 2), np.nan], - 'c': [7, 8, 9, 10] + 'c': [7, 8, 9, 10], }) out = instance.is_valid(table_data) @@ -2240,18 +2183,14 @@ def test_is_valid_datetimes_as_object(self): ``value`` or the row contains nan, even when the ``datetime`` is passed as an object. """ # Setup - instance = ScalarInequality( - column_name='b', - value='8/31/2021', - relation='>=' - ) + instance = ScalarInequality(column_name='b', value='8/31/2021', relation='>=') instance._dtype = np.dtype('O') instance._is_datetime = True # Run table_data = pd.DataFrame({ 'b': ['2021, 8, 30', '2021, 8, 31', '2021, 9, 2', np.nan], - 'c': [7, 8, 9, 10] + 'c': [7, 8, 9, 10], }) out = instance.is_valid(table_data) @@ -2306,10 +2245,7 @@ def test__transform_datetime(self): in the ``column_name``'s place. """ # Setup - instance = ScalarInequality( - column_name='a', - value='2020-01-01T00:00:00', - relation='>') + instance = ScalarInequality(column_name='a', value='2020-01-01T00:00:00', relation='>') instance._diff_column_name = 'a#' instance._is_datetime = True @@ -2423,10 +2359,7 @@ def test_reverse_transform_datetime(self): and the diff column dropped. """ # Setup - instance = ScalarInequality( - column_name='a', - value='2020-01-01T00:00:00', - relation='>=') + instance = ScalarInequality(column_name='a', value='2020-01-01T00:00:00', relation='>=') instance._dtype = np.dtype('=', - 'value': 10 - }, - { - 'constraint_name': 'ScalarInequality', - 'column_name': 'b', - 'relation': '>=', - 'value': 5 - }, - { - 'constraint_name': 'ScalarInequality', - 'column_name': 'c', - 'relation': '>=', - 'value': 5 - }, - { - 'constraint_name': 'ScalarInequality', - 'column_name': 'f', - 'relation': '<=', - 'value': 10 - }, - { - 'constraint_name': 'ScalarInequality', - 'column_name': 'c', - 'relation': '<=', - 'value': 5 - }, - { - 'constraint_name': 'ScalarInequality', - 'column_name': 'd', - 'relation': '<=', - 'value': 5 - } + 'strict_boundaries': True, + }, + {'constraint_name': 'ScalarInequality', 'column_name': 'a', 'relation': '>=', 'value': 10}, + {'constraint_name': 'ScalarInequality', 'column_name': 'b', 'relation': '>=', 'value': 5}, + {'constraint_name': 'ScalarInequality', 'column_name': 'c', 'relation': '>=', 'value': 5}, + {'constraint_name': 'ScalarInequality', 'column_name': 'f', 'relation': '<=', 'value': 10}, + {'constraint_name': 'ScalarInequality', 'column_name': 'c', 'relation': '<=', 'value': 5}, + {'constraint_name': 'ScalarInequality', 'column_name': 'd', 'relation': '<=', 'value': 5}, ] assert len(expected_constraints) == len(new_constraints) @@ -205,15 +169,15 @@ def test__upgrade_constraints_greater_than_error(warnings_mock): 'scalar': None, 'high': ['a', 'b'], 'low': 'c', - 'strict': True + 'strict': True, }, { 'constraint': 'sdv.constraints.tabular.GreaterThan', 'scalar': None, 'high': 'a', 'low': ['b', 'c'], - 'strict': True - } + 'strict': True, + }, ] old_metadata = {'constraints': old_constraints} @@ -230,7 +194,7 @@ def test__upgrade_constraints_greater_than_error(warnings_mock): call( "Unable to upgrade the GreaterThan constraint specified for 'high' 'a' " "and 'low' ['b', 'c']. Manually add Inequality constraints to capture this logic." - ) + ), ]) @@ -260,7 +224,7 @@ def test__upgrade_constraints_between(): 'low_is_scalar': True, 'low': 5, 'high': 10, - 'strict': True + 'strict': True, }, { 'constraint': 'sdv.constraints.tabular.Between', @@ -269,7 +233,7 @@ def test__upgrade_constraints_between(): 'low_is_scalar': False, 'low': 'a', 'high': 'b', - 'strict': True + 'strict': True, }, { 'constraint': 'sdv.constraints.tabular.Between', @@ -277,7 +241,7 @@ def test__upgrade_constraints_between(): 'high_is_scalar': True, 'low_is_scalar': False, 'low': 'a', - 'high': 10 + 'high': 10, }, { 'constraint': 'sdv.constraints.tabular.Between', @@ -285,8 +249,8 @@ def test__upgrade_constraints_between(): 'high_is_scalar': False, 'low_is_scalar': True, 'low': 5, - 'high': 'b' - } + 'high': 'b', + }, ] old_metadata = {'constraints': old_constraints} @@ -300,20 +264,20 @@ def test__upgrade_constraints_between(): 'column_name': 'z', 'low_value': 5, 'high_value': 10, - 'strict_boundaries': True + 'strict_boundaries': True, }, { 'constraint_name': 'Range', 'middle_column_name': 'z', 'low_column_name': 'a', 'high_column_name': 'b', - 'strict_boundaries': True + 'strict_boundaries': True, }, { 'constraint_name': 'Inequality', 'low_column_name': 'a', 'high_column_name': 'z', - 'strict_boundaries': False + 'strict_boundaries': False, }, { 'constraint_name': 'ScalarInequality', @@ -325,14 +289,14 @@ def test__upgrade_constraints_between(): 'constraint_name': 'Inequality', 'low_column_name': 'z', 'high_column_name': 'b', - 'strict_boundaries': False + 'strict_boundaries': False, }, { 'constraint_name': 'ScalarInequality', 'column_name': 'z', 'relation': '>=', 'value': 5, - } + }, ] assert len(expected_constraints) == len(new_constraints) for constraint in expected_constraints: @@ -362,46 +326,14 @@ def test__upgrade_constraints_positive_and_negative(): """ # Setup old_constraints = [ - { - 'constraint': 'sdv.constraints.tabular.Positive', - 'columns': 'a', - 'strict': True - }, - { - 'constraint': 'sdv.constraints.tabular.Positive', - 'columns': ['b', 'c'], - 'strict': True - }, - { - 'constraint': 'sdv.constraints.tabular.Positive', - 'columns': 'd', - 'strict': False - }, - { - 'constraint': 'sdv.constraints.tabular.Positive', - 'columns': ['e', 'f'], - 'strict': False - }, - { - 'constraint': 'sdv.constraints.tabular.Negative', - 'columns': 'a', - 'strict': True - }, - { - 'constraint': 'sdv.constraints.tabular.Negative', - 'columns': ['b', 'c'], - 'strict': True - }, - { - 'constraint': 'sdv.constraints.tabular.Negative', - 'columns': 'd', - 'strict': False - }, - { - 'constraint': 'sdv.constraints.tabular.Negative', - 'columns': ['e', 'f'], - 'strict': False - }, + {'constraint': 'sdv.constraints.tabular.Positive', 'columns': 'a', 'strict': True}, + {'constraint': 'sdv.constraints.tabular.Positive', 'columns': ['b', 'c'], 'strict': True}, + {'constraint': 'sdv.constraints.tabular.Positive', 'columns': 'd', 'strict': False}, + {'constraint': 'sdv.constraints.tabular.Positive', 'columns': ['e', 'f'], 'strict': False}, + {'constraint': 'sdv.constraints.tabular.Negative', 'columns': 'a', 'strict': True}, + {'constraint': 'sdv.constraints.tabular.Negative', 'columns': ['b', 'c'], 'strict': True}, + {'constraint': 'sdv.constraints.tabular.Negative', 'columns': 'd', 'strict': False}, + {'constraint': 'sdv.constraints.tabular.Negative', 'columns': ['e', 'f'], 'strict': False}, ] old_metadata = {'constraints': old_constraints} @@ -410,66 +342,18 @@ def test__upgrade_constraints_positive_and_negative(): # Assert expected_constraints = [ - { - 'constraint_name': 'Positive', - 'column_name': 'a', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Positive', - 'column_name': 'b', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Positive', - 'column_name': 'c', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Positive', - 'column_name': 'd', - 'strict_boundaries': False - }, - { - 'constraint_name': 'Positive', - 'column_name': 'e', - 'strict_boundaries': False - }, - { - 'constraint_name': 'Positive', - 'column_name': 'f', - 'strict_boundaries': False - }, - { - 'constraint_name': 'Negative', - 'column_name': 'a', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Negative', - 'column_name': 'b', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Negative', - 'column_name': 'c', - 'strict_boundaries': True - }, - { - 'constraint_name': 'Negative', - 'column_name': 'd', - 'strict_boundaries': False - }, - { - 'constraint_name': 'Negative', - 'column_name': 'e', - 'strict_boundaries': False - }, - { - 'constraint_name': 'Negative', - 'column_name': 'f', - 'strict_boundaries': False - } + {'constraint_name': 'Positive', 'column_name': 'a', 'strict_boundaries': True}, + {'constraint_name': 'Positive', 'column_name': 'b', 'strict_boundaries': True}, + {'constraint_name': 'Positive', 'column_name': 'c', 'strict_boundaries': True}, + {'constraint_name': 'Positive', 'column_name': 'd', 'strict_boundaries': False}, + {'constraint_name': 'Positive', 'column_name': 'e', 'strict_boundaries': False}, + {'constraint_name': 'Positive', 'column_name': 'f', 'strict_boundaries': False}, + {'constraint_name': 'Negative', 'column_name': 'a', 'strict_boundaries': True}, + {'constraint_name': 'Negative', 'column_name': 'b', 'strict_boundaries': True}, + {'constraint_name': 'Negative', 'column_name': 'c', 'strict_boundaries': True}, + {'constraint_name': 'Negative', 'column_name': 'd', 'strict_boundaries': False}, + {'constraint_name': 'Negative', 'column_name': 'e', 'strict_boundaries': False}, + {'constraint_name': 'Negative', 'column_name': 'f', 'strict_boundaries': False}, ] assert len(expected_constraints) == len(new_constraints) for constraint in expected_constraints: @@ -492,18 +376,9 @@ def test__upgrade_constraints_simple_constraints(): """ # Setup old_constraints = [ - { - 'constraint': 'sdv.constraints.tabular.UniqueCombinations', - 'columns': ['a', 'b'] - }, - { - 'constraint': 'sdv.constraints.tabular.OneHotEncoding', - 'columns': ['c', 'd'] - }, - { - 'constraint': 'sdv.constraints.tabular.Unique', - 'columns': ['e', 'f'] - }, + {'constraint': 'sdv.constraints.tabular.UniqueCombinations', 'columns': ['a', 'b']}, + {'constraint': 'sdv.constraints.tabular.OneHotEncoding', 'columns': ['c', 'd']}, + {'constraint': 'sdv.constraints.tabular.Unique', 'columns': ['e', 'f']}, ] old_metadata = {'constraints': old_constraints} @@ -512,18 +387,9 @@ def test__upgrade_constraints_simple_constraints(): # Assert expected_constraints = [ - { - 'constraint_name': 'FixedCombinations', - 'column_names': ['a', 'b'] - }, - { - 'constraint_name': 'OneHotEncoding', - 'column_names': ['c', 'd'] - }, - { - 'constraint_name': 'Unique', - 'column_names': ['e', 'f'] - }, + {'constraint_name': 'FixedCombinations', 'column_names': ['a', 'b']}, + {'constraint_name': 'OneHotEncoding', 'column_names': ['c', 'd']}, + {'constraint_name': 'Unique', 'column_names': ['e', 'f']}, ] assert len(expected_constraints) == len(new_constraints) for constraint in expected_constraints: @@ -552,18 +418,10 @@ def test__upgrade_constraints_constraint_has_no_upgrade(warnings_mock): """ # Setup old_constraints = [ - { - 'constraint': 'sdv.constraints.tabular.Rounding' - }, - { - 'constraint': 'sdv.constraints.tabular.ColumnFormula' - }, - { - 'constraint': 'sdv.constraints.tabular.CustomConstraint' - }, - { - 'constraint': 'Fake' - } + {'constraint': 'sdv.constraints.tabular.Rounding'}, + {'constraint': 'sdv.constraints.tabular.ColumnFormula'}, + {'constraint': 'sdv.constraints.tabular.CustomConstraint'}, + {'constraint': 'Fake'}, ] old_metadata = {'constraints': old_constraints} @@ -589,7 +447,7 @@ def test__upgrade_constraints_constraint_has_no_upgrade(warnings_mock): call( 'Unable to upgrade the Fake constraint. Please add in the constraint ' 'using the new Constraints API.' - ) + ), ]) @@ -608,49 +466,22 @@ def test_convert_metadata(): # Setup old_metadata = { 'fields': { - 'start_date': { - 'type': 'datetime', - 'format': '%Y-%m-%d' - }, - 'end_date': { - 'type': 'datetime', - 'format': '%Y-%m-%d' - }, - 'salary': { - 'type': 'numerical', - 'subtype': 'integer' - }, - 'duration': { - 'type': 'categorical' - }, - 'student_id': { - 'type': 'id', - 'subtype': 'integer' - }, - 'high_perc': { - 'type': 'numerical', - 'subtype': 'float' - }, - 'placed': { - 'type': 'boolean' - }, - 'ssn': { - 'type': 'categorical', - 'pii': True, - 'pii_category': 'ssn' - }, + 'start_date': {'type': 'datetime', 'format': '%Y-%m-%d'}, + 'end_date': {'type': 'datetime', 'format': '%Y-%m-%d'}, + 'salary': {'type': 'numerical', 'subtype': 'integer'}, + 'duration': {'type': 'categorical'}, + 'student_id': {'type': 'id', 'subtype': 'integer'}, + 'high_perc': {'type': 'numerical', 'subtype': 'float'}, + 'placed': {'type': 'boolean'}, + 'ssn': {'type': 'categorical', 'pii': True, 'pii_category': 'ssn'}, 'credit_card': { 'type': 'categorical', 'pii': True, - 'pii_category': ['credit_card_number', 'visa'] + 'pii_category': ['credit_card_number', 'visa'], }, - 'drivers_license': { - 'type': 'id', - 'subtype': 'string', - 'regex': 'regex' - } + 'drivers_license': {'type': 'id', 'subtype': 'string', 'regex': 'regex'}, }, - 'primary_key': 'student_id' + 'primary_key': 'student_id', } # Run @@ -659,47 +490,19 @@ def test_convert_metadata(): # Assert expected_metadata = { 'columns': { - 'start_date': { - 'sdtype': 'datetime', - 'datetime_format': '%Y-%m-%d' - }, - 'end_date': { - 'sdtype': 'datetime', - 'datetime_format': '%Y-%m-%d' - }, - 'salary': { - 'sdtype': 'numerical', - 'computer_representation': 'Int64' - }, - 'duration': { - 'sdtype': 'categorical' - }, - 'student_id': { - 'sdtype': 'id', - 'regex_format': r'\d{30}' - }, - 'high_perc': { - 'sdtype': 'numerical', - 'computer_representation': 'Float' - }, - 'placed': { - 'sdtype': 'boolean' - }, - 'ssn': { - 'sdtype': 'ssn', - 'pii': True - }, - 'credit_card': { - 'sdtype': 'credit_card_number', - 'pii': True - }, - 'drivers_license': { - 'sdtype': 'id', - 'regex_format': 'regex' - } + 'start_date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'end_date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'salary': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'duration': {'sdtype': 'categorical'}, + 'student_id': {'sdtype': 'id', 'regex_format': r'\d{30}'}, + 'high_perc': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'placed': {'sdtype': 'boolean'}, + 'ssn': {'sdtype': 'ssn', 'pii': True}, + 'credit_card': {'sdtype': 'credit_card_number', 'pii': True}, + 'drivers_license': {'sdtype': 'id', 'regex_format': 'regex'}, }, 'primary_key': 'student_id', - 'alternate_keys': ['drivers_license'] + 'alternate_keys': ['drivers_license'], } assert new_metadata == expected_metadata @@ -723,36 +526,18 @@ def test_convert_metadata_with_constraints(upgrade_constraints_mock): # Setup old_metadata = { 'fields': { - 'salary': { - 'type': 'numerical', - 'subtype': 'integer' - }, - 'student_id': { - 'type': 'id', - 'subtype': 'integer' - }, + 'salary': {'type': 'numerical', 'subtype': 'integer'}, + 'student_id': {'type': 'id', 'subtype': 'integer'}, }, 'primary_key': 'student_id', 'constraints': [ - { - 'constraint': 'sdv.constraints.tabular.UniqueCombinations', - 'columns': ['a', 'b'] - }, - { - 'constraint': 'sdv.constraints.tabular.OneHotEncoding', - 'columns': ['c', 'd'] - } - ] + {'constraint': 'sdv.constraints.tabular.UniqueCombinations', 'columns': ['a', 'b']}, + {'constraint': 'sdv.constraints.tabular.OneHotEncoding', 'columns': ['c', 'd']}, + ], } new_constraints = [ - { - 'constraint_name': 'FixedCombinations', - 'column_names': ['a', 'b'] - }, - { - 'constraint_name': 'OneHotEncoding', - 'column_names': ['c', 'd'] - } + {'constraint_name': 'FixedCombinations', 'column_names': ['a', 'b']}, + {'constraint_name': 'OneHotEncoding', 'column_names': ['c', 'd']}, ] upgrade_constraints_mock.return_value = new_constraints @@ -762,15 +547,9 @@ def test_convert_metadata_with_constraints(upgrade_constraints_mock): # Assert expected_metadata = { 'columns': { - 'salary': { - 'sdtype': 'numerical', - 'computer_representation': 'Int64' - }, - 'student_id': { - 'sdtype': 'id', - 'regex_format': r'\d{30}' - } + 'salary': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'student_id': {'sdtype': 'id', 'regex_format': r'\d{30}'}, }, - 'primary_key': 'student_id' + 'primary_key': 'student_id', } assert new_metadata == expected_metadata diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index de30ae4a8..cccf244f4 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -24,36 +24,33 @@ def get_metadata(self): metadata = {} metadata['tables'] = { 'users': { - 'columns': { - 'id': {'sdtype': 'id'}, - 'country': {'sdtype': 'categorical'} - }, - 'primary_key': 'id' + 'columns': {'id': {'sdtype': 'id'}, 'country': {'sdtype': 'categorical'}}, + 'primary_key': 'id', }, 'payments': { 'columns': { 'payment_id': {'sdtype': 'id'}, 'user_id': {'sdtype': 'id'}, - 'date': {'sdtype': 'datetime'} + 'date': {'sdtype': 'datetime'}, }, - 'primary_key': 'payment_id' + 'primary_key': 'payment_id', }, 'sessions': { 'columns': { 'session_id': {'sdtype': 'id'}, 'user_id': {'sdtype': 'id'}, - 'device': {'sdtype': 'categorical'} + 'device': {'sdtype': 'categorical'}, }, - 'primary_key': 'session_id' + 'primary_key': 'session_id', }, 'transactions': { 'columns': { 'transaction_id': {'sdtype': 'id'}, 'session_id': {'sdtype': 'id'}, - 'timestamp': {'sdtype': 'datetime'} + 'timestamp': {'sdtype': 'datetime'}, }, - 'primary_key': 'transaction_id' - } + 'primary_key': 'transaction_id', + }, } metadata['relationships'] = [ @@ -74,7 +71,7 @@ def get_metadata(self): 'parent_primary_key': 'id', 'child_table_name': 'payments', 'child_foreign_key': 'user_id', - } + }, ] return MultiTableMetadata.load_from_dict(metadata) @@ -198,11 +195,7 @@ def test__validate_missing_relationship_keys_foreign_key(self): ) with pytest.raises(InvalidMetadataError, match=error_msg): MultiTableMetadata._validate_missing_relationship_keys( - instance, - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + instance, parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) def test__validate_missing_relationship_keys_primary_key(self): @@ -247,7 +240,7 @@ def test__validate_missing_relationship_keys_primary_key(self): parent_table_name, parent_primary_key, child_table_name, - child_foreign_key + child_foreign_key, ) def test__validate_no_missing_tables_in_relationship(self): @@ -268,9 +261,7 @@ def test__validate_no_missing_tables_in_relationship(self): error_msg = re.escape("Relationship contains an unknown table {'session'}.") with pytest.raises(InvalidMetadataError, match=error_msg): MultiTableMetadata._validate_no_missing_tables_in_relationship( - 'users', - 'session', - tables + 'users', 'session', tables ) def test__validate_missing_relationship_key_length(self): @@ -298,10 +289,7 @@ def test__validate_missing_relationship_key_length(self): ) with pytest.raises(InvalidMetadataError, match=error_msg): MultiTableMetadata._validate_relationship_key_length( - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) def test__validate_relationship_sdtype(self): @@ -366,11 +354,7 @@ def test__validate_relationship_sdtype(self): ) with pytest.raises(InvalidMetadataError, match=error_msg): MultiTableMetadata._validate_relationship_sdtypes( - instance, - parent_table_name, - parent_primary_key, - child_table_name, - child_foreign_key + instance, parent_table_name, parent_primary_key, child_table_name, child_foreign_key ) def test__validate_relationship_does_not_exist(self): @@ -389,7 +373,7 @@ def test__validate_relationship_does_not_exist(self): 'child_table_name': 'transactions', 'parent_primary_key': 'id', 'child_foreign_key': 'session_id', - } + }, ] # Run and Assert @@ -399,7 +383,7 @@ def test__validate_relationship_does_not_exist(self): parent_table_name='sessions', parent_primary_key='id', child_table_name='transactions', - child_foreign_key='session_id' + child_foreign_key='session_id', ) def test__validate_circular_relationships(self): @@ -421,16 +405,14 @@ def test__validate_circular_relationships(self): relationship. """ # Setup - child_map = { - 'users': {'sessions', 'transactions'}, - 'sessions': {'users', 'transactions'} - } + child_map = {'users': {'sessions', 'transactions'}, 'sessions': {'users', 'transactions'}} parent = 'users' errors = [] # Run MultiTableMetadata()._validate_circular_relationships( - parent, child_map=child_map, errors=errors) + parent, child_map=child_map, errors=errors + ) # Assert assert errors == ['users'] @@ -457,15 +439,11 @@ def test__validate_child_map_circular_relationship(self): # Setup instance = MultiTableMetadata() parent_table = Mock() - instance.tables = { - 'users': parent_table, - 'sessions': Mock(), - 'transactions': Mock() - } + instance.tables = {'users': parent_table, 'sessions': Mock(), 'transactions': Mock()} child_map = { 'users': {'sessions', 'transactions'}, 'sessions': {'users'}, - 'transactions': set() + 'transactions': set(), } # Run / Assert @@ -480,9 +458,9 @@ def test__validate_child_map_circular_relationship(self): 'sdv.metadata.multi_table.MultiTableMetadata._validate_no_missing_tables_in_relationship' ) @patch('sdv.metadata.multi_table.MultiTableMetadata._validate_relationship_key_length') - def test__validate_relationship(self, - mock_validate_relationship_key_length, - mock_validate_no_missing_tables_in_relationship): + def test__validate_relationship( + self, mock_validate_relationship_key_length, mock_validate_no_missing_tables_in_relationship + ): """Test thath the ``_validate_relationship`` method. Test that when calling the ``_validate_relationship`` method, the other validation methods @@ -538,13 +516,17 @@ def test__validate_relationship(self, # Assert mock_validate_no_missing_tables_in_relationship.assert_called_once_with( - 'users', 'sessions', instance.tables.keys()) + 'users', 'sessions', instance.tables.keys() + ) instance._validate_missing_relationship_keys.assert_called_once_with( - 'users', 'id', 'sessions', 'user_id') + 'users', 'id', 'sessions', 'user_id' + ) mock_validate_relationship_key_length.assert_called_once_with( - 'users', 'id', 'sessions', 'user_id') + 'users', 'id', 'sessions', 'user_id' + ) instance._validate_relationship_sdtypes.assert_called_once_with( - 'users', 'id', 'sessions', 'user_id') + 'users', 'id', 'sessions', 'user_id' + ) def test__get_foreign_keys(self): """Test that this method returns the foreign keys for a given table name and child name.""" @@ -566,7 +548,7 @@ def test__get_all_foreign_keys(self): parent_table_name='users', parent_primary_key='id', child_table_name='transactions', - child_foreign_key='user_id' + child_foreign_key='user_id', ) # Run @@ -636,8 +618,9 @@ def test_add_relationship(self): 'child_foreign_key': 'user_id', } ] - instance._validate_child_map_circular_relationship.assert_called_once_with( - {'users': {'sessions'}}) + instance._validate_child_map_circular_relationship.assert_called_once_with({ + 'users': {'sessions'} + }) instance._validate_relationship_does_not_exist.assert_called_once_with( 'users', 'id', 'sessions', 'user_id' ) @@ -646,11 +629,7 @@ def test_add_relationship(self): def test_add_relationship_child_key_is_primary_key(self): """Test that passing a primary key as ``child_foreign_key`` crashes.""" # Setup - table = pd.DataFrame({ - 'pk': [1, 2, 3], - 'col1': [.1, .1, .2], - 'col2': ['a', 'b', 'c'] - }) + table = pd.DataFrame({'pk': [1, 2, 3], 'col1': [0.1, 0.1, 0.2], 'col2': ['a', 'b', 'c']}) metadata = MultiTableMetadata() metadata.detect_table_from_dataframe('table', table) metadata.update_column('table', 'pk', sdtype='id') @@ -700,7 +679,7 @@ def test_remove_relationship(self): instance.tables = { 'users': parent_table, 'sessions': child_table, - 'transactions': alternate_child_table + 'transactions': alternate_child_table, } instance.relationships = [ { @@ -726,7 +705,7 @@ def test_remove_relationship(self): 'child_table_name': 'transactions', 'parent_primary_key': 'session_id', 'child_foreign_key': 'session_id', - } + }, ] # Run @@ -745,7 +724,7 @@ def test_remove_relationship(self): 'child_table_name': 'transactions', 'parent_primary_key': 'session_id', 'child_foreign_key': 'session_id', - } + }, ] assert instance._multi_table_updated is True @@ -794,30 +773,26 @@ def test_remove_primary_key(self, logger_mock): instance = MultiTableMetadata() table = Mock() table.primary_key = 'primary_key' - instance.tables = { - 'table': table, - 'parent': Mock(), - 'child': Mock() - } + instance.tables = {'table': table, 'parent': Mock(), 'child': Mock()} instance.relationships = [ { 'parent_table_name': 'parent', 'child_table_name': 'table', 'parent_primary_key': 'pk', - 'child_foreign_key': 'primary_key' + 'child_foreign_key': 'primary_key', }, { 'parent_table_name': 'table', 'child_table_name': 'child', 'parent_primary_key': 'primary_key', - 'child_foreign_key': 'fk' + 'child_foreign_key': 'fk', }, { 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'pk', - 'child_foreign_key': 'fk' - } + 'child_foreign_key': 'fk', + }, ] # Run @@ -829,7 +804,7 @@ def test_remove_primary_key(self, logger_mock): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'pk', - 'child_foreign_key': 'fk' + 'child_foreign_key': 'fk', } ] table.remove_primary_key.assert_called_once() @@ -847,15 +822,11 @@ def test_remove_primary_key(self, logger_mock): def test__validate_column_relationships_foreign_keys(self): """Test ``_validate_column_relationships_foriegn_keys.""" # Setup - column_relationships = [ - {'type': 'bad_relationship', 'column_names': ['amount', 'owner']} - ] + column_relationships = [{'type': 'bad_relationship', 'column_names': ['amount', 'owner']}] instance = MultiTableMetadata() # Run and Assert - err_msg = ( - "Cannot use foreign keys {'owner'} in column relationship." - ) + err_msg = "Cannot use foreign keys {'owner'} in column relationship." with pytest.raises(InvalidMetadataError, match=err_msg): instance._validate_column_relationships_foreign_keys(column_relationships, ['owner']) @@ -878,7 +849,7 @@ def test_add_column_relationship(self): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'parent_id', - 'child_foreign_key': 'foreign_key' + 'child_foreign_key': 'foreign_key', } ] @@ -892,13 +863,12 @@ def test_add_column_relationship(self): mock_validate_column_relationships.assert_called_with( [ {'type': 'relationship_B', 'column_names': ['col1', 'col2', 'col3']}, - {'type': 'relationship_A', 'column_names': ['colA', 'colB']} + {'type': 'relationship_A', 'column_names': ['colA', 'colB']}, ], - ['foreign_key'] + ['foreign_key'], ) instance.tables['child'].add_column_relationship.assert_called_with( - 'relationship_B', - ['col1', 'col2', 'col3'] + 'relationship_B', ['col1', 'col2', 'col3'] ) def test__validate_single_table(self): @@ -915,11 +885,10 @@ def test__validate_single_table(self): Side Effects: - Errors has been updated with the error message for that column. """ + # Setup def validate_relationship_side_effect(*args, **kwargs): - raise InvalidMetadataError( - 'Cannot use foreign keys in column relationship.' - ) + raise InvalidMetadataError('Cannot use foreign keys in column relationship.') table_accounts = SingleTableMetadata.load_from_dict({ 'columns': { @@ -929,10 +898,8 @@ def validate_relationship_side_effect(*args, **kwargs): 'start_date': {'sdtype': 'datetime'}, 'owner': {'sdtype': 'id'}, }, - 'column_relationships': [ - {'type': 'bad_relationship', 'columns': ['amount', 'owner']} - ], - 'primary_key': 'branches' + 'column_relationships': [{'type': 'bad_relationship', 'columns': ['amount', 'owner']}], + 'primary_key': 'branches', }) instance = Mock() @@ -941,16 +908,13 @@ def validate_relationship_side_effect(*args, **kwargs): instance._validate_column_relationships_foreign_keys = validate_column_relationship_mock users_mock = Mock() users_mock.columns = {} - instance.tables = { - 'accounts': table_accounts, - 'users': users_mock - } + instance.tables = {'accounts': table_accounts, 'users': users_mock} instance.relationships = [ { 'parent_table_name': 'users', 'child_table_name': 'accounts', 'child_foreign_key': 'owner', - 'parent_primary_key': 'id' + 'parent_primary_key': 'id', } ] errors = [] @@ -964,16 +928,17 @@ def validate_relationship_side_effect(*args, **kwargs): 'Keys should be columns that exist in the table.\n' "Relationship has invalid keys {'columns'}." ) - foreign_key_col_relationship_message = ( - 'Cannot use foreign keys in column relationship.' - ) + foreign_key_col_relationship_message = 'Cannot use foreign keys in column relationship.' empty_table_error_message = ( "Table 'users' has 0 columns. Use 'add_column' to specify its columns." ) assert errors == [ - '\n', expected_error_msg, foreign_key_col_relationship_message, - empty_table_error_message, foreign_key_col_relationship_message + '\n', + expected_error_msg, + foreign_key_col_relationship_message, + empty_table_error_message, + foreign_key_col_relationship_message, ] instance.tables['users'].validate.assert_called_once() @@ -1001,22 +966,12 @@ def test__validate_all_tables_connected_connected(self): 'users': Mock(), 'sessions': Mock(), 'transactions': Mock(), - 'accounts': Mock() + 'accounts': Mock(), } relationships = [ - { - 'parent_table_name': 'users', - 'child_table_name': 'sessions' - }, - { - 'parent_table_name': 'users', - 'child_table_name': 'transactions' - }, - { - 'parent_table_name': 'users', - 'child_table_name': 'accounts' - }, - + {'parent_table_name': 'users', 'child_table_name': 'sessions'}, + {'parent_table_name': 'users', 'child_table_name': 'transactions'}, + {'parent_table_name': 'users', 'child_table_name': 'accounts'}, ] parent_map = defaultdict(set) @@ -1058,15 +1013,8 @@ def test__validate_all_tables_connected_not_connected(self): 'accounts': Mock(), } relationships = [ - { - 'parent_table_name': 'users', - 'child_table_name': 'sessions' - }, - { - 'parent_table_name': 'users', - 'child_table_name': 'transactions' - }, - + {'parent_table_name': 'users', 'child_table_name': 'sessions'}, + {'parent_table_name': 'users', 'child_table_name': 'transactions'}, ] parent_map = defaultdict(set) @@ -1111,18 +1059,11 @@ def test__validate_all_tables_connected_multiple_not_connected(self): 'sessions': Mock(), 'transactions': Mock(), 'accounts': Mock(), - 'branches': Mock() + 'branches': Mock(), } relationships = [ - { - 'parent_table_name': 'users', - 'child_table_name': 'sessions' - }, - { - 'parent_table_name': 'users', - 'child_table_name': 'transactions' - }, - + {'parent_table_name': 'users', 'child_table_name': 'sessions'}, + {'parent_table_name': 'users', 'child_table_name': 'transactions'}, ] parent_map = defaultdict(set) @@ -1145,11 +1086,7 @@ def test__validate_all_tables_connected_no_connections(self): """Test ``_validate_all_tables_connected`` when no tables are connected.""" # Setup instance = Mock() - instance.tables = { - 'users': Mock(), - 'sessions': Mock(), - 'transactions': Mock() - } + instance.tables = {'users': Mock(), 'sessions': Mock(), 'transactions': Mock()} parent_map = defaultdict(set) child_map = defaultdict(set) @@ -1235,18 +1172,13 @@ def test__validate_all_tables_connected_raises_errors(self): # Run and Assert with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_all_tables_connected( - instance._get_parent_map(), - instance._get_child_map() + instance._get_parent_map(), instance._get_child_map() ) def test_validate_child_key_is_primary_key(self): """Test it crashes if the child key is a primary key.""" # Setup - table = pd.DataFrame({ - 'pk': [1, 2, 3], - 'col1': [.1, .1, .2], - 'col2': ['a', 'b', 'c'] - }) + table = pd.DataFrame({'pk': [1, 2, 3], 'col1': [0.1, 0.1, 0.2], 'col2': ['a', 'b', 'c']}) metadata = MultiTableMetadata() metadata.detect_table_from_dataframe('table', table) metadata.update_column('table', 'pk', sdtype='id') @@ -1377,16 +1309,16 @@ def test_validate_data_data_does_not_match(self): 'nesreca': pd.DataFrame({ 'id_nesreca': np.arange(10), 'upravna_enota': np.arange(10), - 'nesreca_val': np.arange(10).astype(str) + 'nesreca_val': np.arange(10).astype(str), }), 'oseba': pd.DataFrame({ 'upravna_enota': np.arange(10), 'id_nesreca': np.arange(10), - 'oseba_val': np.arange(10).astype(str) + 'oseba_val': np.arange(10).astype(str), }), 'upravna_enota': pd.DataFrame({ 'id_upravna_enota': np.arange(10), - 'upravna_val': np.arange(10).astype(str) + 'upravna_val': np.arange(10).astype(str), }), } @@ -1414,16 +1346,16 @@ def test_validate_data_missing_foreign_keys(self): 'nesreca': pd.DataFrame({ 'id_nesreca': np.arange(0, 20, 2), 'upravna_enota': np.arange(10), - 'nesreca_val': np.arange(10) + 'nesreca_val': np.arange(10), }), 'oseba': pd.DataFrame({ 'upravna_enota': np.arange(10), 'id_nesreca': np.arange(10), - 'oseba_val': np.arange(10) + 'oseba_val': np.arange(10), }), 'upravna_enota': pd.DataFrame({ 'id_upravna_enota': np.arange(10), - 'upravna_val': np.arange(10) + 'upravna_val': np.arange(10), }), } @@ -1451,26 +1383,23 @@ def test_validate_data_datetime_warning(self): '2022-09-02', '2022-09-16', '2022-08-26', - '2022-08-26' + '2022-08-26', ] data['upravna_enota']['valid_date'] = [ '20220902110443000000', '20220916230356000000', '20220826173917000000', - '20220929111311000000' + '20220929111311000000', ] data['upravna_enota']['datetime'] = pd.to_datetime([ '20220902', '20220916', '20220826', - '20220826' + '20220826', ]) metadata.add_column('upravna_enota', 'warning_date_str', sdtype='datetime') metadata.add_column( - 'upravna_enota', - 'valid_date', - sdtype='datetime', - datetime_format='%Y%m%d%H%M%S%f' + 'upravna_enota', 'valid_date', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f' ) metadata.add_column('upravna_enota', 'datetime', sdtype='datetime') @@ -1479,7 +1408,7 @@ def test_validate_data_datetime_warning(self): 'Table Name': ['upravna_enota'], 'Column Name': ['warning_date_str'], 'sdtype': ['datetime'], - 'datetime_format': [None] + 'datetime_format': [None], }) warning_msg = ( "No 'datetime_format' is present in the metadata for the following columns:\n " @@ -1569,10 +1498,7 @@ def test_to_dict(self): 'name': {'sdtype': 'id'}, } instance = MultiTableMetadata() - instance.tables = { - 'accounts': table_accounts, - 'branches': table_branches - } + instance.tables = {'accounts': table_accounts, 'branches': table_branches} instance.relationships = [ { 'parent_table_name': 'accounts', @@ -1598,7 +1524,7 @@ def test_to_dict(self): 'branches': { 'id': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}, - } + }, }, 'relationships': [ { @@ -1608,7 +1534,7 @@ def test_to_dict(self): 'chil_foreign_key': 'branch_id', } ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } assert result == expected_result @@ -1640,7 +1566,7 @@ def test__set_metadata(self, mock_singletablemetadata): 'branches': { 'id': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}, - } + }, }, 'relationships': [ { @@ -1649,14 +1575,14 @@ def test__set_metadata(self, mock_singletablemetadata): 'child_table_name': 'branches', 'chil_foreign_key': 'branch_id', } - ] + ], } single_table_accounts = object() single_table_branches = object() mock_singletablemetadata.load_from_dict.side_effect = [ single_table_accounts, - single_table_branches + single_table_branches, ] instance = MultiTableMetadata() @@ -1667,7 +1593,7 @@ def test__set_metadata(self, mock_singletablemetadata): # Assert assert instance.tables == { 'accounts': single_table_accounts, - 'branches': single_table_branches + 'branches': single_table_branches, } assert instance.relationships == [ @@ -1711,7 +1637,7 @@ def test_load_from_dict(self, mock_singletablemetadata): 'branches': { 'id': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}, - } + }, }, 'relationships': [ { @@ -1720,14 +1646,14 @@ def test_load_from_dict(self, mock_singletablemetadata): 'child_table_name': 'branches', 'child_foreign_key': 'branch_id', } - ] + ], } single_table_accounts = object() single_table_branches = object() mock_singletablemetadata.load_from_dict.side_effect = [ single_table_accounts, - single_table_branches + single_table_branches, ] # Run @@ -1736,7 +1662,7 @@ def test_load_from_dict(self, mock_singletablemetadata): # Assert assert instance.tables == { 'accounts': single_table_accounts, - 'branches': single_table_branches + 'branches': single_table_branches, } assert instance.relationships == [ @@ -1781,7 +1707,7 @@ def test_load_from_dict_integer(self, mock_singletablemetadata): 'branches': { 1: {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}, - } + }, }, 'relationships': [ { @@ -1790,7 +1716,7 @@ def test_load_from_dict_integer(self, mock_singletablemetadata): 'child_table_name': 'branches', 'child_foreign_key': 1, } - ] + ], } single_table_accounts = { @@ -1806,7 +1732,7 @@ def test_load_from_dict_integer(self, mock_singletablemetadata): } mock_singletablemetadata.load_from_dict.side_effect = [ single_table_accounts, - single_table_branches + single_table_branches, ] # Run @@ -1815,7 +1741,7 @@ def test_load_from_dict_integer(self, mock_singletablemetadata): # Assert assert instance.tables == { 'accounts': single_table_accounts, - 'branches': single_table_branches + 'branches': single_table_branches, } assert instance.relationships == [ @@ -1900,12 +1826,12 @@ def test_visualize_show_relationship_and_details(self, visualize_graph_mock): 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, - 'transactions': expected_transactions_label + 'transactions': expected_transactions_label, } expected_edges = [ ('users', 'sessions', ' user_id → id'), ('sessions', 'transactions', ' session_id → session_id'), - ('users', 'payments', ' user_id → id') + ('users', 'payments', ' user_id → id'), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, None) @@ -1951,19 +1877,20 @@ def test_visualize_show_relationship_and_details_summarized(self, visualize_grap 'users': expected_user_label, 'payments': expected_payments_label, 'sessions': expected_sessions_label, - 'transactions': expected_transactions_label + 'transactions': expected_transactions_label, } expected_edges = [ ('users', 'sessions', ' user_id → id'), ('sessions', 'transactions', ' session_id → session_id'), - ('users', 'payments', ' user_id → id') + ('users', 'payments', ' user_id → id'), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, None) @patch('sdv.metadata.multi_table.warnings') @patch('sdv.metadata.multi_table.visualize_graph') - def test_visualize_show_relationship_and_details_warning(self, visualize_graph_mock, - warnings_mock): + def test_visualize_show_relationship_and_details_warning( + self, visualize_graph_mock, warnings_mock + ): """Test the ``visualize`` method. If both the ``show_relationship_labels`` and ``show_table_details`` parameters are @@ -2004,17 +1931,18 @@ def test_visualize_show_relationship_and_details_warning(self, visualize_graph_m 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, - 'transactions': expected_transactions_label + 'transactions': expected_transactions_label, } expected_edges = [ ('users', 'sessions', ' user_id → id'), ('sessions', 'transactions', ' session_id → session_id'), - ('users', 'payments', ' user_id → id') + ('users', 'payments', ' user_id → id'), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, None) warnings_mock.warn.assert_called_once_with( 'Using True or False for show_table_details is deprecated. Use ' - "show_table_details='full' to show all table details.", FutureWarning + "show_table_details='full' to show all table details.", + FutureWarning, ) @patch('sdv.metadata.multi_table.visualize_graph') @@ -2048,19 +1976,18 @@ def test_visualize_show_relationship_show_table_details_none(self, visualize_gra 'users': 'users', 'payments': 'payments', 'sessions': 'sessions', - 'transactions': 'transactions' + 'transactions': 'transactions', } expected_edges = [ ('users', 'sessions', ' user_id → id'), ('sessions', 'transactions', ' session_id → session_id'), - ('users', 'payments', ' user_id → id') + ('users', 'payments', ' user_id → id'), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, 'output.jpg') @patch('sdv.metadata.multi_table.warnings') @patch('sdv.metadata.multi_table.visualize_graph') - def test_visualize_show_relationship_only_warning(self, visualize_graph_mock, - warnings_mock): + def test_visualize_show_relationship_only_warning(self, visualize_graph_mock, warnings_mock): """Test the ``visualize`` method. If ``show_relationship_labels`` is True but ``show_table_details``is False, @@ -2089,17 +2016,18 @@ def test_visualize_show_relationship_only_warning(self, visualize_graph_mock, 'users': 'users', 'payments': 'payments', 'sessions': 'sessions', - 'transactions': 'transactions' + 'transactions': 'transactions', } expected_edges = [ ('users', 'sessions', ' user_id → id'), ('sessions', 'transactions', ' session_id → session_id'), - ('users', 'payments', ' user_id → id') + ('users', 'payments', ' user_id → id'), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, 'output.jpg') warnings_mock.warn.assert_called_once_with( "Using True or False for 'show_table_details' is deprecated. " - 'Use show_table_details=None to hide table details.', FutureWarning + 'Use show_table_details=None to hide table details.', + FutureWarning, ) @patch('sdv.metadata.multi_table.visualize_graph') @@ -2144,12 +2072,12 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock): 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, - 'transactions': expected_transactions_label + 'transactions': expected_transactions_label, } expected_edges = [ ('users', 'sessions', ''), ('sessions', 'transactions', ''), - ('users', 'payments', '') + ('users', 'payments', ''), ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, 'output.jpg') @@ -2271,10 +2199,7 @@ def test_update_columns_metadata(self): metadata._validate_table_exists = Mock() table = Mock() metadata.tables = {'table': table} - metadata_updates = { - 'col_1': {'sdtype': 'numerical'}, - 'col_2': {'sdtype': 'categorical'} - } + metadata_updates = {'col_1': {'sdtype': 'numerical'}, 'col_2': {'sdtype': 'categorical'}} # Run metadata.update_columns_metadata('table', metadata_updates) @@ -2346,7 +2271,7 @@ def test__detect_relationships(self): 'parent_table_name': 'users', 'child_table_name': 'sessions', 'parent_primary_key': 'user_id', - 'child_foreign_key': 'user_id' + 'child_foreign_key': 'user_id', } ] assert instance.relationships == expected_relationships @@ -2383,8 +2308,7 @@ def test__detect_relationships_circular(self): instance._detect_relationships() # Assert - instance.add_relationship.assert_called_once_with( - 'users', 'sessions', 'user_id', 'user_id') + instance.add_relationship.assert_called_once_with('users', 'sessions', 'user_id', 'user_id') assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'categorical' @patch('sdv.metadata.multi_table.LOGGER') @@ -2485,7 +2409,7 @@ def test_detect_from_csvs(self, tmp_path): # Assert expected_calls = [ call('table1', str(filepath1), None), - call('table2', str(filepath2), None) + call('table2', str(filepath2), None), ] instance.detect_table_from_csv.assert_has_calls(expected_calls, any_order=True) @@ -2595,12 +2519,7 @@ def test_detect_from_dataframes(self): hotels_table = pd.DataFrame() # Run - metadata.detect_from_dataframes( - data={ - 'guests': guests_table, - 'hotels': hotels_table - } - ) + metadata.detect_from_dataframes(data={'guests': guests_table, 'hotels': hotels_table}) # Assert metadata.detect_table_from_dataframe.assert_any_call('guests', guests_table) @@ -2782,7 +2701,8 @@ def test_add_constraint(self): # Assert table.add_constraint.assert_called_once_with( - 'Inequality', low_column_name='a', high_column_name='b') + 'Inequality', low_column_name='a', high_column_name='b' + ) def test_add_constraint_table_does_not_exist(self): """Test the ``add_constraint`` method. @@ -2804,7 +2724,8 @@ def test_add_constraint_table_does_not_exist(self): error_message = re.escape("Unknown table name ('table')") with pytest.raises(InvalidMetadataError, match=error_message): metadata.add_constraint( - 'table', 'Inequality', low_column_name='a', high_column_name='b') + 'table', 'Inequality', low_column_name='a', high_column_name='b' + ) @patch('sdv.metadata.utils.Path') def test_load_from_json_path_does_not_exist(self, mock_path): @@ -2861,16 +2782,12 @@ def test_load_from_json(self, mock_json, mock_path, mock_open): mock_json.load.return_value = { 'tables': { 'table1': { - 'columns': { - 'animals': { - 'type': 'categorical' - } - }, + 'columns': {'animals': {'type': 'categorical'}}, 'primary_key': 'animals', - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } }, - 'relationships': {} + 'relationships': {}, } # Run @@ -2972,47 +2889,30 @@ def test__convert_relationships(self): 'upravna_enota': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'upravna_enota', - 'field': 'id_upravna_enota' - } - }, - 'id_nesreca': { - 'type': 'id', - 'subtype': 'integer' + 'ref': {'table': 'upravna_enota', 'field': 'id_upravna_enota'}, }, + 'id_nesreca': {'type': 'id', 'subtype': 'integer'}, }, - 'primary_key': 'id_nesreca' + 'primary_key': 'id_nesreca', }, 'oseba': { 'fields': { 'upravna_enota': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'upravna_enota', - 'field': 'id_upravna_enota' - } + 'ref': {'table': 'upravna_enota', 'field': 'id_upravna_enota'}, }, 'id_nesreca': { 'type': 'id', 'subtype': 'integer', - 'ref': { - 'table': 'nesreca', - 'field': 'id_nesreca' - } + 'ref': {'table': 'nesreca', 'field': 'id_nesreca'}, }, }, }, 'upravna_enota': { - 'fields': { - 'id_upravna_enota': { - 'type': 'id', - 'subtype': 'integer' - } - }, - 'primary_key': 'id_upravna_enota' - } + 'fields': {'id_upravna_enota': {'type': 'id', 'subtype': 'integer'}}, + 'primary_key': 'id_upravna_enota', + }, } } @@ -3025,20 +2925,20 @@ def test__convert_relationships(self): 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'nesreca', - 'child_foreign_key': 'upravna_enota' + 'child_foreign_key': 'upravna_enota', }, { 'parent_table_name': 'nesreca', 'parent_primary_key': 'id_nesreca', 'child_table_name': 'oseba', - 'child_foreign_key': 'id_nesreca' + 'child_foreign_key': 'id_nesreca', }, { 'parent_table_name': 'upravna_enota', 'parent_primary_key': 'id_upravna_enota', 'child_table_name': 'oseba', - 'child_foreign_key': 'upravna_enota' - } + 'child_foreign_key': 'upravna_enota', + }, ] for relationship in expected: assert relationship in relationships @@ -3048,7 +2948,8 @@ def test__convert_relationships(self): @patch('sdv.metadata.multi_table.convert_metadata') @patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict') def test_upgrade_metadata( - self, from_dict_mock, convert_mock, relationships_mock, read_json_mock): + self, from_dict_mock, convert_mock, relationships_mock, read_json_mock + ): """Test the ``upgrade_metadata`` method. The method should validate that the ``new_filepath`` does not exist, read the old metadata @@ -3079,7 +2980,7 @@ def test_upgrade_metadata( read_json_mock.return_value = { 'tables': { 'table1': {'columns': {'column1': {'type': 'numerical'}}}, - 'table2': {'columns': {'column2': {'type': 'categorical'}}} + 'table2': {'columns': {'column2': {'type': 'categorical'}}}, } } relationships_mock.return_value = [ @@ -3087,7 +2988,7 @@ def test_upgrade_metadata( 'parent_table_name': 'table1', 'parent_primary_key': 'id', 'child_table_name': 'table2', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } ] @@ -3099,27 +3000,27 @@ def test_upgrade_metadata( relationships_mock.assert_called_once_with({ 'tables': { 'table1': {'columns': {'column1': {'type': 'numerical'}}}, - 'table2': {'columns': {'column2': {'type': 'categorical'}}} + 'table2': {'columns': {'column2': {'type': 'categorical'}}}, } }) convert_mock.assert_has_calls([ call({'columns': {'column1': {'type': 'numerical'}}}), - call({'columns': {'column2': {'type': 'categorical'}}}) + call({'columns': {'column2': {'type': 'categorical'}}}), ]) expected_new_metadata = { 'tables': { 'table1': {'columns': {'column1': {'sdtype': 'numerical'}}}, - 'table2': {'columns': {'column2': {'sdtype': 'categorical'}}} + 'table2': {'columns': {'column2': {'sdtype': 'categorical'}}}, }, 'relationships': [ { 'parent_table_name': 'table1', 'parent_primary_key': 'id', 'child_table_name': 'table2', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } from_dict_mock.assert_called_once_with(expected_new_metadata) new_metadata.validate.assert_called_once() @@ -3130,7 +3031,8 @@ def test_upgrade_metadata( @patch('sdv.metadata.multi_table.convert_metadata') @patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict') def test_upgrade_metadata_validate_error( - self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, warnings_mock): + self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, warnings_mock + ): """Test the ``upgrade_metadata`` method. The method should validate that the ``new_filepath`` does not exist, read the old metadata @@ -3170,7 +3072,7 @@ def test_upgrade_metadata_validate_error( expected_new_metadata = { 'tables': {}, 'relationships': [], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } from_dict_mock.assert_called_once_with(expected_new_metadata) new_metadata.validate.assert_called_once() diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index 428f5926f..f871ef679 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -39,36 +39,41 @@ class TestSingleTableMetadata: ( 'age', 'numerical', - { - 'computer_representation': 'Int8', - 'datetime_format': None, - 'pii': True - }, + {'computer_representation': 'Int8', 'datetime_format': None, 'pii': True}, re.escape("Invalid values '(datetime_format, pii)' for numerical column 'age'."), ), ( - 'start_date', 'datetime', {'datetime_format': '%Y-%d', 'pii': True}, - re.escape("Invalid values '(pii)' for datetime column 'start_date'.") + 'start_date', + 'datetime', + {'datetime_format': '%Y-%d', 'pii': True}, + re.escape("Invalid values '(pii)' for datetime column 'start_date'."), ), ( - 'name', 'categorical', + 'name', + 'categorical', {'pii': True, 'ordering': ['a', 'b'], 'ordered': 'numerical_values'}, - re.escape("Invalid values '(ordered, ordering, pii)' for categorical column 'name'.") + re.escape("Invalid values '(ordered, ordering, pii)' for categorical column 'name'."), ), ( - 'synthetic', 'boolean', {'pii': True}, - re.escape("Invalid values '(pii)' for boolean column 'synthetic'.") + 'synthetic', + 'boolean', + {'pii': True}, + re.escape("Invalid values '(pii)' for boolean column 'synthetic'."), ), ( - 'phrase', 'id', {'regex_format': '[A-z]', 'pii': True, 'anonymization': True}, - re.escape("Invalid values '(anonymization, pii)' for id column 'phrase'.") + 'phrase', + 'id', + {'regex_format': '[A-z]', 'pii': True, 'anonymization': True}, + re.escape("Invalid values '(anonymization, pii)' for id column 'phrase'."), ), ( - 'phone', 'phone_number', {'anonymization': True, 'order_by': 'phone_number'}, + 'phone', + 'phone_number', + {'anonymization': True, 'order_by': 'phone_number'}, re.escape( "Invalid values '(anonymization, order_by)' for phone_number column 'phone'." - ) - ) + ), + ), ] # noqa: JS102 def test___init__(self): @@ -113,7 +118,8 @@ def test__validate_numerical_default_and_invalid(self): instance._validate_numerical('age', computer_representation=36) @pytest.mark.parametrize( - 'computer_representation', SingleTableMetadata._NUMERICAL_REPRESENTATIONS) + 'computer_representation', SingleTableMetadata._NUMERICAL_REPRESENTATIONS + ) def test__validate_numerical_computer_representations(self, computer_representation): """Test the ``_validate_numerical`` method. @@ -160,7 +166,8 @@ def test__validate_datetime(self): instance._validate_datetime('start_date', datetime_format='%Y-%m-%d - Synthetic') error_msg = re.escape( - "Invalid datetime format string '%1-%Y-%m-%d-%' for datetime column 'start_date'.") + "Invalid datetime format string '%1-%Y-%m-%d-%' for datetime column 'start_date'." + ) with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_datetime('start_date', datetime_format='%1-%Y-%m-%d-%') @@ -333,8 +340,7 @@ def test__validate_column_invalid_sdtype(self): instance._validate_column_args('column', 'fake_type') error_msg = re.escape( - 'Invalid sdtype: None is not a string. Please use one of the ' - 'supported SDV sdtypes.' + 'Invalid sdtype: None is not a string. Please use one of the ' 'supported SDV sdtypes.' ) with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_column_args('column', None) @@ -369,7 +375,8 @@ def test__validate_column_numerical(self, mock__validate_numerical, mock__valida # Assert mock__validate_kwargs.assert_called_once_with( - 'age', 'numerical', computer_representation='Int8') + 'age', 'numerical', computer_representation='Int8' + ) mock__validate_numerical.assert_called_once_with('age', computer_representation='Int8') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs') @@ -401,8 +408,7 @@ def test__validate_column_categorical(self, mock__validate_categorical, mock__va instance._validate_column_args('name', 'categorical', order=['a', 'b', 'c']) # Assert - mock__validate_kwargs.assert_called_once_with( - 'name', 'categorical', order=['a', 'b', 'c']) + mock__validate_kwargs.assert_called_once_with('name', 'categorical', order=['a', 'b', 'c']) mock__validate_categorical.assert_called_once_with('name', order=['a', 'b', 'c']) @patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs') @@ -495,7 +501,8 @@ def test__validate_column_id(self, mock__validate_id, mock__validate_kwargs): # Assert mock__validate_kwargs.assert_called_once_with( - 'phrase', 'id', regex_format='[A-z0-9]', pii=True) + 'phrase', 'id', regex_format='[A-z0-9]', pii=True + ) mock__validate_id.assert_called_once_with('phrase', regex_format='[A-z0-9]', pii=True) @patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs') @@ -538,10 +545,7 @@ def test_update_column_add_extra_value(self): # Assert assert instance.columns == { - 'a': { - 'sdtype': 'numerical', - 'computer_representation': 'Int64' - } + 'a': {'sdtype': 'numerical', 'computer_representation': 'Int64'} } def test_add_column_column_name_in_columns(self): @@ -567,7 +571,8 @@ def test_add_column_column_name_in_columns(self): # Run / Assert error_msg = re.escape( - "Column name 'age' already exists. Use 'update_column' to update an existing column.") + "Column name 'age' already exists. Use 'update_column' to update an existing column." + ) with pytest.raises(InvalidMetadataError, match=error_msg): instance.add_column('age') @@ -634,10 +639,7 @@ def test_add_column(self): instance.add_column('age', sdtype='numerical', computer_representation='Int8') # Assert - assert instance.columns['age'] == { - 'sdtype': 'numerical', - 'computer_representation': 'Int8' - } + assert instance.columns['age'] == {'sdtype': 'numerical', 'computer_representation': 'Int8'} def test_add_column_other_sdtype(self): """Test ``add_column`` with an ``sdtype`` that isn't in our base ``sdtypes``.. @@ -687,7 +689,8 @@ def test__validate_update_column_kwargs_with_sdtype(self): instance._validate_column_exists.assert_called_once_with('age') expected_kwargs = {'computer_representation': 'Int8'} instance._validate_column_args.assert_called_once_with( - 'age', 'numerical', **expected_kwargs) + 'age', 'numerical', **expected_kwargs + ) def test_update_column(self): """Test the ``update_column`` method.""" @@ -703,15 +706,13 @@ def test_update_column(self): instance._validate_update_column.assert_called_once_with( 'age', sdtype='categorical', order_by='numerical_value' ) - assert instance.columns['age'] == { - 'sdtype': 'categorical', - 'order_by': 'numerical_value' - } + assert instance.columns['age'] == {'sdtype': 'categorical', 'order_by': 'numerical_value'} @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column_args') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column_exists') - def test_update_column_sdtype_in_kwargs(self, - mock__validate_column_exists, mock__validate_column): + def test_update_column_sdtype_in_kwargs( + self, mock__validate_column_exists, mock__validate_column + ): """Test the ``update_column`` method. Test that when calling ``update_column`` with an ``sdtype`` this is being updated as well @@ -736,13 +737,11 @@ def test_update_column_sdtype_in_kwargs(self, instance.update_column('age', sdtype='categorical', order_by='numerical_value') # Assert - assert instance.columns['age'] == { - 'sdtype': 'categorical', - 'order_by': 'numerical_value' - } + assert instance.columns['age'] == {'sdtype': 'categorical', 'order_by': 'numerical_value'} mock__validate_column_exists.assert_called_once_with('age') mock__validate_column.assert_called_once_with( - 'age', 'categorical', order_by='numerical_value') + 'age', 'categorical', order_by='numerical_value' + ) @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column_args') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column_exists') @@ -773,11 +772,12 @@ def test_update_column_no_sdtype(self, mock__validate_column_exists, mock__valid # Assert assert instance.columns['age'] == { 'sdtype': 'numerical', - 'computer_representation': 'Float' + 'computer_representation': 'Float', } mock__validate_column_exists.assert_called_once_with('age') mock__validate_column.assert_called_once_with( - 'age', 'numerical', computer_representation='Float') + 'age', 'numerical', computer_representation='Float' + ) def test_update_columns_sdtype_in_kwargs_error(self): """Test the ``update_columns`` method. @@ -789,8 +789,7 @@ def test_update_columns_sdtype_in_kwargs_error(self): instance = SingleTableMetadata() # Run / Assert - error_msg = re.escape( - "Invalid values '(pii)' for 'numerical' sdtype.") + error_msg = re.escape("Invalid values '(pii)' for 'numerical' sdtype.") with pytest.raises(InvalidMetadataError, match=error_msg): instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True) @@ -806,7 +805,7 @@ def test_update_columns_multiple_erros(self): instance.columns = { 'col_1': {'sdtype': 'country_code'}, 'col_2': {'sdtype': 'numerical'}, - 'col_3': {'sdtype': 'categorical'} + 'col_3': {'sdtype': 'categorical'}, } # Run / Assert @@ -824,10 +823,7 @@ def test_update_columns(self): instance = SingleTableMetadata() instance._validate_update_column = Mock() instance._get_unexpected_kwargs = Mock(return_value=None) - instance.columns = { - 'age': {'sdtype': 'numerical'}, - 'salary': {'sdtype': 'numerical'} - } + instance.columns = {'age': {'sdtype': 'numerical'}, 'salary': {'sdtype': 'numerical'}} # Run instance.update_columns(['age', 'salary'], sdtype='categorical') @@ -836,11 +832,11 @@ def test_update_columns(self): instance._get_unexpected_kwargs.assert_called_once_with('categorical') instance._validate_update_column.assert_has_calls([ call('age', sdtype='categorical'), - call('salary', sdtype='categorical') + call('salary', sdtype='categorical'), ]) assert instance.columns == { 'age': {'sdtype': 'categorical'}, - 'salary': {'sdtype': 'categorical'} + 'salary': {'sdtype': 'categorical'}, } def test_update_columns_kwargs_without_sdtype(self): @@ -850,7 +846,7 @@ def test_update_columns_kwargs_without_sdtype(self): instance.columns = { 'col_1': {'sdtype': 'country_code'}, 'col_2': {'sdtype': 'latitude'}, - 'col_3': {'sdtype': 'longitude'} + 'col_3': {'sdtype': 'longitude'}, } # Run @@ -860,7 +856,7 @@ def test_update_columns_kwargs_without_sdtype(self): assert instance.columns == { 'col_1': {'sdtype': 'country_code', 'pii': True}, 'col_2': {'sdtype': 'latitude', 'pii': True}, - 'col_3': {'sdtype': 'longitude', 'pii': True} + 'col_3': {'sdtype': 'longitude', 'pii': True}, } assert instance._updated is True @@ -869,35 +865,29 @@ def test_update_columns_metadata(self): # Setup instance = SingleTableMetadata() instance._validate_update_column = Mock() - instance.columns = { - 'age': {'sdtype': 'numerical'}, - 'salary': {'sdtype': 'numerical'} - } + instance.columns = {'age': {'sdtype': 'numerical'}, 'salary': {'sdtype': 'numerical'}} # Run instance.update_columns_metadata({ 'age': {'sdtype': 'categorical'}, - 'salary': {'computer_representation': 'Int64'} + 'salary': {'computer_representation': 'Int64'}, }) # Assert instance._validate_update_column.assert_has_calls([ call('age', sdtype='categorical'), - call('salary', computer_representation='Int64') + call('salary', computer_representation='Int64'), ]) assert instance.columns == { 'age': {'sdtype': 'categorical'}, - 'salary': {'sdtype': 'numerical', 'computer_representation': 'Int64'} + 'salary': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, } def test_update_columns_metadata_multiple_error(self): """Test the ``update_columns_metadata`` method with multiple error.""" # Setup instance = SingleTableMetadata() - instance.columns = { - 'age': {'sdtype': 'numerical'}, - 'hours': {'sdtype': 'numerical'} - } + instance.columns = {'age': {'sdtype': 'numerical'}, 'hours': {'sdtype': 'numerical'}} # Run / Assert error_msg = re.escape( @@ -911,7 +901,7 @@ def test_update_columns_metadata_multiple_error(self): instance.update_columns_metadata({ 'age': {'pii': True}, 'hours': {'sdtype': 'categorical', 'datetime_format': '%Y-%m-%d'}, - 'salary': {'sdtype': 'numerical'} + 'salary': {'sdtype': 'numerical'}, }) def test_get_column_names(self): @@ -921,15 +911,14 @@ def test_get_column_names(self): metadata.columns = { 'id': {'sdtype': 'id'}, 'value1': {'sdtype': 'numerical'}, - 'value2': {'sdtype': 'numerical', 'computer_representation': 'Float'} + 'value2': {'sdtype': 'numerical', 'computer_representation': 'Float'}, } # Run matches_no_filter = metadata.get_column_names() matches_numerical = metadata.get_column_names(sdtype='numerical') matches_extra = metadata.get_column_names( - sdtype='numerical', - computer_representation='Float' + sdtype='numerical', computer_representation='Float' ) # Assert @@ -977,18 +966,35 @@ def test__determine_sdtype_for_numbers(self): instance = SingleTableMetadata() data_less_than_5_rows = pd.Series([1, np.nan, 3, 4, 5]) - data_less_than_10_percent_unique_values = pd.Series( - [1, 2, 2, None, 2, 1, 1, 1, 1, 1, 2, 2, np.nan, 2, 1, 1, 2, 1, 2, 2] + data_less_than_10_percent_unique_values = pd.Series([ + 1, + 2, + 2, + None, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + np.nan, + 2, + 1, + 1, + 2, + 1, + 2, + 2, + ]) + large_numerical_series = pd.Series( + [400, 401, 402, 403, 404, 405, 406, 500, 501, 502, 503, 504, 505, 506] * 1000 + ) + + large_categorical_series = pd.Series( + [400, 401, 402, 403, 404, 500, 501, 502, 503, 504] * 1000 ) - large_numerical_series = pd.Series([ - 400, 401, 402, 403, 404, 405, 406, - 500, 501, 502, 503, 504, 505, 506 - ] * 1000) - - large_categorical_series = pd.Series([ - 400, 401, 402, 403, 404, - 500, 501, 502, 503, 504 - ] * 1000) data_all_unique = pd.Series([1, 2, 3, 4, 5, 6]) data_numerical_int = pd.Series([1, np.nan, 3, 4, 5, 6, 7, 8, 9, 10]) data_numerical_float = pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6]) @@ -1093,8 +1099,17 @@ def test__detect_columns(self): 'id': ['id1', 'id2', 'id3', 'id4', 'id5', 'id6', 'id7', 'id8', 'id9', 'id10', 'id11'], 'numerical': [1, 2, 3, 2, 5, 6, 7, 8, 9, 10, 11], 'datetime': [ - '2022-01-01', '2022-02-01', '2022-03-01', '2022-04-01', '2022-05-01', '2022-06-01', - '2022-07-01', '2022-08-01', '2022-09-01', '2022-10-01', '2022-11-01' + '2022-01-01', + '2022-02-01', + '2022-03-01', + '2022-04-01', + '2022-05-01', + '2022-06-01', + '2022-07-01', + '2022-08-01', + '2022-09-01', + '2022-10-01', + '2022-11-01', ], 'alternate_id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'alternate_id_string': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], @@ -1102,8 +1117,17 @@ def test__detect_columns(self): 'bool': [True, False, True, False, True, False, True, False, True, False, True], 'unknown': ['a', 'b', 'c', 'c', 1, 2.2, np.nan, None, 'd', 'e', 'f'], 'first_name': [ - 'John', 'Jane', 'John', 'Jane', 'John', 'Jane', 'John', 'Jane', 'John', - 'Jane', 'John' + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', ], }) @@ -1169,7 +1193,7 @@ def test__detect_columns_with_nans_nones_and_nats(self): 'cat': [None] * 100, 'num': [np.nan] * 100, 'num2': [float('nan')] * 100, - 'date': [pd.NaT] * 100 + 'date': [pd.NaT] * 100, }) stm = SingleTableMetadata() @@ -1274,8 +1298,8 @@ def test_detect_from_dataframe(self, mock_log): 'categorical': ['cat', 'dog', 'cat', np.nan], 'date': pd.to_datetime(['2021-02-02', np.nan, '2021-03-05', '2022-12-09']), 'int': [1, 2, 3, 4], - 'float': [1., 2., 3., 4], - 'bool': [np.nan, True, False, True] + 'float': [1.0, 2.0, 3.0, 4], + 'bool': [np.nan, True, False, True], }) # Run @@ -1287,12 +1311,12 @@ def test_detect_from_dataframe(self, mock_log): 'date': {'sdtype': 'datetime'}, 'int': {'sdtype': 'numerical'}, 'float': {'sdtype': 'numerical'}, - 'bool': {'sdtype': 'categorical'} + 'bool': {'sdtype': 'categorical'}, } expected_log_calls = [ call('Detected metadata:'), - call(json.dumps(instance.to_dict(), indent=4)) + call(json.dumps(instance.to_dict(), indent=4)), ] mock_log.info.assert_has_calls(expected_log_calls) @@ -1306,68 +1330,28 @@ def test_detect_from_dataframe_numerical_columns(self, mock_log): data = pd.DataFrame(values) correct_metadata = { 'columns': { - '1': { - 'sdtype': 'numerical' - }, - '2': { - 'sdtype': 'numerical' - }, - '3': { - 'sdtype': 'numerical' - }, - '4': { - 'sdtype': 'numerical' - }, - '5': { - 'sdtype': 'numerical' - }, - '6': { - 'sdtype': 'numerical' - }, - '7': { - 'sdtype': 'numerical' - }, - '8': { - 'sdtype': 'numerical' - }, - '9': { - 'sdtype': 'numerical' - }, - '10': { - 'sdtype': 'numerical' - }, - '11': { - 'sdtype': 'numerical' - }, - '12': { - 'sdtype': 'numerical' - }, - '13': { - 'sdtype': 'numerical' - }, - '14': { - 'sdtype': 'numerical' - }, - '15': { - 'sdtype': 'numerical' - }, - '16': { - 'sdtype': 'numerical' - }, - '17': { - 'sdtype': 'numerical' - }, - '18': { - 'sdtype': 'numerical' - }, - '19': { - 'sdtype': 'numerical' - }, - '20': { - 'sdtype': 'numerical' - } + '1': {'sdtype': 'numerical'}, + '2': {'sdtype': 'numerical'}, + '3': {'sdtype': 'numerical'}, + '4': {'sdtype': 'numerical'}, + '5': {'sdtype': 'numerical'}, + '6': {'sdtype': 'numerical'}, + '7': {'sdtype': 'numerical'}, + '8': {'sdtype': 'numerical'}, + '9': {'sdtype': 'numerical'}, + '10': {'sdtype': 'numerical'}, + '11': {'sdtype': 'numerical'}, + '12': {'sdtype': 'numerical'}, + '13': {'sdtype': 'numerical'}, + '14': {'sdtype': 'numerical'}, + '15': {'sdtype': 'numerical'}, + '16': {'sdtype': 'numerical'}, + '17': {'sdtype': 'numerical'}, + '18': {'sdtype': 'numerical'}, + '19': {'sdtype': 'numerical'}, + '20': {'sdtype': 'numerical'}, }, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } # Run @@ -1428,8 +1412,8 @@ def test_detect_from_csv(self, mock_log, tmp_path): 'categorical': ['cat', 'dog', 'tiger', np.nan], 'date': pd.to_datetime(['2021-02-02', np.nan, '2021-03-05', '2022-12-09']), 'int': [1, 2, 3, 4], - 'float': [1., 2., 3., 4], - 'bool': [np.nan, True, False, True] + 'float': [1.0, 2.0, 3.0, 4], + 'bool': [np.nan, True, False, True], }) # Run @@ -1443,12 +1427,12 @@ def test_detect_from_csv(self, mock_log, tmp_path): 'date': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'int': {'sdtype': 'numerical'}, 'float': {'sdtype': 'numerical'}, - 'bool': {'sdtype': 'categorical'} + 'bool': {'sdtype': 'categorical'}, } expected_log_calls = [ call('Detected metadata:'), - call(json.dumps(instance.to_dict(), indent=4)) + call(json.dumps(instance.to_dict(), indent=4)), ] mock_log.info.assert_has_calls(expected_log_calls) @@ -1478,8 +1462,8 @@ def test_detect_from_csv_with_kwargs(self, mock_log, tmp_path): 'categorical': ['cat', 'dog', 'tiger', np.nan], 'date': pd.to_datetime(['2021-02-02', np.nan, '2021-03-05', '2022-12-09']), 'int': [1, 2, 3, 4], - 'float': [1., 2., 3., 4], - 'bool': [np.nan, True, False, True] + 'float': [1.0, 2.0, 3.0, 4], + 'bool': [np.nan, True, False, True], }) # Run @@ -1493,12 +1477,12 @@ def test_detect_from_csv_with_kwargs(self, mock_log, tmp_path): 'date': {'sdtype': 'datetime'}, 'int': {'sdtype': 'numerical'}, 'float': {'sdtype': 'numerical'}, - 'bool': {'sdtype': 'categorical'} + 'bool': {'sdtype': 'categorical'}, } expected_log_calls = [ call('Detected metadata:'), - call(json.dumps(instance.to_dict(), indent=4)) + call(json.dumps(instance.to_dict(), indent=4)), ] mock_log.info.assert_has_calls(expected_log_calls) @@ -1568,9 +1552,7 @@ def test_set_primary_key_validation_dtype(self): # Setup instance = SingleTableMetadata() - err_msg = ( - "'primary_key' must be a string." - ) + err_msg = "'primary_key' must be a string." # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_primary_key(1) @@ -1592,8 +1574,7 @@ def test_set_primary_key_validation_columns(self): instance.columns = {'a', 'd'} err_msg = ( - "Unknown primary key values {'b'}." - ' Keys should be columns that exist in the table.' + "Unknown primary key values {'b'}." ' Keys should be columns that exist in the table.' ) # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): @@ -1615,9 +1596,7 @@ def test_set_primary_key_validation_categorical(self): instance.add_column('column2', sdtype='categorical') instance.add_column('column3', sdtype='id') - err_msg = re.escape( - "The primary_keys ['column1'] must be type 'id' or another PII type." - ) + err_msg = re.escape("The primary_keys ['column1'] must be type 'id' or another PII type.") # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_primary_key('column1') @@ -1712,7 +1691,7 @@ def test_set_primary_key_in_alternate_keys_warning(self, warning_mock): ) warning_mock.warn.assert_has_calls([ call(alternate_key_warning_msg), - call(primary_key_warning_msg) + call(primary_key_warning_msg), ]) assert instance.primary_key == 'column1' assert instance.alternate_keys == ['column2'] @@ -1751,8 +1730,7 @@ def test_set_sequence_key_validation_columns(self): instance.columns = {'a', 'd'} err_msg = ( - "Unknown sequence key values {'b'}." - ' Keys should be columns that exist in the table.' + "Unknown sequence key values {'b'}." ' Keys should be columns that exist in the table.' ) # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): @@ -1774,9 +1752,7 @@ def test_set_sequence_key_validation_categorical(self): instance.add_column('column2', sdtype='categorical') instance.add_column('column3', sdtype='id') - err_msg = re.escape( - "The sequence_keys ['column1'] must be type 'id' or another PII type." - ) + err_msg = re.escape("The sequence_keys ['column1'] must be type 'id' or another PII type.") # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_sequence_key('column1') @@ -1920,7 +1896,7 @@ def test_add_alternate_keys(self): instance.columns = { 'column1': {'sdtype': 'id'}, 'column2': {'sdtype': 'id'}, - 'column3': {'sdtype': 'id'} + 'column3': {'sdtype': 'id'}, } # Run @@ -1937,7 +1913,7 @@ def test_add_alternate_keys_duplicate(self, warnings_mock): instance.columns = { 'column1': {'sdtype': 'id'}, 'column2': {'sdtype': 'id'}, - 'column3': {'sdtype': 'id'} + 'column3': {'sdtype': 'id'}, } instance.alternate_keys = ['column3'] @@ -1994,10 +1970,7 @@ def test_set_sequence_index_column_not_numerical_or_datetime(self): """Test that the method errors if the column is not numerical or datetime.""" # Setup instance = SingleTableMetadata() - instance.columns = { - 'a': {'sdtype': 'numerical'}, - 'd': {'sdtype': 'categorical'} - } + instance.columns = {'a': {'sdtype': 'numerical'}, 'd': {'sdtype': 'categorical'}} # Run / Assert error_message = "The sequence_index must be of type 'datetime' or 'numerical'." @@ -2036,17 +2009,12 @@ def test__validate_column_relationship(self): # Setup instance = SingleTableMetadata() mock_relationship_validation = Mock() - instance._COLUMN_RELATIONSHIP_TYPES = { - 'mock_relationship': mock_relationship_validation - } - relationship = { - 'type': 'mock_relationship', - 'column_names': ['a', 'b'] - } + instance._COLUMN_RELATIONSHIP_TYPES = {'mock_relationship': mock_relationship_validation} + relationship = {'type': 'mock_relationship', 'column_names': ['a', 'b']} instance.columns = { 'a': {'sdtype': 'categorical'}, 'b': {'sdtype': 'numerical'}, - 'c': {'sdtype': 'datetime'} + 'c': {'sdtype': 'datetime'}, } # Run @@ -2057,21 +2025,14 @@ def test__validate_column_relationship(self): 'a': 'categorical', 'b': 'numerical', } - mock_relationship_validation.assert_called_once_with( - expected_columns_to_sdtypes - ) + mock_relationship_validation.assert_called_once_with(expected_columns_to_sdtypes) def test__validate_column_relationship_bad_relationship_type(self): """Test validation fails for an unknown relationship type.""" # Setup instance = SingleTableMetadata() - instance._COLUMN_RELATIONSHIP_TYPES = { - 'mock_relationship': Mock() - } - relationship = { - 'type': 'bad_relationship_type', - 'column_names': ['a', 'b'] - } + instance._COLUMN_RELATIONSHIP_TYPES = {'mock_relationship': Mock()} + relationship = {'type': 'bad_relationship_type', 'column_names': ['a', 'b']} # Run and Assert msg = re.escape( @@ -2083,6 +2044,7 @@ def test__validate_column_relationship_bad_relationship_type(self): def test__validate_column_relationship_bad_columns(self): """Test validation fails for invalid columns.""" + # Setup def validation_side_effect(*args, **kwargs): raise InvalidMetadataError("Columns ['a', 'b'] have unsupported sdtype.") @@ -2090,13 +2052,8 @@ def validation_side_effect(*args, **kwargs): instance = SingleTableMetadata() mock_relationship_validation = Mock() mock_relationship_validation.side_effect = validation_side_effect - instance._COLUMN_RELATIONSHIP_TYPES = { - 'mock_relationship': mock_relationship_validation - } - relationship = { - 'type': 'mock_relationship', - 'column_names': ['a', 'b', 'c', 'x'] - } + instance._COLUMN_RELATIONSHIP_TYPES = {'mock_relationship': mock_relationship_validation} + relationship = {'type': 'mock_relationship', 'column_names': ['a', 'b', 'c', 'x']} instance.columns = { 'a': {'sdtype': 'id'}, 'b': {'sdtype': 'categorical'}, @@ -2114,15 +2071,8 @@ def validation_side_effect(*args, **kwargs): instance._validate_column_relationship(relationship) # Assert - expected_columns_to_sdtypes = { - 'a': 'id', - 'b': 'categorical', - 'c': 'numerical', - 'x': None - } - mock_relationship_validation.assert_called_once_with( - expected_columns_to_sdtypes - ) + expected_columns_to_sdtypes = {'a': 'id', 'b': 'categorical', 'c': 'numerical', 'x': None} + mock_relationship_validation.assert_called_once_with(expected_columns_to_sdtypes) def test__validate_column_relationship_with_other_relationships(self): """Test ``_validate_column_relationship_with_others``.""" @@ -2131,26 +2081,19 @@ def test__validate_column_relationship_with_other_relationships(self): column_relationships = [ {'type': 'relationship_one', 'column_names': ['a', 'b']}, ] - relationship_valid = { - 'type': 'relationship_two', - 'column_names': ['c', 'd'] - } - relationship_invalid = { - 'type': 'relationship_two', - 'column_names': ['b', 'e'] - } + relationship_valid = {'type': 'relationship_two', 'column_names': ['c', 'd']} + relationship_invalid = {'type': 'relationship_two', 'column_names': ['b', 'e']} # Run and Assert - instance._validate_column_relationship_with_others( - relationship_valid, column_relationships - ) + instance._validate_column_relationship_with_others(relationship_valid, column_relationships) expected_message = re.escape( "Columns 'b' is already part of a relationship of type" " 'relationship_one'. Columns cannot be part of multiple relationships." ) with pytest.raises(InvalidMetadataError, match=expected_message): instance._validate_column_relationship_with_others( - relationship_invalid, column_relationships) + relationship_invalid, column_relationships + ) def test__validate_all_column_relationships(self): """Test ``_validate_all_column_relationships`` method.""" @@ -2160,9 +2103,7 @@ def test__validate_all_column_relationships(self): instance._validate_column_relationship = mock_validate_relationship relationship_one = {'type': 'relationship_one', 'column_names': ['a', 'b']} relationship_two = {'type': 'relationship_two', 'column_names': ['c', 'd']} - column_relationships = [ - relationship_one, relationship_two - ] + column_relationships = [relationship_one, relationship_two] # Run instance._validate_all_column_relationships(column_relationships) @@ -2170,7 +2111,7 @@ def test__validate_all_column_relationships(self): # Assert mock_validate_relationship.assert_has_calls([ call(relationship_one), - call(relationship_two) + call(relationship_two), ]) def test__validate_all_column_relationships_invalid_relationship_structure(self): @@ -2181,13 +2122,11 @@ def test__validate_all_column_relationships_invalid_relationship_structure(self) instance._validate_column_relationship = mock_validate_relationship column_relationships = [ {'type': 'relationship_one', 'column_names': ['a', 'b']}, - {'type': 'relationship_two', 'bad_key': ['c', 'd']} + {'type': 'relationship_two', 'bad_key': ['c', 'd']}, ] # Run and Assert - err_msg = re.escape( - "Relationship has invalid keys {'bad_key'}." - ) + err_msg = re.escape("Relationship has invalid keys {'bad_key'}.") with pytest.raises(InvalidMetadataError, match=err_msg): instance._validate_all_column_relationships(column_relationships) @@ -2199,7 +2138,7 @@ def test__validate_all_column_relationships_repeated_column(self): instance._validate_column_relationship = mock_validate_relationship column_relationships = [ {'type': 'relationship_one', 'column_names': ['a', 'b']}, - {'type': 'relationship_two', 'column_names': ['b', 'c']} + {'type': 'relationship_two', 'column_names': ['b', 'c']}, ] instance.column_relationships = column_relationships # Run and Assert @@ -2212,18 +2151,18 @@ def test__validate_all_column_relationships_repeated_column(self): def test__validate_all_column_relationships_bad_relationship(self): """Test validation fails if individual relationship validation fails.""" + # Setup def mock_relationship_validate(relationship): - raise InvalidMetadataError( - f"Error in '{relationship['type']}' relationship." - ) + raise InvalidMetadataError(f"Error in '{relationship['type']}' relationship.") + instance = SingleTableMetadata() mock_validate_relationship = Mock() mock_validate_relationship.side_effect = mock_relationship_validate instance._validate_column_relationship = mock_validate_relationship column_relationships = [ {'type': 'relationship_one', 'column_names': ['a', 'b']}, - {'type': 'relationship_two', 'column_names': ['c', 'd']} + {'type': 'relationship_two', 'column_names': ['c', 'd']}, ] # Run and Assert @@ -2243,30 +2182,27 @@ def test_add_column_relationships(self): # Run instance.add_column_relationship( - relationship_type='relationship_A', - column_names=['colA', 'colB'] + relationship_type='relationship_A', column_names=['colA', 'colB'] ) instance.add_column_relationship( - relationship_type='relationship_B', - column_names=['col1', 'col2', 'col3'] + relationship_type='relationship_B', column_names=['col1', 'col2', 'col3'] ) # Assert mock_validate_column_relationships.assert_has_calls([ - call([ - {'type': 'relationship_A', 'column_names': ['colA', 'colB']} - ]), + call([{'type': 'relationship_A', 'column_names': ['colA', 'colB']}]), call([ {'type': 'relationship_B', 'column_names': ['col1', 'col2', 'col3']}, - {'type': 'relationship_A', 'column_names': ['colA', 'colB']} - ]) + {'type': 'relationship_A', 'column_names': ['colA', 'colB']}, + ]), ]) assert instance.column_relationships == [ {'type': 'relationship_A', 'column_names': ['colA', 'colB']}, - {'type': 'relationship_B', 'column_names': ['col1', 'col2', 'col3']} + {'type': 'relationship_B', 'column_names': ['col1', 'col2', 'col3']}, ] def test_add_column_relationships_silence_warnings(self): """Test ``add_column_relationship`` silences UserWarnings.""" + # Setup def raise_user_warning(*args, **kwargs): warnings.warn('This is a warning', UserWarning) @@ -2279,8 +2215,7 @@ def raise_user_warning(*args, **kwargs): with warnings.catch_warnings(record=True) as captured_warnings: warnings.simplefilter('always') instance.add_column_relationship( - relationship_type='relationship_A', - column_names=['colA', 'colB'] + relationship_type='relationship_A', column_names=['colA', 'colB'] ) # Assert @@ -2316,27 +2251,27 @@ def test_validate(self): instance._validate_column_args = Mock(side_effect=InvalidMetadataError('column_error')) err_msg = re.escape( - 'The following errors were found in the metadata:' - '\n\ncolumn_error' - '\ncolumn_error' + 'The following errors were found in the metadata:' '\n\ncolumn_error' '\ncolumn_error' ) # Run with pytest.raises(InvalidMetadataError, match=err_msg): instance.validate() # Assert - instance._validate_key.assert_has_calls( - [call(instance.primary_key, 'primary'), call(instance.sequence_key, 'sequence')] - ) - instance._validate_column_args.assert_has_calls( - [call('col1', sdtype='numerical'), call('col2', sdtype='numerical')] - ) + instance._validate_key.assert_has_calls([ + call(instance.primary_key, 'primary'), + call(instance.sequence_key, 'sequence'), + ]) + instance._validate_column_args.assert_has_calls([ + call('col1', sdtype='numerical'), + call('col2', sdtype='numerical'), + ]) instance._validate_alternate_keys.assert_called_once_with(instance.alternate_keys) instance._validate_sequence_index.assert_called_once_with(instance.sequence_index) instance._validate_sequence_index_not_in_sequence_key.assert_called_once() - instance._validate_all_column_relationships.assert_called_once_with( - [{'type': 'relationship_one', 'column_names': ['col1', 'col2']}] - ) + instance._validate_all_column_relationships.assert_called_once_with([ + {'type': 'relationship_one', 'column_names': ['col1', 'col2']} + ]) def test_validate_data_wrong_type(self): """Test error is raised if data is not ``pd.DataFrame``.""" @@ -2405,7 +2340,7 @@ def test_validate_data_keys_with_missing_values(self): 'sk_col3': [0, 1, 2], 'ak_col1': [0, 1, None], 'ak_col2': [0, 1, np.nan], - 'ak_col3': [0, 1, 2] + 'ak_col3': [0, 1, 2], }) metadata = SingleTableMetadata() metadata.add_column('pk_col', sdtype='id') @@ -2437,10 +2372,7 @@ def test_validate_data_keys_with_missing_with_single_sequence_key(self): Test the case with a single sequence key. """ - data = pd.DataFrame({ - 'pk_col': [1], - 'sk_col': [None] - }) + data = pd.DataFrame({'pk_col': [1], 'sk_col': [None]}) metadata = SingleTableMetadata() metadata.add_column('pk_col', sdtype='id') metadata.add_column('sk_col', sdtype='id') @@ -2461,7 +2393,7 @@ def test_validate_data_keys_not_unique(self): 'pk_col': [0, 1, 1, 0, 2], 'ak_col1': [0, 1, 0, 3, 3], 'ak_col2': [2, 2, 2, 2, 2], - 'ak_col3': [0, 1, 2, 3, 4] + 'ak_col3': [0, 1, 2, 3, 4], }) metadata = SingleTableMetadata() metadata.add_column('pk_col', sdtype='id') @@ -2591,22 +2523,22 @@ def test_validate_data_datetime_sdtype(self): '20220916230356000000', '20220826173917000000', '20220826212135000000', - '20220929111311000000' + '20220929111311000000', ], 'date_int': [ 20220902110443000000, 20220916230356000000, 20220826173917000000, 20220826212135000000, - 20220929111311000000 + 20220929111311000000, ], 'bad_date': [ 2022090, 20220916230356000000, 2022, 20220826212135000000, - 20220929111311000000 - ] + 20220929111311000000, + ], }) metadata = SingleTableMetadata() metadata.add_column('date_str', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f') @@ -2634,22 +2566,22 @@ def test_validate_data_datetime_warning(self): '2022-09-16', '2022-08-26', '2022-08-26', - '2022-09-29' + '2022-09-29', ], 'valid_date': [ '20220902110443000000', '20220916230356000000', '20220826173917000000', '20220826212135000000', - '20220929111311000000' + '20220929111311000000', ], 'datetime': pd.to_datetime([ '20220902', '20220916', '20220826', '20220826', - '20220929' - ]) + '20220929', + ]), }) metadata = SingleTableMetadata() metadata.add_column('warning_date_str', sdtype='datetime') @@ -2660,7 +2592,7 @@ def test_validate_data_datetime_warning(self): warning_frame = pd.DataFrame({ 'Column Name': ['warning_date_str'], 'sdtype': ['datetime'], - 'datetime_format': [None] + 'datetime_format': [None], }) warning_msg = ( "No 'datetime_format' is present in the metadata for the following columns:\n" @@ -2687,7 +2619,6 @@ def test_validate_data(self): 'numerical_col': [np.nan, -1, 1.54], 'date_col': [np.nan, '2021-02-10', '2021-05-10'], 'bool_col': [np.nan, True, False], - }) metadata = SingleTableMetadata() metadata.add_column('pk_col', sdtype='id') @@ -2725,7 +2656,7 @@ def test_to_dict(self): # Assert assert result == { 'columns': {'my_column': 'value'}, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } # Ensure that the output object does not alterate the inside object @@ -2750,7 +2681,7 @@ def test_to_dict_missing_attributes(self): # Assert assert result == { 'columns': {'my_column': 'value'}, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } def test_load_from_dict(self): @@ -2762,7 +2693,7 @@ def test_load_from_dict(self): 'alternate_keys': [], 'sequence_key': None, 'sequence_index': None, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } # Run @@ -2790,7 +2721,7 @@ def test_load_from_dict_integer(self): 'alternate_keys': [], 'sequence_key': None, 'sequence_index': None, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } # Run @@ -2856,11 +2787,7 @@ def test_load_from_json_schema_not_present(self, mock_json, mock_path, mock_open mock_path.return_value.exists.return_value = True mock_path.return_value.name = 'filepath.json' mock_json.load.return_value = { - 'columns': { - 'animals': { - 'type': 'categorical' - } - }, + 'columns': {'animals': {'type': 'categorical'}}, 'primary_key': 'animals', } @@ -2898,13 +2825,9 @@ def test_load_from_json(self, mock_json, mock_path, mock_open): mock_path.return_value.exists.return_value = True mock_path.return_value.name = 'filepath.json' mock_json.load.return_value = { - 'columns': { - 'animals': { - 'type': 'categorical' - } - }, + 'columns': {'animals': {'type': 'categorical'}}, 'primary_key': 'animals', - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } # Run @@ -3065,8 +2988,9 @@ def test_visualize_metadata_summarized(self, mock_visualize_graph): mock_visualize_graph.assert_called_once_with(expected_node, [], None) @patch('sdv.metadata.single_table.visualize_graph') - def test_visualize_metadata_with_primary_alternate_and_sequence_keys(self, - mock_visualize_graph): + def test_visualize_metadata_with_primary_alternate_and_sequence_keys( + self, mock_visualize_graph + ): """Test the ``visualize`` method when there are primary, alternate and sequence keys.""" # Setup instance = SingleTableMetadata() @@ -3076,7 +3000,7 @@ def test_visualize_metadata_with_primary_alternate_and_sequence_keys(self, 'age': {'sdtype': 'numerical'}, 'start_date': {'sdtype': 'datetime'}, 'phrase': {'sdtype': 'id'}, - 'passport': {'sdtype': 'id'} + 'passport': {'sdtype': 'id'}, } instance.primary_key = 'passport' instance.alternate_keys = ['phrase', 'name'] @@ -3135,8 +3059,7 @@ def test_upgrade_metadata(self, from_dict_mock, convert_mock, read_json_mock): @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') - def test_upgrade_metadata_multiple_tables( - self, from_dict_mock, convert_mock, read_json_mock): + def test_upgrade_metadata_multiple_tables(self, from_dict_mock, convert_mock, read_json_mock): """Test the ``upgrade_metadata`` method. If the old metadata is in the multi-table format (has 'tables'), but it only contains one @@ -3159,9 +3082,7 @@ def test_upgrade_metadata_multiple_tables( convert_mock.return_value = {} new_metadata = Mock() from_dict_mock.return_value = new_metadata - read_json_mock.return_value = { - 'tables': {'table': {'columns': {}}} - } + read_json_mock.return_value = {'tables': {'table': {'columns': {}}}} # Run SingleTableMetadata.upgrade_metadata('old') @@ -3174,7 +3095,8 @@ def test_upgrade_metadata_multiple_tables( @patch('sdv.metadata.single_table.convert_metadata') @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata_multiple_tables_fails( - self, from_dict_mock, convert_mock, read_json_mock): + self, from_dict_mock, convert_mock, read_json_mock + ): """Test the ``upgrade_metadata`` method. If the old metadata is in the multi-table format (has 'tables'), but contains multiple @@ -3197,9 +3119,7 @@ def test_upgrade_metadata_multiple_tables_fails( convert_mock.return_value = {} new_metadata = Mock() from_dict_mock.return_value = new_metadata - read_json_mock.return_value = { - 'tables': {'table1': {'columns': {}}, 'table2': {}} - } + read_json_mock.return_value = {'tables': {'table1': {'columns': {}}, 'table2': {}}} # Run message = ( @@ -3214,7 +3134,8 @@ def test_upgrade_metadata_multiple_tables_fails( @patch('sdv.metadata.single_table.convert_metadata') @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata_validate_error( - self, from_dict_mock, convert_mock, read_json_mock, warnings_mock): + self, from_dict_mock, convert_mock, read_json_mock, warnings_mock + ): """Test the ``upgrade_metadata`` method. The method should raise a warning with any validation errors after the metadata is diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 43a41f6c5..3c493606f 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -11,8 +11,13 @@ from sdv import version from sdv.errors import ( - ConstraintsNotMetError, InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError, - VersionError) + ConstraintsNotMetError, + InvalidDataError, + NotFittedError, + SamplingError, + SynthesizerInputError, + VersionError, +) from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -23,7 +28,6 @@ class TestBaseMultiTableSynthesizer: - def test__initialize_models(self): """Test that this method initializes the ``self._synthezier`` for each table. @@ -34,11 +38,7 @@ def test__initialize_models(self): locales = ['en_CA', 'fr_CA'] instance = Mock() instance._table_synthesizers = {} - instance._table_parameters = { - 'nesreca': { - 'default_distribution': 'gamma' - } - } + instance._table_parameters = {'nesreca': {'default_distribution': 'gamma'}} instance.locales = locales instance.metadata = get_multi_table_metadata() @@ -49,13 +49,16 @@ def test__initialize_models(self): assert instance._table_synthesizers == { 'nesreca': instance._synthesizer.return_value, 'oseba': instance._synthesizer.return_value, - 'upravna_enota': instance._synthesizer.return_value + 'upravna_enota': instance._synthesizer.return_value, } instance._synthesizer.assert_has_calls([ - call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales), + call( + metadata=instance.metadata.tables['nesreca'], + default_distribution='gamma', + locales=locales, + ), call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales) + call(metadata=instance.metadata.tables['upravna_enota'], locales=locales), ]) def test__get_pbar_args(self): @@ -78,17 +81,11 @@ def test__get_pbar_args_kwargs(self): # Run result = BaseMultiTableSynthesizer._get_pbar_args( - instance, - desc='Process Table', - position=0 + instance, desc='Process Table', position=0 ) # Assert - assert result == { - 'disable': False, - 'desc': 'Process Table', - 'position': 0 - } + assert result == {'disable': False, 'desc': 'Process Table', 'position': 0} @patch('sdv.multi_table.base.print') def test__print(self, mock_print): @@ -105,8 +102,9 @@ def test__print(self, mock_print): @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.BaseMultiTableSynthesizer._check_metadata_updated') - def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_id, - mock_datetime, caplog): + def test___init__( + self, mock_check_metadata_updated, mock_generate_synthesizer_id, mock_datetime, caplog + ): """Test that when creating a new instance this sets the defaults. Test that the metadata object is being stored and also being validated. Afterwards, this @@ -137,7 +135,7 @@ def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_i 'EVENT': 'Instance', 'TIMESTAMP': '2024-04-19 16:20:10.037183', 'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer', - 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) def test__init__column_relationship_warning(self): @@ -211,23 +209,23 @@ def test_set_address_columns(self): 'city_column': {'sdtype': 'city'}, 'parent_key': {'sdtype': 'id'}, }, - 'primary_key': 'parent_key' + 'primary_key': 'parent_key', }, 'other_table': { 'columns': { 'numerical_column': {'sdtype': 'numerical'}, 'child_foreign_key': {'sdtype': 'id'}, } - } + }, }, 'relationships': [ { 'parent_table_name': 'address_table', 'parent_primary_key': 'parent_key', 'child_table_name': 'other_table', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', } - ] + ], }) columns = ('country_column', 'city_column') metadata.validate = Mock() @@ -236,9 +234,7 @@ def test_set_address_columns(self): instance._table_synthesizers['address_table'].set_address_columns = Mock() # Run - instance.set_address_columns( - 'address_table', columns, anonymization_level='street_address' - ) + instance.set_address_columns('address_table', columns, anonymization_level='street_address') # Assert instance._table_synthesizers['address_table'].set_address_columns.assert_called_once_with( @@ -281,8 +277,8 @@ def test_get_table_parameters_empty(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} - } + 'numerical_distributions': {}, + }, } def test_get_table_parameters_has_parameters(self): @@ -301,7 +297,7 @@ def test_get_table_parameters_has_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, } def test_get_parameters(self): @@ -339,7 +335,7 @@ def test_set_table_parameters(self): 'enforce_min_max_values': True, 'locales': ['en_US'], 'enforce_rounding': True, - 'numerical_distributions': {} + 'numerical_distributions': {}, } def test_set_table_parameters_invalid_enforce_min_max_values(self): @@ -442,16 +438,16 @@ def test_validate_data_does_not_match(self): 'nesreca': pd.DataFrame({ 'id_nesreca': np.arange(10), 'upravna_enota': np.arange(10), - 'nesreca_val': np.arange(10).astype(str) + 'nesreca_val': np.arange(10).astype(str), }), 'oseba': pd.DataFrame({ 'upravna_enota': np.arange(10), 'id_nesreca': np.arange(10), - 'oseba_val': np.arange(10).astype(str) + 'oseba_val': np.arange(10).astype(str), }), 'upravna_enota': pd.DataFrame({ 'id_upravna_enota': np.arange(10), - 'upravna_val': np.arange(10).astype(str) + 'upravna_val': np.arange(10).astype(str), }), } @@ -481,16 +477,16 @@ def test_validate_missing_foreign_keys(self): 'nesreca': pd.DataFrame({ 'id_nesreca': np.arange(0, 20, 2), 'upravna_enota': np.arange(10), - 'nesreca_val': np.arange(10) + 'nesreca_val': np.arange(10), }), 'oseba': pd.DataFrame({ 'upravna_enota': np.arange(10), 'id_nesreca': np.arange(10), - 'oseba_val': np.arange(10) + 'oseba_val': np.arange(10), }), 'upravna_enota': pd.DataFrame({ 'id_upravna_enota': np.arange(10), - 'upravna_val': np.arange(10) + 'upravna_val': np.arange(10), }), } instance = BaseMultiTableSynthesizer(metadata) @@ -519,8 +515,8 @@ def test_validate_constraints_not_met(self): 'constraint_parameters': { 'low_column_name': 'nesreca_val', 'high_column_name': 'val', - 'strict_boundaries': True - } + 'strict_boundaries': True, + }, } instance.add_constraints([inequality_constraint]) @@ -548,8 +544,7 @@ def test_validate_table_synthesizers_errors(self): # Run and Assert error_msg = ( - 'The provided data does not match the metadata:\n' - 'Invalid data for PAR synthesizer.' + 'The provided data does not match the metadata:\n' 'Invalid data for PAR synthesizer.' ) with pytest.raises(InvalidDataError, match=error_msg): instance.validate(data) @@ -561,10 +556,7 @@ def test_auto_assign_transformers(self): instance = BaseMultiTableSynthesizer(metadata) table1 = pd.DataFrame({'col1': [1, 2]}) table2 = pd.DataFrame({'col2': [1, 2]}) - data = { - 'nesreca': table1, - 'oseba': table2 - } + data = {'nesreca': table1, 'oseba': table2} instance._table_synthesizers['nesreca'] = Mock() instance._table_synthesizers['oseba'] = Mock() @@ -573,19 +565,18 @@ def test_auto_assign_transformers(self): # Assert instance._table_synthesizers['nesreca'].auto_assign_transformers.assert_called_once_with( - table1) + table1 + ) instance._table_synthesizers['oseba'].auto_assign_transformers.assert_called_once_with( - table2) + table2 + ) def test_auto_assign_transformers_foreign_key_none(self): """Test that each table's foreign key transformers are set to None.""" # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - data = { - 'nesreca': Mock(), - 'oseba': Mock() - } + data = {'nesreca': Mock(), 'oseba': Mock()} instance.validate = Mock() instance.metadata._get_all_foreign_keys = Mock(return_value=['a', 'b']) nesreca_synthesizer = Mock() @@ -622,10 +613,7 @@ def test_auto_assign_transformers_missing_column(self): synthesizer = HMASynthesizer(metadata) table1 = pd.DataFrame({'col1': [1, 2]}) table2 = pd.DataFrame({'col2': [1, 2]}) - data = { - 'nesreca': table1, - 'oseba': table2 - } + data = {'nesreca': table1, 'oseba': table2} # Run error_msg = re.escape( @@ -774,7 +762,7 @@ def test_preprocess(self): instance._table_synthesizers = { 'nesreca': synth_nesreca, 'oseba': synth_oseba, - 'upravna_enota': synth_upravna_enota + 'upravna_enota': synth_upravna_enota, } # Run @@ -784,13 +772,13 @@ def test_preprocess(self): assert result == { 'nesreca': synth_nesreca._preprocess.return_value, 'oseba': synth_oseba._preprocess.return_value, - 'upravna_enota': synth_upravna_enota._preprocess.return_value + 'upravna_enota': synth_upravna_enota._preprocess.return_value, } instance.validate.assert_called_once_with(data) assert instance.metadata._get_all_foreign_keys.call_args_list == [ call('nesreca'), call('oseba'), - call('upravna_enota') + call('upravna_enota'), ] synth_nesreca.auto_assign_transformers.assert_called_once_with(data['nesreca']) @@ -819,32 +807,26 @@ def test_preprocess_int_columns(self): 'columns': { '1': {'sdtype': 'id'}, '2': {'sdtype': 'categorical'}, - 'str': {'sdtype': 'categorical'} - } + 'str': {'sdtype': 'categorical'}, + }, }, 'second_table': { - 'columns': { - '3': {'sdtype': 'id'}, - 'str': {'sdtype': 'categorical'} - } - } + 'columns': {'3': {'sdtype': 'id'}, 'str': {'sdtype': 'categorical'}} + }, }, 'relationships': [ { 'parent_table_name': 'first_table', 'parent_primary_key': '1', 'child_table_name': 'second_table', - 'child_foreign_key': '3' + 'child_foreign_key': '3', } - ] + ], } metadata = MultiTableMetadata.load_from_dict(metadata_dict) instance = BaseMultiTableSynthesizer(metadata) instance.validate = Mock() - instance._table_synthesizers = { - 'first_table': Mock(), - 'second_table': Mock() - } + instance._table_synthesizers = {'first_table': Mock(), 'second_table': Mock()} multi_data = { 'first_table': pd.DataFrame({ 1: ['abc', 'def', 'ghi'], @@ -903,7 +885,7 @@ def test_preprocess_warning(self, mock_warnings): instance._table_synthesizers = { 'nesreca': synth_nesreca, 'oseba': synth_oseba, - 'upravna_enota': synth_upravna_enota + 'upravna_enota': synth_upravna_enota, } instance._fitted = True @@ -914,7 +896,7 @@ def test_preprocess_warning(self, mock_warnings): assert result == { 'nesreca': synth_nesreca._preprocess.return_value, 'oseba': synth_oseba._preprocess.return_value, - 'upravna_enota': synth_upravna_enota._preprocess.return_value + 'upravna_enota': synth_upravna_enota._preprocess.return_value, } instance.validate.assert_called_once_with(data) synth_nesreca._preprocess.assert_called_once_with(data['nesreca']) @@ -937,11 +919,11 @@ def test_fit_processed_data(self, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', ) processed_data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } # Run @@ -959,20 +941,14 @@ def test_fit_processed_data(self, mock_datetime, caplog): 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', 'TOTAL NUMBER OF TABLES': 2, 'TOTAL NUMBER OF ROWS': 6, - 'TOTAL NUMBER OF COLUMNS': 4 + 'TOTAL NUMBER OF COLUMNS': 4, }) def test_fit_processed_data_empty_table(self): """Test attributes are properly set when data is empty and that _fit is not called.""" # Setup - instance = Mock( - _fitted_sdv_version=None, - _fitted_sdv_enterprise_version=None - ) - processed_data = { - 'table1': pd.DataFrame(), - 'table2': pd.DataFrame() - } + instance = Mock(_fitted_sdv_version=None, _fitted_sdv_enterprise_version=None) + processed_data = {'table1': pd.DataFrame(), 'table2': pd.DataFrame()} # Run BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) @@ -986,15 +962,9 @@ def test_fit_processed_data_empty_table(self): def test_fit_processed_data_raises_version_error(self): """Test that fit_processed data will raise a ``VersionError``.""" # Setup - instance = Mock( - _fitted_sdv_version='1.0.0', - _fitted_sdv_enterprise_version=None - ) + instance = Mock(_fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None) instance.metadata = Mock() - processed_data = { - 'table1': pd.DataFrame(), - 'table2': pd.DataFrame() - } + processed_data = {'table1': pd.DataFrame(), 'table2': pd.DataFrame()} # Run and Assert error_msg = ( @@ -1019,12 +989,12 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', ) instance.metadata = Mock() data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } # Run @@ -1043,20 +1013,17 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', 'TOTAL NUMBER OF TABLES': 2, 'TOTAL NUMBER OF ROWS': 6, - 'TOTAL NUMBER OF COLUMNS': 4 + 'TOTAL NUMBER OF COLUMNS': 4, }) def test_fit_raises_version_error(self): """Test that fit will raise a ``VersionError`` if the current version is bigger.""" # Setup - instance = Mock( - _fitted_sdv_version='1.0.0', - _fitted_sdv_enterprise_version=None - ) + instance = Mock(_fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None) instance.metadata = Mock() data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } # Run and Assert @@ -1158,7 +1125,7 @@ def test_sample(self, mock_datetime, caplog): instance._fitted = True data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } instance._sample = Mock(return_value=data) @@ -1178,7 +1145,7 @@ def test_sample(self, mock_datetime, caplog): 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', 'TOTAL NUMBER OF TABLES': 2, 'TOTAL NUMBER OF ROWS': 6, - 'TOTAL NUMBER OF COLUMNS': 4 + 'TOTAL NUMBER OF COLUMNS': 4, }) def test_get_learned_distributions_raises_an_unfitted_error(self): @@ -1255,9 +1222,7 @@ def test_add_constraint_warning(self): instance._fitted = True # Run and Assert - warn_msg = ( - "For these constraints to take effect, please refit the synthesizer using 'fit'." - ) + warn_msg = "For these constraints to take effect, please refit the synthesizer using 'fit'." with pytest.warns(UserWarning, match=warn_msg): instance.add_constraints([]) @@ -1271,18 +1236,12 @@ def test_add_constraints(self): positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', - 'constraint_parameters': { - 'column_name': 'nesreca_val', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'nesreca_val', 'strict_boundaries': True}, } negative_constraint = { 'constraint_class': 'Negative', 'table_name': 'oseba', - 'constraint_parameters': { - 'column_name': 'oseba_val', - 'strict_boundaries': False - } + 'constraint_parameters': {'column_name': 'oseba_val', 'strict_boundaries': False}, } # Run @@ -1291,17 +1250,11 @@ def test_add_constraints(self): # Assert positive_constraint = { 'constraint_class': 'Positive', - 'constraint_parameters': { - 'column_name': 'nesreca_val', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'nesreca_val', 'strict_boundaries': True}, } negative_constraint = { 'constraint_class': 'Negative', - 'constraint_parameters': { - 'column_name': 'oseba_val', - 'strict_boundaries': False - } + 'constraint_parameters': {'column_name': 'oseba_val', 'strict_boundaries': False}, } output_nesreca = instance._table_synthesizers['nesreca'].get_constraints() assert output_nesreca == [positive_constraint] @@ -1319,7 +1272,7 @@ def test_add_constraints_unique(self): 'table_name': 'oseba', 'constraint_parameters': { 'column_name': 'id_nesreca', - } + }, } # Run and Assert @@ -1340,18 +1293,12 @@ def test_get_constraints(self): positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', - 'constraint_parameters': { - 'column_name': 'nesreca_val', - 'strict_boundaries': True - } + 'constraint_parameters': {'column_name': 'nesreca_val', 'strict_boundaries': True}, } negative_constraint = { 'constraint_class': 'Negative', 'table_name': 'oseba', - 'constraint_parameters': { - 'column_name': 'oseba_val', - 'strict_boundaries': False - } + 'constraint_parameters': {'column_name': 'oseba_val', 'strict_boundaries': False}, } constraints = [positive_constraint, negative_constraint] instance.add_constraints(constraints) @@ -1388,15 +1335,12 @@ def test_load_custom_constraint_classes(self): # Run BaseMultiTableSynthesizer.load_custom_constraint_classes( - instance, - 'path/to/file.py', - ['Custom', 'Constr', 'UpperPlus'] + instance, 'path/to/file.py', ['Custom', 'Constr', 'UpperPlus'] ) # Assert table_synth_mock.load_custom_constraint_classes.assert_called_once_with( - 'path/to/file.py', - ['Custom', 'Constr', 'UpperPlus'] + 'path/to/file.py', ['Custom', 'Constr', 'UpperPlus'] ) def test_load_custom_constraint_classes_multi_tables(self): @@ -1409,19 +1353,15 @@ def test_load_custom_constraint_classes_multi_tables(self): # Run BaseMultiTableSynthesizer.load_custom_constraint_classes( - instance, - 'path/to/file.py', - ['Custom', 'Constr', 'UpperPlus'] + instance, 'path/to/file.py', ['Custom', 'Constr', 'UpperPlus'] ) # Assert table_synth_mock.load_custom_constraint_classes.assert_called_once_with( - 'path/to/file.py', - ['Custom', 'Constr', 'UpperPlus'] + 'path/to/file.py', ['Custom', 'Constr', 'UpperPlus'] ) table_synth_mock_2.load_custom_constraint_classes.assert_called_once_with( - 'path/to/file.py', - ['Custom', 'Constr', 'UpperPlus'] + 'path/to/file.py', ['Custom', 'Constr', 'UpperPlus'] ) def test_add_custom_constraint_class(self): @@ -1433,16 +1373,11 @@ def test_add_custom_constraint_class(self): instance._table_synthesizers = {'table': table_synth_mock} # Run - BaseMultiTableSynthesizer.add_custom_constraint_class( - instance, - constraint_mock, - 'custom' - ) + BaseMultiTableSynthesizer.add_custom_constraint_class(instance, constraint_mock, 'custom') # Assert table_synth_mock.add_custom_constraint_class.assert_called_once_with( - constraint_mock, - 'custom' + constraint_mock, 'custom' ) def test_add_custom_constraint_class_multi_tables(self): @@ -1455,20 +1390,14 @@ def test_add_custom_constraint_class_multi_tables(self): instance._table_synthesizers = {'table': table_synth_mock, 'table_2': table_synth_mock_2} # Run - BaseMultiTableSynthesizer.add_custom_constraint_class( - instance, - constraint_mock, - 'custom' - ) + BaseMultiTableSynthesizer.add_custom_constraint_class(instance, constraint_mock, 'custom') # Assert table_synth_mock.add_custom_constraint_class.assert_called_once_with( - constraint_mock, - 'custom' + constraint_mock, 'custom' ) table_synth_mock_2.add_custom_constraint_class.assert_called_once_with( - constraint_mock, - 'custom' + constraint_mock, 'custom' ) @patch('sdv.multi_table.base.version') @@ -1503,7 +1432,7 @@ def test_get_info(self, mock_version): 'creation_date': '2023-01-23', 'is_fit': False, 'last_fit_date': None, - 'fitted_sdv_version': None + 'fitted_sdv_version': None, } # Run @@ -1516,7 +1445,7 @@ def test_get_info(self, mock_version): 'creation_date': '2023-01-23', 'is_fit': True, 'last_fit_date': '2023-01-23', - 'fitted_sdv_version': '1.0.0' + 'fitted_sdv_version': '1.0.0', } @patch('sdv.multi_table.base.version') @@ -1551,7 +1480,7 @@ def test_get_info_with_enterprise(self, mock_version): 'creation_date': '2023-01-23', 'is_fit': False, 'last_fit_date': None, - 'fitted_sdv_version': None + 'fitted_sdv_version': None, } # Run @@ -1565,7 +1494,7 @@ def test_get_info_with_enterprise(self, mock_version): 'is_fit': True, 'last_fit_date': '2023-01-23', 'fitted_sdv_version': '1.0.0', - 'fitted_sdv_enterprise_version': '1.1.0' + 'fitted_sdv_enterprise_version': '1.1.0', } @patch('sdv.multi_table.base.datetime') @@ -1598,9 +1527,16 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): @patch('sdv.multi_table.base.check_sdv_versions_and_warn') @patch('sdv.multi_table.base.cloudpickle') @patch('builtins.open', new_callable=mock_open) - def test_load(self, mock_file, cloudpickle_mock, - mock_check_sdv_versions_and_warn, mock_check_synthesizer_version, - mock_generate_synthesizer_id, mock_datetime, caplog): + def test_load( + self, + mock_file, + cloudpickle_mock, + mock_check_sdv_versions_and_warn, + mock_check_synthesizer_version, + mock_generate_synthesizer_id, + mock_datetime, + caplog, + ): """Test that the ``load`` method loads a stored synthesizer.""" # Setup synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' @@ -1633,12 +1569,14 @@ def test_load(self, mock_file, cloudpickle_mock, def test_load_runtime_error(self, cloudpickle_mock, mock_open): """Test that the synthesizer's load method errors with the correct message.""" # Setup - cloudpickle_mock.load.side_effect = RuntimeError(( - 'Attempting to deserialize object on a CUDA device but ' - 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' - " please use torch.load with map_location=torch.device('cpu') " - 'to map your storages to the CPU.' - )) + cloudpickle_mock.load.side_effect = RuntimeError( + ( + 'Attempting to deserialize object on a CUDA device but ' + 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' + " please use torch.load with map_location=torch.device('cpu') " + 'to map your storages to the CPU.' + ) + ) # Run and Assert err_msg = re.escape( diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index c40e7b080..4cb0ff202 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -13,7 +13,6 @@ class TestHMASynthesizer: - def test___init__(self): """Test the default initialization of the ``HMASynthesizer``.""" # Run @@ -38,9 +37,7 @@ def test_set_table_parameters_errors_gaussian_kde(self): # Setup default_table_parameters = {'default_distribution': 'gaussian_kde'} numerical_distribution_parameters = { - 'numerical_distributions': { - 'id_nesreca': 'gaussian_kde' - } + 'numerical_distributions': {'id_nesreca': 'gaussian_kde'} } metadata = get_multi_table_metadata() instance = HMASynthesizer(metadata) @@ -65,10 +62,7 @@ def test__get_extension(self): """ # Setup metadata = get_multi_table_metadata() - child_table = pd.DataFrame({ - 'id_nesreca': [0, 1, 2, 3], - 'upravna_enota': [0, 1, 2, 3] - }) + child_table = pd.DataFrame({'id_nesreca': [0, 1, 2, 3], 'upravna_enota': [0, 1, 2, 3]}) instance = HMASynthesizer(metadata) # Run @@ -76,11 +70,11 @@ def test__get_extension(self): # Assert expected = pd.DataFrame({ - '__nesreca__upravna_enota__univariates__id_nesreca__a': [1., 1., 1., 1.], - '__nesreca__upravna_enota__univariates__id_nesreca__b': [1., 1., 1., 1.], - '__nesreca__upravna_enota__univariates__id_nesreca__loc': [0., 1., 2., 3.], + '__nesreca__upravna_enota__univariates__id_nesreca__a': [1.0, 1.0, 1.0, 1.0], + '__nesreca__upravna_enota__univariates__id_nesreca__b': [1.0, 1.0, 1.0, 1.0], + '__nesreca__upravna_enota__univariates__id_nesreca__loc': [0.0, 1.0, 2.0, 3.0], '__nesreca__upravna_enota__univariates__id_nesreca__scale': [np.nan] * 4, - '__nesreca__upravna_enota__num_rows': [1., 1., 1., 1.] + '__nesreca__upravna_enota__num_rows': [1.0, 1.0, 1.0, 1.0], }) pd.testing.assert_frame_equal(result, expected) @@ -94,18 +88,14 @@ def test__get_distributions(self): instance.get_table_parameters.side_effect = [ {'synthesizer_parameters': {'default_distribution': 'gamma'}}, {'wrong_key': {'default_distribution': 'gamma'}}, - {'synthesizer_parameters': {'not_default_distribution': 'wrong'}} + {'synthesizer_parameters': {'not_default_distribution': 'wrong'}}, ] # Run result = instance._get_distributions() # Assert - expected = { - 'nesreca': 'gamma', - 'oseba': None, - 'upravna_enota': None - } + expected = {'nesreca': 'gamma', 'oseba': None, 'upravna_enota': None} assert result == expected @patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns') @@ -119,7 +109,7 @@ def test__print_estimate_warning(self, get_distributions_mock, estimate_mock, ca key_phrases = [ r'PerformanceAlert:', r'large number of columns.', - r'contact us at info@sdv.dev for enterprise solutions.' + r'contact us at info@sdv.dev for enterprise solutions.', ] # Run @@ -151,9 +141,7 @@ def test__get_extension_foreign_key_only(self): instance._get_pbar_args.return_value = {'desc': "(1/2) Tables 'A' and 'B' ('user_id')"} instance.metadata._get_all_foreign_keys.return_value = ['id_upravna_enota'] instance._table_synthesizers = {'nesreca': Mock()} - child_table = pd.DataFrame({ - 'id_upravna_enota': [0, 1, 2, 3] - }) + child_table = pd.DataFrame({'id_upravna_enota': [0, 1, 2, 3]}) # Run result = HMASynthesizer._get_extension( @@ -161,15 +149,12 @@ def test__get_extension_foreign_key_only(self): 'nesreca', child_table, 'id_upravna_enota', - "(1/2) Tables 'A' and 'B' ('user_id')" + "(1/2) Tables 'A' and 'B' ('user_id')", ) # Assert - expected = pd.DataFrame({ - '__nesreca__id_upravna_enota__num_rows': [1, 1, 1, 1] - }) - instance._get_pbar_args.assert_called_once_with( - desc="(1/2) Tables 'A' and 'B' ('user_id')") + expected = pd.DataFrame({'__nesreca__id_upravna_enota__num_rows': [1, 1, 1, 1]}) + instance._get_pbar_args.assert_called_once_with(desc="(1/2) Tables 'A' and 'B' ('user_id')") pd.testing.assert_frame_equal(result, expected) @@ -201,16 +186,16 @@ def test__augment_table(self): 'upravna_enota': [0, 1, 2, 3], 'nesreca_val': [0, 1, 2, 3], 'value': [0, 1, 2, 3], - '__oseba__id_nesreca__correlation__0__0': [0.] * 4, - '__oseba__id_nesreca__univariates__oseba_val__a': [1.] * 4, - '__oseba__id_nesreca__univariates__oseba_val__b': [1.] * 4, - '__oseba__id_nesreca__univariates__oseba_val__loc': [0., 1., 2., 3.], + '__oseba__id_nesreca__correlation__0__0': [0.0] * 4, + '__oseba__id_nesreca__univariates__oseba_val__a': [1.0] * 4, + '__oseba__id_nesreca__univariates__oseba_val__b': [1.0] * 4, + '__oseba__id_nesreca__univariates__oseba_val__loc': [0.0, 1.0, 2.0, 3.0], '__oseba__id_nesreca__univariates__oseba_val__scale': [1e-6] * 4, - '__oseba__id_nesreca__univariates__oseba_value__a': [1.] * 4, - '__oseba__id_nesreca__univariates__oseba_value__b': [1.] * 4, - '__oseba__id_nesreca__univariates__oseba_value__loc': [0., 1., 2., 3.], + '__oseba__id_nesreca__univariates__oseba_value__a': [1.0] * 4, + '__oseba__id_nesreca__univariates__oseba_value__b': [1.0] * 4, + '__oseba__id_nesreca__univariates__oseba_value__loc': [0.0, 1.0, 2.0, 3.0], '__oseba__id_nesreca__univariates__oseba_value__scale': [1e-6] * 4, - '__oseba__id_nesreca__num_rows': [1.] * 4, + '__oseba__id_nesreca__num_rows': [1.0] * 4, }) pd.testing.assert_frame_equal(expected_result, result) @@ -225,11 +210,7 @@ def test__pop_foreign_keys(self): # Setup instance = Mock() instance.metadata._get_all_foreign_keys.return_value = ['a', 'b'] - table_data = pd.DataFrame({ - 'a': [1, 2, 3], - 'b': [2, 3, 4], - 'c': ['John', 'Doe', 'Johanna'] - }) + table_data = pd.DataFrame({'a': [1, 2, 3], 'b': [2, 3, 4], 'c': ['John', 'Doe', 'Johanna']}) # Run result = HMASynthesizer._pop_foreign_keys(instance, table_data, 'table_name') @@ -253,7 +234,7 @@ def test__clear_nans(self): # Assert expected_data = pd.DataFrame({ 'numerical': [0, 1, 2, 3, 1.5, 1.5], - 'categorical': ['John', 'John', 'Johanna', 'John', 'John', 'Doe'] + 'categorical': ['John', 'John', 'Johanna', 'John', 'John', 'Doe'], }) pd.testing.assert_frame_equal(expected_data, data) @@ -272,7 +253,7 @@ def test__model_tables(self): upravna_enota_model = Mock() upravna_enota_model._get_parameters.return_value = { 'col__univariates': 'univariate_param', - 'corr': 'correlation_param' + 'corr': 'correlation_param', } instance = Mock() instance._synthesizer = GaussianCopulaSynthesizer @@ -290,7 +271,7 @@ def test__model_tables(self): 'upravna_enota': pd.DataFrame({ 'id_nesreca': [0, 1, 2], 'upravna_enota': [0, 1, 2], - 'extended': ['a', 'b', 'c'] + 'extended': ['a', 'b', 'c'], }) } augmented_data = input_data.copy() @@ -303,13 +284,12 @@ def test__model_tables(self): 'id_nesreca': [0, 1, 2], 'upravna_enota': [0, 1, 2], 'extended': ['a', 'b', 'c'], - 'fk': [1, 2, 3] + 'fk': [1, 2, 3], }) pd.testing.assert_frame_equal(expected_result, augmented_data['upravna_enota']) instance._pop_foreign_keys.assert_called_once_with( - input_data['upravna_enota'], - 'upravna_enota' + input_data['upravna_enota'], 'upravna_enota' ) instance._clear_nans.assert_called_once_with(input_data['upravna_enota']) upravna_enota_model.fit_processed_data.assert_called_once_with( @@ -357,7 +337,7 @@ def test__finalize(self): metadata = Mock() metadata._get_parent_map.return_value = { 'sessions': ['users'], - 'transactions': ['sessions'] + 'transactions': ['sessions'], } instance.metadata = metadata @@ -387,18 +367,15 @@ def test__finalize(self): 'user_id': np.int64, 'session_id': str, 'os': str, - 'country': str + 'country': str, } transactions_synth = Mock() - transactions_synth._data_processor._dtypes = { - 'transaction_id': np.int64, - 'session_id': str - } + transactions_synth._data_processor._dtypes = {'transaction_id': np.int64, 'session_id': str} instance._table_synthesizers = { 'users': users_synth, 'sessions': sessions_synth, - 'transactions': transactions_synth + 'transactions': transactions_synth, } # Run @@ -437,7 +414,7 @@ def test__extract_parameters(self): instance._max_child_rows = {'__sessions__user_id__num_rows': 10} float_formatter1 = MagicMock() - float_formatter1._min_value = 0. + float_formatter1._min_value = 0.0 float_formatter1._max_value = 5 float_formatter2 = MagicMock() @@ -466,7 +443,7 @@ def test__extract_parameters(self): # Assert expected_result = { - 'a': .1, + 'a': 0.1, 'b': 0.2, 'loc': 0.3, 'num_rows': 5, @@ -487,9 +464,7 @@ def test__recreate_child_synthesizer(self): instance.metadata._get_foreign_keys.return_value = ['session_id'] instance._table_parameters = {'users': {'a': 1}} instance._table_synthesizers = {'users': table_synthesizer} - instance._default_parameters = { - 'users': {'colA': 'default_param', 'colB': 'default_param'} - } + instance._default_parameters = {'users': {'colA': 'default_param', 'colB': 'default_param'}} # Run synthesizer = HMASynthesizer._recreate_child_synthesizer( @@ -505,7 +480,7 @@ def test__recreate_child_synthesizer(self): instance._synthesizer.assert_called_once_with(table_meta, a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, - {'colA': 'default_param', 'colB': 'default_param'} + {'colA': 'default_param', 'colB': 'default_param'}, ) instance._extract_parameters.assert_called_once_with(parent_row, table_name, 'session_id') @@ -544,12 +519,7 @@ def test_get_learned_distributions(self): assert list(result) == ['upravna_val'] assert result['upravna_val'] == { 'distribution': 'beta', - 'learned_parameters': { - 'a': 1.0, - 'b': 1.0, - 'loc': 10.0, - 'scale': 0.0 - } + 'learned_parameters': {'a': 1.0, 'b': 1.0, 'loc': 10.0, 'scale': 0.0}, } def test_get_learned_distributions_raises_an_error(self): @@ -596,21 +566,15 @@ def test__add_foreign_key_columns(self): }) child_table = pd.DataFrame({ 'transaction_id': pd.Series([1, 2, 3], dtype=np.int64), - 'primary_user_id': pd.Series([0, 0, 1], dtype=np.int64) + 'primary_user_id': pd.Series([0, 0, 1], dtype=np.int64), }) - instance._table_synthesizers = { - 'users': Mock(), - 'transactions': Mock() - } + instance._table_synthesizers = {'users': Mock(), 'transactions': Mock()} # Run HMASynthesizer._add_foreign_key_columns( - instance, - child_table, - parent_table, - 'transactions', - 'users') + instance, child_table, parent_table, 'transactions', 'users' + ) # Assert expected_parent_table = pd.DataFrame({ @@ -620,7 +584,7 @@ def test__add_foreign_key_columns(self): expected_child_table = pd.DataFrame({ 'transaction_id': pd.Series([1, 2, 3], dtype=np.int64), 'primary_user_id': pd.Series([0, 0, 1], dtype=np.int64), - 'secondary_user_id': pd.Series([2, 1, 2], dtype=np.int64) + 'secondary_user_id': pd.Series([2, 1, 2], dtype=np.int64), }) pd.testing.assert_frame_equal(expected_parent_table, parent_table) pd.testing.assert_frame_equal(expected_child_table, child_table) @@ -637,7 +601,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'id': [0, 1, 2], 'id1': [0, 1, 2], 'id2': [0, 1, 2], - 'col1': [0, 1, 2] + 'col1': [0, 1, 2], }) data = {'parent': parent, 'child': child} metadata = MultiTableMetadata.load_from_dict({ @@ -646,7 +610,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'primary_key': 'id', 'columns': { 'id': {'sdtype': 'id'}, - } + }, }, 'child': { 'primary_key': 'id', @@ -655,7 +619,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'id1': {'sdtype': 'id'}, 'id2': {'sdtype': 'id'}, 'col1': {'sdtype': 'numerical'}, - } + }, }, }, 'relationships': [ @@ -663,15 +627,15 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) @@ -710,7 +674,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'id': [0, 1, 2], 'id1': [0, 1, 2], 'id2': [0, 1, 2], - 'col': [.2, .3, .2] + 'col': [0.2, 0.3, 0.2], }) data = { 'parent': parent, @@ -718,7 +682,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'child_beta': child, 'child_gamma': child, 'child_truncnorm': child, - 'child_uniform': child + 'child_uniform': child, } child_dict = { 'primary_key': 'id', @@ -727,7 +691,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'id1': {'sdtype': 'id'}, 'id2': {'sdtype': 'id'}, 'col': {'sdtype': 'numerical'}, - } + }, } metadata = MultiTableMetadata.load_from_dict({ 'tables': { @@ -735,7 +699,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'primary_key': 'id', 'columns': { 'id': {'sdtype': 'id'}, - } + }, }, 'child_norm': child_dict, 'child_beta': child_dict, @@ -748,80 +712,76 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_norm', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_norm', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_beta', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_beta', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_truncnorm', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_truncnorm', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_uniform', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_uniform', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_gamma', - 'child_foreign_key': 'id1' + 'child_foreign_key': 'id1', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'id', 'child_table_name': 'child_gamma', - 'child_foreign_key': 'id2' + 'child_foreign_key': 'id2', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) synthesizer.set_table_parameters( - table_name='child_norm', - table_parameters={'default_distribution': 'norm'} + table_name='child_norm', table_parameters={'default_distribution': 'norm'} ) synthesizer.set_table_parameters( - table_name='child_gamma', - table_parameters={'default_distribution': 'gamma'} + table_name='child_gamma', table_parameters={'default_distribution': 'gamma'} ) synthesizer.set_table_parameters( - table_name='child_truncnorm', - table_parameters={'default_distribution': 'truncnorm'} + table_name='child_truncnorm', table_parameters={'default_distribution': 'truncnorm'} ) synthesizer.set_table_parameters( - table_name='child_uniform', - table_parameters={'default_distribution': 'uniform'} + table_name='child_uniform', table_parameters={'default_distribution': 'uniform'} ) synthesizer._finalize = Mock(return_value=data) distributions = synthesizer._get_distributions() @@ -862,7 +822,10 @@ def test__estimate_num_columns_to_be_modeled(self): root1 = pd.DataFrame({'R1': [0, 1, 2]}) root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) grandparent = pd.DataFrame({ - 'GP': [0, 1, 2], 'R1_1': [0, 1, 2], 'R1_2': [0, 1, 2], 'R2': [0, 1, 2] + 'GP': [0, 1, 2], + 'R1_1': [0, 1, 2], + 'R1_2': [0, 1, 2], + 'R2': [0, 1, 2], }) parent = pd.DataFrame({'P': [0, 1, 2], 'GP': [0, 1, 2]}) child = pd.DataFrame({'C': [0, 1, 2], 'P': [0, 1, 2], 'GP': [0, 1, 2]}) @@ -871,7 +834,7 @@ def test__estimate_num_columns_to_be_modeled(self): 'root2': root2, 'grandparent': grandparent, 'parent': parent, - 'child': child + 'child': child, } metadata = MultiTableMetadata.load_from_dict({ 'tables': { @@ -879,14 +842,11 @@ def test__estimate_num_columns_to_be_modeled(self): 'primary_key': 'R1', 'columns': { 'R1': {'sdtype': 'id'}, - } + }, }, 'root2': { 'primary_key': 'R2', - 'columns': { - 'R2': {'sdtype': 'id'}, - 'data': {'sdtype': 'numerical'} - } + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, }, 'grandparent': { 'primary_key': 'GP', @@ -895,14 +855,14 @@ def test__estimate_num_columns_to_be_modeled(self): 'R1_1': {'sdtype': 'id'}, 'R1_2': {'sdtype': 'id'}, 'R2': {'sdtype': 'id'}, - } + }, }, 'parent': { 'primary_key': 'P', 'columns': { 'P': {'sdtype': 'id'}, 'GP': {'sdtype': 'id'}, - } + }, }, 'child': { 'primary_key': 'C', @@ -910,47 +870,47 @@ def test__estimate_num_columns_to_be_modeled(self): 'C': {'sdtype': 'id'}, 'P': {'sdtype': 'id'}, 'GP': {'sdtype': 'id'}, - } - } + }, + }, }, 'relationships': [ { 'parent_table_name': 'root1', 'parent_primary_key': 'R1', 'child_table_name': 'grandparent', - 'child_foreign_key': 'R1_1' + 'child_foreign_key': 'R1_1', }, { 'parent_table_name': 'root1', 'parent_primary_key': 'R1', 'child_table_name': 'grandparent', - 'child_foreign_key': 'R1_2' + 'child_foreign_key': 'R1_2', }, { 'parent_table_name': 'root2', 'parent_primary_key': 'R2', 'child_table_name': 'grandparent', - 'child_foreign_key': 'R2' + 'child_foreign_key': 'R2', }, { 'parent_table_name': 'grandparent', 'parent_primary_key': 'GP', 'child_table_name': 'parent', - 'child_foreign_key': 'GP' + 'child_foreign_key': 'GP', }, { 'parent_table_name': 'grandparent', 'parent_primary_key': 'GP', 'child_table_name': 'child', - 'child_foreign_key': 'GP' + 'child_foreign_key': 'GP', }, { 'parent_table_name': 'parent', 'parent_primary_key': 'P', 'child_table_name': 'child', - 'child_foreign_key': 'P' + 'child_foreign_key': 'P', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) @@ -998,7 +958,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): parent = pd.DataFrame({ 'P': [0, 1, 2], 'GP': [0, 1, 2], - 'numerical': [.1, .5, np.nan], + 'numerical': [0.1, 0.5, np.nan], 'categorical': ['a', np.nan, 'c'], 'datetime': [None, '2019-01-02', '2019-01-03'], 'boolean': [float('nan'), False, True], @@ -1016,14 +976,11 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): 'primary_key': 'R1', 'columns': { 'R1': {'sdtype': 'id'}, - } + }, }, 'root2': { 'primary_key': 'R2', - 'columns': { - 'R2': {'sdtype': 'id'}, - 'data': {'sdtype': 'numerical'} - } + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, }, 'grandparent': { 'primary_key': 'GP', @@ -1031,7 +988,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): 'GP': {'sdtype': 'id'}, 'R1': {'sdtype': 'id'}, 'R2': {'sdtype': 'id'}, - } + }, }, 'parent': { 'primary_key': 'P', @@ -1043,29 +1000,29 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): 'datetime': {'sdtype': 'datetime'}, 'boolean': {'sdtype': 'boolean'}, 'id': {'sdtype': 'id'}, - } - } + }, + }, }, 'relationships': [ { 'parent_table_name': 'root1', 'parent_primary_key': 'R1', 'child_table_name': 'grandparent', - 'child_foreign_key': 'R1' + 'child_foreign_key': 'R1', }, { 'parent_table_name': 'root2', 'parent_primary_key': 'R2', 'child_table_name': 'grandparent', - 'child_foreign_key': 'R2' + 'child_foreign_key': 'R2', }, { 'parent_table_name': 'grandparent', 'parent_primary_key': 'GP', 'child_table_name': 'parent', - 'child_foreign_key': 'GP' + 'child_foreign_key': 'GP', }, - ] + ], }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 3917a993f..ad3770f71 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -10,14 +10,33 @@ from sdv.errors import InvalidDataError, SamplingError from sdv.metadata import MultiTableMetadata from sdv.multi_table.utils import ( - _drop_rows, _get_all_descendant_per_root_at_order_n, _get_ancestors, - _get_columns_to_drop_child, _get_disconnected_roots_from_table, _get_n_order_descendants, - _get_nan_fk_indices_table, _get_num_column_to_drop, _get_primary_keys_referenced, - _get_relationships_for_child, _get_relationships_for_parent, _get_rows_to_drop, - _get_total_estimated_columns, _print_simplified_schema_summary, _print_subsample_summary, - _simplify_child, _simplify_children, _simplify_data, _simplify_grandchildren, - _simplify_metadata, _simplify_relationships_and_tables, _subsample_ancestors, _subsample_data, - _subsample_disconnected_roots, _subsample_parent, _subsample_table_and_descendants) + _drop_rows, + _get_all_descendant_per_root_at_order_n, + _get_ancestors, + _get_columns_to_drop_child, + _get_disconnected_roots_from_table, + _get_n_order_descendants, + _get_nan_fk_indices_table, + _get_num_column_to_drop, + _get_primary_keys_referenced, + _get_relationships_for_child, + _get_relationships_for_parent, + _get_rows_to_drop, + _get_total_estimated_columns, + _print_simplified_schema_summary, + _print_subsample_summary, + _simplify_child, + _simplify_children, + _simplify_data, + _simplify_grandchildren, + _simplify_metadata, + _simplify_relationships_and_tables, + _subsample_ancestors, + _subsample_data, + _subsample_disconnected_roots, + _subsample_parent, + _subsample_table_and_descendants, +) def test__get_relationships_for_child(): @@ -26,7 +45,7 @@ def test__get_relationships_for_child(): relationships = [ {'parent_table_name': 'parent', 'child_table_name': 'child'}, {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, - {'parent_table_name': 'parent', 'child_table_name': 'grandchild'} + {'parent_table_name': 'parent', 'child_table_name': 'grandchild'}, ] # Run @@ -35,7 +54,7 @@ def test__get_relationships_for_child(): # Assert expected_result = [ {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, - {'parent_table_name': 'parent', 'child_table_name': 'grandchild'} + {'parent_table_name': 'parent', 'child_table_name': 'grandchild'}, ] assert result == expected_result @@ -46,7 +65,7 @@ def test__get_relationships_for_parent(): relationships = [ {'parent_table_name': 'parent', 'child_table_name': 'child'}, {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, - {'parent_table_name': 'parent', 'child_table_name': 'grandchild'} + {'parent_table_name': 'parent', 'child_table_name': 'grandchild'}, ] # Run @@ -55,7 +74,7 @@ def test__get_relationships_for_parent(): # Assert expected_result = [ {'parent_table_name': 'parent', 'child_table_name': 'child'}, - {'parent_table_name': 'parent', 'child_table_name': 'grandchild'} + {'parent_table_name': 'parent', 'child_table_name': 'grandchild'}, ] assert result == expected_result @@ -80,20 +99,20 @@ def test__get_rows_to_drop(): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] metadata = Mock() @@ -101,7 +120,7 @@ def test__get_rows_to_drop(): metadata.tables = { 'parent': Mock(primary_key='id_parent'), 'child': Mock(primary_key='id_child'), - 'grandchild': Mock(primary_key='id_grandchild') + 'grandchild': Mock(primary_key='id_grandchild'), } data = { @@ -112,24 +131,20 @@ def test__get_rows_to_drop(): 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5], 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] + 'B': ['Yes', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6], 'child_foreign_key': [9, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) + 'C': ['Yes', 'No', 'No', 'No', 'No'], + }), } # Run result = _get_rows_to_drop(data, metadata) # Assert - expected_result = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 2, 4}, - 'parent': set() - }) + expected_result = defaultdict(set, {'child': {4}, 'grandchild': {0, 2, 4}, 'parent': set()}) assert result == expected_result @@ -141,26 +156,26 @@ def test__get_nan_fk_indices_table(): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] data = { 'grandchild': pd.DataFrame({ 'parent_foreign_key': [np.nan, 1, 2, 2, np.nan], 'child_foreign_key': [9, np.nan, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] + 'C': ['Yes', 'No', 'No', 'No', 'No'], }) } @@ -180,20 +195,20 @@ def test__drop_rows(mock_get_rows_to_drop): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] metadata = Mock() @@ -201,7 +216,7 @@ def test__drop_rows(mock_get_rows_to_drop): metadata.tables = { 'parent': Mock(primary_key='id_parent'), 'child': Mock(primary_key='id_child'), - 'grandchild': Mock(primary_key='id_grandchild') + 'grandchild': Mock(primary_key='id_grandchild'), } data = { 'parent': pd.DataFrame({ @@ -211,19 +226,16 @@ def test__drop_rows(mock_get_rows_to_drop): 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5], 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] + 'B': ['Yes', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6], 'child_foreign_key': [9, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) + 'C': ['Yes', 'No', 'No', 'No', 'No'], + }), } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 2, 4} - }) + mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {4}, 'grandchild': {0, 2, 4}}) # Run _drop_rows(data, metadata, False) @@ -235,16 +247,18 @@ def test__drop_rows(mock_get_rows_to_drop): 'id_parent': [0, 1, 2, 3, 4], 'A': [True, True, False, True, False], }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2], - 'id_child': [5, 6, 7, 8], - 'B': ['Yes', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [1, 2], - 'child_foreign_key': [5, 6], - 'C': ['No', 'No'] - }, index=[1, 3]) + 'child': pd.DataFrame( + { + 'parent_foreign_key': [0, 1, 2, 2], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'], + }, + index=[0, 1, 2, 3], + ), + 'grandchild': pd.DataFrame( + {'parent_foreign_key': [1, 2], 'child_foreign_key': [5, 6], 'C': ['No', 'No']}, + index=[1, 3], + ), } for table_name, table in data.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -259,20 +273,20 @@ def test_drop_unknown_references_with_nan(mock_get_rows_to_drop): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] metadata = Mock() @@ -287,18 +301,15 @@ def test_drop_unknown_references_with_nan(mock_get_rows_to_drop): 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5, None], 'id_child': [5, 6, 7, 8, 9, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6, 4], 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }) + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'], + }), } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 3, 4} - }) + mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {4}, 'grandchild': {0, 3, 4}}) # Run _drop_rows(data, metadata, True) @@ -310,16 +321,18 @@ def test_drop_unknown_references_with_nan(mock_get_rows_to_drop): 'id_parent': [0, 1, 2, 3, 4], 'A': [True, True, False, True, False], }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0., 1., 2., 2.], - 'id_child': [5, 6, 7, 8], - 'B': ['Yes', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [2, 4], - 'child_foreign_key': [5., 4.], - 'C': ['No', 'No'] - }, index=[2, 5]) + 'child': pd.DataFrame( + { + 'parent_foreign_key': [0.0, 1.0, 2.0, 2.0], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'], + }, + index=[0, 1, 2, 3], + ), + 'grandchild': pd.DataFrame( + {'parent_foreign_key': [2, 4], 'child_foreign_key': [5.0, 4.0], 'C': ['No', 'No']}, + index=[2, 5], + ), } for table_name, table in data.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -334,20 +347,20 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] metadata = Mock() @@ -363,18 +376,15 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5, None], 'id_child': [5, 6, 7, 8, 9, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6, 4], 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }) + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'], + }), } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 3, 4} - }) + mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {4}, 'grandchild': {0, 3, 4}}) # Run _drop_rows(data, metadata, drop_missing_values=False) @@ -386,16 +396,22 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop 'id_parent': [0, 1, 2, 3, 4], 'A': [True, True, False, True, False], }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0., 1., 2., 2., None], - 'id_child': [5, 6, 7, 8, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3, 5]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [1, 2, 4], - 'child_foreign_key': [np.nan, 5, 4.], - 'C': ['No', 'No', 'No'] - }, index=[1, 2, 5]) + 'child': pd.DataFrame( + { + 'parent_foreign_key': [0.0, 1.0, 2.0, 2.0, None], + 'id_child': [5, 6, 7, 8, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No'], + }, + index=[0, 1, 2, 3, 5], + ), + 'grandchild': pd.DataFrame( + { + 'parent_foreign_key': [1, 2, 4], + 'child_foreign_key': [np.nan, 5, 4.0], + 'C': ['No', 'No', 'No'], + }, + index=[1, 2, 5], + ), } for table_name, table in data.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -410,20 +426,20 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'child', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } + 'child_foreign_key': 'parent_foreign_key', + }, ] metadata = Mock() @@ -439,18 +455,16 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5], 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] + 'B': ['Yes', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6], 'child_foreign_key': [9, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) + 'C': ['Yes', 'No', 'No', 'No', 'No'], + }), } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {0, 1, 2, 3, 4} - }) + mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {0, 1, 2, 3, 4}}) # Run and Assert expected_message = re.escape( @@ -482,19 +496,13 @@ def test__get_n_order_descendants(): expected_gp_order_1 = { 'order_1': ['parent', 'other_table'], } - expected_gp_order_2 = { - 'order_1': ['parent', 'other_table'], - 'order_2': ['child'] - } + expected_gp_order_2 = {'order_1': ['parent', 'other_table'], 'order_2': ['child']} expected_gp_order_3 = { 'order_1': ['parent', 'other_table'], 'order_2': ['child'], - 'order_3': ['grandchild'] - } - expected_other_order_2 = { - 'order_1': [], - 'order_2': [] + 'order_3': ['grandchild'], } + expected_other_order_2 = {'order_1': [], 'order_2': []} assert grandparent_order_1 == expected_gp_order_1 assert grandparent_order_2 == expected_gp_order_2 assert grandparent_order_3 == expected_gp_order_3 @@ -518,25 +526,32 @@ def test__get_all_descendant_per_root_at_order_n(): # Assert expected_result = { 'other_root': { - 'order_1': ['child'], 'order_2': ['grandchild'], 'order_3': [], - 'num_descendants': 2 + 'order_1': ['child'], + 'order_2': ['grandchild'], + 'order_3': [], + 'num_descendants': 2, }, 'grandparent': { - 'order_1': ['parent', 'other_table'], 'order_2': ['child'], 'order_3': ['grandchild'], - 'num_descendants': 4 - } + 'order_1': ['parent', 'other_table'], + 'order_2': ['child'], + 'order_3': ['grandchild'], + 'num_descendants': 4, + }, } assert result == expected_result -@pytest.mark.parametrize(('table_name', 'expected_result'), [ - ('grandchild', {'child', 'parent', 'grandparent', 'other_root'}), - ('child', {'parent', 'grandparent', 'other_root'}), - ('parent', {'grandparent'}), - ('other_table', {'grandparent'}), - ('grandparent', set()), - ('other_root', set()), -]) +@pytest.mark.parametrize( + ('table_name', 'expected_result'), + [ + ('grandchild', {'child', 'parent', 'grandparent', 'other_root'}), + ('child', {'parent', 'grandparent', 'other_root'}), + ('parent', {'grandparent'}), + ('other_table', {'grandparent'}), + ('grandparent', set()), + ('other_root', set()), + ], +) def test__get_ancestors(table_name, expected_result): """Test the ``_get_ancestors`` method.""" # Setup @@ -555,16 +570,19 @@ def test__get_ancestors(table_name, expected_result): assert result == expected_result -@pytest.mark.parametrize(('table_name', 'expected_result'), [ - ('grandchild', {'disconnected_root'}), - ('child', {'disconnected_root'}), - ('parent', {'disconnected_root'}), - ('other_table', {'disconnected_root', 'other_root'}), - ('grandparent', {'disconnected_root'}), - ('other_root', {'disconnected_root'}), - ('disconnected_root', {'grandparent', 'other_root'}), - ('disconnect_child', {'grandparent', 'other_root'}), -]) +@pytest.mark.parametrize( + ('table_name', 'expected_result'), + [ + ('grandchild', {'disconnected_root'}), + ('child', {'disconnected_root'}), + ('parent', {'disconnected_root'}), + ('other_table', {'disconnected_root', 'other_root'}), + ('grandparent', {'disconnected_root'}), + ('other_root', {'disconnected_root'}), + ('disconnected_root', {'grandparent', 'other_root'}), + ('disconnect_child', {'grandparent', 'other_root'}), + ], +) def test__get_disconnected_roots_from_table(table_name, expected_result): """Test the ``_get_disconnected_roots_from_table`` method.""" # Setup @@ -602,7 +620,7 @@ def test__simplify_relationships_and_tables(): {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, {'parent_table_name': 'grandparent', 'child_table_name': 'other_table'}, {'parent_table_name': 'other_root', 'child_table_name': 'child'}, - ] + ], }) tables_to_drop = {'grandchild', 'other_root'} @@ -646,7 +664,7 @@ def test__simplify_grandchildren(): 'columns': { 'col_10': {'sdtype': 'id'}, 'col_11': {'sdtype': 'phone_number'}, - 'col_12': {'sdtype': 'categorical'} + 'col_12': {'sdtype': 'categorical'}, } }, } @@ -674,27 +692,20 @@ def test__get_num_column_to_drop(): """Test the ``_get_num_column_to_drop`` method.""" # Setup metadata = Mock() - categorical_columns = { - f'col_{i}': {'sdtype': 'categorical'} for i in range(300) - } - numerical_columns = { - f'col_{i}': {'sdtype': 'numerical'} for i in range(300, 600) - } - datetime_columns = { - f'col_{i}': {'sdtype': 'datetime'} for i in range(600, 900) - } - id_columns = { - f'col_{i}': {'sdtype': 'id'} for i in range(900, 910) - } - email_columns = { - f'col_{i}': {'sdtype': 'email'} for i in range(910, 920) - } + categorical_columns = {f'col_{i}': {'sdtype': 'categorical'} for i in range(300)} + numerical_columns = {f'col_{i}': {'sdtype': 'numerical'} for i in range(300, 600)} + datetime_columns = {f'col_{i}': {'sdtype': 'datetime'} for i in range(600, 900)} + id_columns = {f'col_{i}': {'sdtype': 'id'} for i in range(900, 910)} + email_columns = {f'col_{i}': {'sdtype': 'email'} for i in range(910, 920)} metadata = MultiTableMetadata().load_from_dict({ 'tables': { 'child': { 'columns': { - **categorical_columns, **numerical_columns, - **datetime_columns, **id_columns, **email_columns + **categorical_columns, + **numerical_columns, + **datetime_columns, + **id_columns, + **email_columns, } } } @@ -702,7 +713,7 @@ def test__get_num_column_to_drop(): child_table = 'child' max_col_per_relationship = 500 - num_modelable_columnn = (len(metadata.tables[child_table].columns) - 20) + num_modelable_columnn = len(metadata.tables[child_table].columns) - 20 # Run num_col_to_drop, modelable_columns = _get_num_column_to_drop( @@ -728,14 +739,10 @@ def test__get_columns_to_drop_child_drop_all_modelable_columns(mock_get_num_colu mock_get_num_column_to_drop.return_value = (10, modelable_column) # Run - columns_to_drop = _get_columns_to_drop_child( - metadata, 'child', max_col_per_relationship - ) + columns_to_drop = _get_columns_to_drop_child(metadata, 'child', max_col_per_relationship) # Assert - mock_get_num_column_to_drop.assert_called_once_with( - metadata, 'child', max_col_per_relationship - ) + mock_get_num_column_to_drop.assert_called_once_with(metadata, 'child', max_col_per_relationship) assert columns_to_drop == ['col_1', 'col_3', 'col_3', 'col_4', 'col_5'] @@ -745,21 +752,14 @@ def test__get_columns_to_drop_child_only_one_sdtyoe(mock_get_num_column_to_drop) # Setup metadata = Mock() max_col_per_relationship = 10 - modelable_column = { - 'numerical': ['col_1', 'col_2', 'col_3'], - 'categorical': [] - } + modelable_column = {'numerical': ['col_1', 'col_2', 'col_3'], 'categorical': []} mock_get_num_column_to_drop.return_value = (2, modelable_column) # Run - columns_to_drop = _get_columns_to_drop_child( - metadata, 'child', max_col_per_relationship - ) + columns_to_drop = _get_columns_to_drop_child(metadata, 'child', max_col_per_relationship) # Assert - mock_get_num_column_to_drop.assert_called_once_with( - metadata, 'child', max_col_per_relationship - ) + mock_get_num_column_to_drop.assert_called_once_with(metadata, 'child', max_col_per_relationship) assert set(columns_to_drop).issubset({'col_1', 'col_2', 'col_3'}) assert len(set(columns_to_drop)) == len(columns_to_drop) == 2 @@ -777,19 +777,15 @@ def test__get_column_to_drop_child(mock_get_num_column_to_drop): max_col_per_relationship = 10 modelable_column = { 'numerical': ['col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'], - 'categorical': ['col_7', 'col_8', 'col_9', 'col_10'] + 'categorical': ['col_7', 'col_8', 'col_9', 'col_10'], } mock_get_num_column_to_drop.return_value = (5, modelable_column) # Run - columns_to_drop = _get_columns_to_drop_child( - metadata, 'child', max_col_per_relationship - ) + columns_to_drop = _get_columns_to_drop_child(metadata, 'child', max_col_per_relationship) # Assert - mock_get_num_column_to_drop.assert_called_once_with( - metadata, 'child', max_col_per_relationship - ) + mock_get_num_column_to_drop.assert_called_once_with(metadata, 'child', max_col_per_relationship) numerical_set = {'col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'} categorical_set = {'col_7', 'col_8', 'col_9', 'col_10'} output_set = set(columns_to_drop) @@ -852,7 +848,7 @@ def test__simplify_children_valid_children(mock_hma): # Assert mock_hma._get_num_extended_columns.assert_has_calls([ call(metadata, 'child_1', 'parent', 3), - call(metadata, 'child_2', 'parent', 3) + call(metadata, 'child_2', 'parent', 3), ]) @@ -887,17 +883,11 @@ def test__simplify_children(mock_get_columns_to_drop_child, mock_hma): child_2_before_simplify['columns']['col_8'] = {'sdtype': 'categorical'} metadata = MultiTableMetadata().load_from_dict({ 'relationships': relatioships, - 'tables': { - 'child_1': child_1_before_simplify, - 'child_2': child_2_before_simplify - } + 'tables': {'child_1': child_1_before_simplify, 'child_2': child_2_before_simplify}, }) metadata_after_simplify_2 = MultiTableMetadata().load_from_dict({ 'relationships': relatioships, - 'tables': { - 'child_1': child_1, - 'child_2': child_2 - } + 'tables': {'child_1': child_1, 'child_2': child_2}, }) mock_hma._get_num_extended_columns.side_effect = [800, 700] mock_get_columns_to_drop_child.side_effect = [['col_4'], ['col_8']] @@ -909,11 +899,11 @@ def test__simplify_children(mock_get_columns_to_drop_child, mock_hma): assert metadata.to_dict()['tables'] == metadata_after_simplify_2.to_dict()['tables'] mock_hma._get_num_extended_columns.assert_has_calls([ call(metadata, 'child_1', 'parent', 3), - call(metadata, 'child_2', 'parent', 3) + call(metadata, 'child_2', 'parent', 3), ]) mock_get_columns_to_drop_child.assert_has_calls([ call(metadata, 'child_1', 500), - call(metadata, 'child_2', 500) + call(metadata, 'child_2', 500), ]) @@ -921,10 +911,7 @@ def test__simplify_children(mock_get_columns_to_drop_child, mock_hma): def test__get_total_estimated_columns(mock_hma): """Test the ``_get_total_estimated_columns`` method.""" # Setup - mock_hma._estimate_num_columns.return_value = { - 'child_1': 500, - 'child_2': 700 - } + mock_hma._estimate_num_columns.return_value = {'child_1': 500, 'child_2': 700} metadata = Mock() # Run @@ -962,17 +949,13 @@ def test__simplify_metadata_no_child_simplification(mock_hma): }, 'grandchild': {'columns': {'col_7': {'sdtype': 'numerical'}}}, 'other_table': {'columns': {'col_8': {'sdtype': 'numerical'}}}, - 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}} + 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } metadata = MultiTableMetadata().load_from_dict({ 'relationships': relationships, - 'tables': tables + 'tables': tables, }) - mock_hma._estimate_num_columns.return_value = { - 'child': 10, - 'parent': 20, - 'other_table': 30 - } + mock_hma._estimate_num_columns.return_value = {'child': 10, 'parent': 20, 'other_table': 30} # Run metadata_result = _simplify_metadata(metadata) @@ -1009,20 +992,20 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'parent_table_name': 'grandparent', 'child_table_name': 'parent', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, { 'parent_table_name': 'grandparent', 'child_table_name': 'other_table', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'col_9' + 'child_foreign_key': 'col_9', }, {'parent_table_name': 'other_root', 'child_table_name': 'child'}, ] @@ -1032,7 +1015,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'col_1': {'sdtype': 'numerical'}, 'id_parent': {'sdtype': 'id'}, }, - 'primary_key': 'id_parent' + 'primary_key': 'id_parent', }, 'parent': { 'columns': { @@ -1043,7 +1026,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'parent_foreign_key': {'sdtype': 'id'}, 'id_child': {'sdtype': 'id'}, }, - 'primary_key': 'id_child' + 'primary_key': 'id_child', }, 'child': { 'columns': { @@ -1062,15 +1045,13 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'col_10': {'sdtype': 'categorical'}, } }, - 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}} + 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } metadata = MultiTableMetadata().load_from_dict({ 'relationships': relationships, - 'tables': tables + 'tables': tables, }) - mock_hma._estimate_num_columns.return_value = { - 'child': 800, 'parent': 900, 'other_table': 10 - } + mock_hma._estimate_num_columns.return_value = {'child': 800, 'parent': 900, 'other_table': 10} mock_hma._get_num_extended_columns.side_effect = [500, 700, 10] mock_get_columns_to_drop_child.side_effect = [ ['col_2', 'col_3'], @@ -1087,7 +1068,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'col_1': {'sdtype': 'numerical'}, 'id_parent': {'sdtype': 'id'}, }, - 'primary_key': 'id_parent' + 'primary_key': 'id_parent', }, 'parent': { 'columns': { @@ -1096,7 +1077,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'parent_foreign_key': {'sdtype': 'id'}, 'id_child': {'sdtype': 'id'}, }, - 'primary_key': 'id_child' + 'primary_key': 'id_child', }, 'child': { 'columns': { @@ -1117,19 +1098,19 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): 'parent_table_name': 'grandparent', 'child_table_name': 'parent', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' + 'child_foreign_key': 'parent_foreign_key', }, { 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' + 'child_foreign_key': 'child_foreign_key', }, { 'parent_table_name': 'grandparent', 'child_table_name': 'other_table', 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'col_9' + 'child_foreign_key': 'col_9', }, ] metadata_dict = metadata_result.to_dict() @@ -1152,32 +1133,25 @@ def test__simplify_data(): 'parent_table_name': 'parent', 'child_table_name': 'child', 'parent_primary_key': 'col_1', - 'child_foreign_key': 'col_2' + 'child_foreign_key': 'col_2', }, { 'parent_table_name': 'parent', 'child_table_name': 'grandchild', 'parent_primary_key': 'col_1', - 'child_foreign_key': 'col_4' + 'child_foreign_key': 'col_4', }, - ] + ], }) data = { - 'parent': pd.DataFrame({ - 'col_1': [1, 2, 3] - }), - 'child': pd.DataFrame({ - 'col_2': [2, 2, 3], - 'col_3': [7, 8, 9] - }), + 'parent': pd.DataFrame({'col_1': [1, 2, 3]}), + 'child': pd.DataFrame({'col_2': [2, 2, 3], 'col_3': [7, 8, 9]}), 'grandchild': pd.DataFrame({ 'col_4': [3, 2, 1], 'col_5': [10, 11, 12], - 'col_7': [13, 14, 15] + 'col_7': [13, 14, 15], }), - 'grandchild_2': pd.DataFrame({ - 'col_5': [10, 11, 12] - }) + 'grandchild_2': pd.DataFrame({'col_5': [10, 11, 12]}), } # Run @@ -1185,9 +1159,7 @@ def test__simplify_data(): # Assert expected_results = { - 'parent': pd.DataFrame({ - 'col_1': [1, 2, 3] - }), + 'parent': pd.DataFrame({'col_1': [1, 2, 3]}), 'child': pd.DataFrame({ 'col_2': [2, 2, 3], }), @@ -1212,9 +1184,7 @@ def test__print_simplified_schema_summary(capsys): 'col_5': [10, 11, 12], 'col_6': [13, 14, 15], }) - data_before_3 = pd.DataFrame({ - 'col_7': [10, 11, 12] - }) + data_before_3 = pd.DataFrame({'col_7': [10, 11, 12]}) data_before = { 'Table 1': data_before_1, 'Table 2': data_before_2, @@ -1227,7 +1197,6 @@ def test__print_simplified_schema_summary(capsys): }) data_after_2 = pd.DataFrame({ 'col_5': [2, 2, 3], - }) data_after = { 'Table 1': data_after_1, @@ -1326,7 +1295,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo {'parent_table_name': 'grandparent', 'child_table_name': 'other_table'}, {'parent_table_name': 'other_root', 'child_table_name': 'child'}, {'parent_table_name': 'disconnected_root', 'child_table_name': 'disconnect_child'}, - ] + ], }) mock_get_disconnected_roots_from_table.return_value = {'grandparent', 'other_root'} ratio_to_keep = 0.6 @@ -1349,8 +1318,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo @patch('sdv.multi_table.utils._drop_rows') @patch('sdv.multi_table.utils._get_nan_fk_indices_table') -def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, - mock_drop_rows): +def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, mock_drop_rows): """Test the ``_subsample_table_and_descendants`` method.""" # Setup data = { @@ -1379,9 +1347,7 @@ def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, _subsample_table_and_descendants(data, metadata, 'parent', 3) # Assert - mock_get_nan_fk_indices_table.assert_called_once_with( - data, metadata.relationships, 'parent' - ) + mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent') mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) assert len(data['parent']) == 3 @@ -1404,9 +1370,7 @@ def test__subsample_table_and_descendants_nan_fk(mock_get_nan_fk_indices_table): _subsample_table_and_descendants(data, metadata, 'parent', 3) # Assert - mock_get_nan_fk_indices_table.assert_called_once_with( - data, metadata.relationships, 'parent' - ) + mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent') def test__get_primary_keys_referenced(): @@ -1443,7 +1407,7 @@ def test__get_primary_keys_referenced(): 'pk_gp': {'type': 'id'}, 'col_1': {'type': 'numerical'}, }, - 'primary_key': 'pk_gp' + 'primary_key': 'pk_gp', }, 'parent': { 'columns': { @@ -1451,7 +1415,7 @@ def test__get_primary_keys_referenced(): 'pk_p': {'type': 'id'}, 'col_2': {'type': 'numerical'}, }, - 'primary_key': 'pk_p' + 'primary_key': 'pk_p', }, 'child': { 'columns': { @@ -1461,7 +1425,7 @@ def test__get_primary_keys_referenced(): 'pk_c': {'type': 'id'}, 'col_3': {'type': 'numerical'}, }, - 'primary_key': 'pk_c' + 'primary_key': 'pk_c', }, 'grandchild': { 'columns': { @@ -1490,7 +1454,6 @@ def test__get_primary_keys_referenced(): 'child_table_name': 'child', 'parent_primary_key': 'pk_p', 'child_foreign_key': 'fk_gp', - }, { 'parent_table_name': 'parent', @@ -1515,8 +1478,8 @@ def test__get_primary_keys_referenced(): 'child_table_name': 'grandchild', 'parent_primary_key': 'pk_p', 'child_foreign_key': 'fk_p_4', - } - ] + }, + ], }) # Run @@ -1600,9 +1563,7 @@ def test__subsample_parent_not_all_referenced_before(): # Assert assert len(data['parent']) == 6 - assert set(data['parent']['pk_p']).issubset({ - 1, 2, 3, 4, 6, 7, 8 - }) + assert set(data['parent']['pk_p']).issubset({1, 2, 3, 4, 6, 7, 8}) def test__subsample_ancestors(): @@ -1610,13 +1571,28 @@ def test__subsample_ancestors(): # Setup data = { 'grandparent': pd.DataFrame({ - 'pk_gp': [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 - ], + 'pk_gp': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 'col_1': [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', - 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', ], }), 'parent': pd.DataFrame({ @@ -1651,7 +1627,7 @@ def test__subsample_ancestors(): 'pk_gp': {'type': 'id'}, 'col_1': {'type': 'numerical'}, }, - 'primary_key': 'pk_gp' + 'primary_key': 'pk_gp', }, 'parent': { 'columns': { @@ -1659,7 +1635,7 @@ def test__subsample_ancestors(): 'pk_p': {'type': 'id'}, 'col_2': {'type': 'numerical'}, }, - 'primary_key': 'pk_p' + 'primary_key': 'pk_p', }, 'child': { 'columns': { @@ -1669,7 +1645,7 @@ def test__subsample_ancestors(): 'pk_c': {'type': 'id'}, 'col_3': {'type': 'numerical'}, }, - 'primary_key': 'pk_c' + 'primary_key': 'pk_c', }, 'grandchild': { 'columns': { @@ -1698,7 +1674,6 @@ def test__subsample_ancestors(): 'child_table_name': 'child', 'parent_primary_key': 'pk_p', 'child_foreign_key': 'fk_gp', - }, { 'parent_table_name': 'parent', @@ -1717,8 +1692,8 @@ def test__subsample_ancestors(): 'child_table_name': 'grandchild', 'parent_primary_key': 'pk_p', 'child_foreign_key': 'fk_p_3', - } - ] + }, + ], }) # Run @@ -1726,31 +1701,56 @@ def test__subsample_ancestors(): # Assert expected_result = { - 'parent': pd.DataFrame({ - 'fk_gp': [1, 2, 3, 6], - 'pk_p': [11, 12, 13, 16], - 'col_2': ['k', 'l', 'm', 'p'], - }, index=[0, 1, 2, 5]), - 'child': pd.DataFrame({ - 'fk_gp': [4, 5, 6], - 'fk_p_1': [11, 11, 11], - 'fk_p_2': [12, 12, 12], - 'pk_c': [21, 22, 23], - 'col_3': ['q', 'r', 's'], - }, index=[0, 1, 2]), - 'grandchild': pd.DataFrame({ - 'fk_p_3': [11, 12, 13, 11, 13], - 'fk_c': [21, 22, 23, 21, 22], - 'col_4': [36, 37, 38, 39, 40], - }, index=[0, 1, 2, 3, 4]), + 'parent': pd.DataFrame( + { + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [11, 12, 13, 16], + 'col_2': ['k', 'l', 'm', 'p'], + }, + index=[0, 1, 2, 5], + ), + 'child': pd.DataFrame( + { + 'fk_gp': [4, 5, 6], + 'fk_p_1': [11, 11, 11], + 'fk_p_2': [12, 12, 12], + 'pk_c': [21, 22, 23], + 'col_3': ['q', 'r', 's'], + }, + index=[0, 1, 2], + ), + 'grandchild': pd.DataFrame( + { + 'fk_p_3': [11, 12, 13, 11, 13], + 'fk_c': [21, 22, 23, 21, 22], + 'col_4': [36, 37, 38, 39, 40], + }, + index=[0, 1, 2, 3, 4], + ), } assert len(data['grandparent']) == 14 - assert set(data['grandparent']['pk_gp']).issubset( - { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20 - } - ) + assert set(data['grandparent']['pk_gp']).issubset({ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + }) for table_name in ['parent', 'child', 'grandchild']: pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) @@ -1760,13 +1760,28 @@ def test__subsample_ancestors_schema_diamond_shape(): # Setup data = { 'grandparent': pd.DataFrame({ - 'pk_gp': [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 - ], + 'pk_gp': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 'col_1': [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', - 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', ], }), 'parent_1': pd.DataFrame({ @@ -1783,7 +1798,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'fk_p_1': [21, 22, 23, 23, 23], 'fk_p_2': [31, 32, 33, 34, 34], 'col_4': ['q', 'r', 's', 't', 'u'], - }) + }), } primary_key_referenced = { @@ -1799,7 +1814,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'pk_gp': {'type': 'id'}, 'col_1': {'type': 'numerical'}, }, - 'primary_key': 'pk_gp' + 'primary_key': 'pk_gp', }, 'parent_1': { 'columns': { @@ -1807,7 +1822,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'pk_p': {'type': 'id'}, 'col_2': {'type': 'numerical'}, }, - 'primary_key': 'pk_p' + 'primary_key': 'pk_p', }, 'parent_2': { 'columns': { @@ -1815,7 +1830,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'pk_p': {'type': 'id'}, 'col_3': {'type': 'numerical'}, }, - 'primary_key': 'pk_p' + 'primary_key': 'pk_p', }, 'child': { 'columns': { @@ -1850,7 +1865,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'parent_primary_key': 'pk_p', 'child_foreign_key': 'fk_p_2', }, - ] + ], }) # Run @@ -1858,29 +1873,54 @@ def test__subsample_ancestors_schema_diamond_shape(): # Assert expected_result = { - 'parent_1': pd.DataFrame({ - 'fk_gp': [1, 2, 3, 6], - 'pk_p': [21, 22, 23, 26], - 'col_2': ['k', 'l', 'm', 'p'], - }, index=[0, 1, 2, 5]), - 'parent_2': pd.DataFrame({ - 'fk_gp': [7, 8, 9, 10], - 'pk_p': [31, 32, 33, 34], - 'col_3': ['k', 'l', 'm', 'n'], - }, index=[0, 1, 2, 3]), - 'child': pd.DataFrame({ - 'fk_p_1': [21, 22, 23, 23, 23], - 'fk_p_2': [31, 32, 33, 34, 34], - 'col_4': ['q', 'r', 's', 't', 'u'], - }, index=[0, 1, 2, 3, 4]), + 'parent_1': pd.DataFrame( + { + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [21, 22, 23, 26], + 'col_2': ['k', 'l', 'm', 'p'], + }, + index=[0, 1, 2, 5], + ), + 'parent_2': pd.DataFrame( + { + 'fk_gp': [7, 8, 9, 10], + 'pk_p': [31, 32, 33, 34], + 'col_3': ['k', 'l', 'm', 'n'], + }, + index=[0, 1, 2, 3], + ), + 'child': pd.DataFrame( + { + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }, + index=[0, 1, 2, 3, 4], + ), } assert len(data['grandparent']) == 14 - assert set(data['grandparent']['pk_gp']).issubset( - { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20 - } - ) + assert set(data['grandparent']['pk_gp']).issubset({ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + }) for table_name in ['parent_1', 'parent_2', 'child']: pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) @@ -1897,7 +1937,7 @@ def test__subsample_data( mock_get_primary_keys_referenced, mock_subsample_ancestors, mock_subsample_table_and_descendants, - mock_subsample_disconnected_roots + mock_subsample_disconnected_roots, ): """Test the ``_subsample_data`` method.""" # Setup @@ -1907,9 +1947,7 @@ def test__subsample_data( metadata = Mock() num_rows = 5 main_table = 'main_table' - primary_key_reference = { - 'main_table': {1, 2, 4} - } + primary_key_reference = {'main_table': {1, 2, 4}} mock_get_primary_keys_referenced.return_value = primary_key_reference # Run @@ -1935,7 +1973,7 @@ def test__subsample_data( def test__subsample_data_empty_dataset( mock_validate_foreign_keys_not_null, mock_get_primary_keys_referenced, - mock_subsample_disconnected_roots + mock_subsample_disconnected_roots, ): """Test the ``subsample_data`` method when a dataset is empty.""" # Setup @@ -1961,13 +1999,28 @@ def test__print_subsample_summary(capsys): # Setup data_before = { 'grandparent': pd.DataFrame({ - 'pk_gp': [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 - ], + 'pk_gp': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 'col_1': [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', - 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', ], }), 'parent_1': pd.DataFrame({ @@ -1984,35 +2037,41 @@ def test__print_subsample_summary(capsys): 'fk_p_1': [21, 22, 23, 23, 23], 'fk_p_2': [31, 32, 33, 34, 34], 'col_4': ['q', 'r', 's', 't', 'u'], - }) + }), } data_after = { - 'grandparent': pd.DataFrame({ - 'pk_gp': [ - 1, 2, 3, 6, 7, 8, 9, 10, 14, 15, - 16, 17, 18, 20 - ], - 'col_1': [ - 'a', 'b', 'c', 'f', 'g', 'h', 'i', 'j', 'n', 'o', - 'p', 'q', 'r', 't' - ], - }, index=[0, 1, 2, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 19]), - 'parent_1': pd.DataFrame({ - 'fk_gp': [1, 2, 3, 6], - 'pk_p': [21, 22, 23, 26], - 'col_2': ['k', 'l', 'm', 'p'], - }, index=[0, 1, 2, 5]), - 'parent_2': pd.DataFrame({ - 'fk_gp': [7, 8, 9, 10], - 'pk_p': [31, 32, 33, 34], - 'col_3': ['k', 'l', 'm', 'n'], - }, index=[0, 1, 2, 3]), - 'child': pd.DataFrame({ - 'fk_p_1': [21, 22, 23, 23, 23], - 'fk_p_2': [31, 32, 33, 34, 34], - 'col_4': ['q', 'r', 's', 't', 'u'], - }, index=[0, 1, 2, 3, 4]), + 'grandparent': pd.DataFrame( + { + 'pk_gp': [1, 2, 3, 6, 7, 8, 9, 10, 14, 15, 16, 17, 18, 20], + 'col_1': ['a', 'b', 'c', 'f', 'g', 'h', 'i', 'j', 'n', 'o', 'p', 'q', 'r', 't'], + }, + index=[0, 1, 2, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 19], + ), + 'parent_1': pd.DataFrame( + { + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [21, 22, 23, 26], + 'col_2': ['k', 'l', 'm', 'p'], + }, + index=[0, 1, 2, 5], + ), + 'parent_2': pd.DataFrame( + { + 'fk_gp': [7, 8, 9, 10], + 'pk_p': [31, 32, 33, 34], + 'col_3': ['k', 'l', 'm', 'n'], + }, + index=[0, 1, 2, 3], + ), + 'child': pd.DataFrame( + { + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }, + index=[0, 1, 2, 3, 4], + ), } # Run diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index b9f6d02c9..2689b2789 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -9,8 +9,7 @@ from tests.utils import DataFrameMatcher, SeriesMatcher, get_multi_table_metadata -class TestBaseHierarchicalSampler(): - +class TestBaseHierarchicalSampler: def test___init__(self): """Test the default initialization of the ``BaseHierarchicalSampler``.""" # Run @@ -44,11 +43,11 @@ def test__add_foreign_key_columns(self): child_table=pd.DataFrame(), parent_table=pd.DataFrame(), child_name='oseba', - parent_name='nescra' + parent_name='nescra', ) def test__sample_rows(self): - """Test that ``_sample_rows`` samples ``num_rows`` from the synthesizer. """ + """Test that ``_sample_rows`` samples ``num_rows`` from the synthesizer.""" synthesizer = Mock() instance = Mock() @@ -57,13 +56,10 @@ def test__sample_rows(self): # Assert assert result == synthesizer._sample_batch.return_value - synthesizer._sample_batch.assert_called_once_with( - 10, - keep_extra_columns=True - ) + synthesizer._sample_batch.assert_called_once_with(10, keep_extra_columns=True) def test__sample_rows_missing_num_rows(self): - """Test that ``_sample_rows`` falls back to ``synthesizer._num_rows``. """ + """Test that ``_sample_rows`` falls back to ``synthesizer._num_rows``.""" synthesizer = Mock() synthesizer._num_rows = 10 instance = Mock() @@ -73,10 +69,7 @@ def test__sample_rows_missing_num_rows(self): # Assert assert result == synthesizer._sample_batch.return_value - synthesizer._sample_batch.assert_called_once_with( - 10, - keep_extra_columns=True - ) + synthesizer._sample_batch.assert_called_once_with(10, keep_extra_columns=True) def test__add_child_rows(self): """Test adding child rows when sampled data is empty.""" @@ -89,10 +82,7 @@ def test__add_child_rows(self): sessions_meta = Mock() users_meta = Mock() users_meta.primary_key = 'user_id' - metadata.tables = { - 'users': users_meta, - 'sessions': sessions_meta - } + metadata.tables = {'users': users_meta, 'sessions': sessions_meta} metadata._get_foreign_keys.return_value = ['user_id'] instance.metadata = metadata @@ -104,7 +94,7 @@ def test__add_child_rows(self): parent_row = pd.DataFrame({ 'user_id': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna'], - '__sessions__user_id__num_rows': [10, 10, 10] + '__sessions__user_id__num_rows': [10, 10, 10], }) sampled_data = {} @@ -118,7 +108,7 @@ def test__add_child_rows(self): 'session_id': ['a', 'b', 'c'], 'os': ['linux', 'mac', 'win'], 'country': ['us', 'us', 'es'], - 'user_id': [1, 2, 3] + 'user_id': [1, 2, 3], }) pd.testing.assert_frame_equal(sampled_data['sessions'], expected_result) @@ -136,10 +126,7 @@ def test__add_child_rows_with_sampled_data(self): sessions_meta = Mock() users_meta = Mock() users_meta.primary_key.return_value = 'user_id' - metadata.tables = { - 'users': users_meta, - 'sessions': sessions_meta - } + metadata.tables = {'users': users_meta, 'sessions': sessions_meta} metadata._get_foreign_keys.return_value = ['user_id'] instance.metadata = metadata instance._synthesizer_kwargs = {'a': 0.1, 'b': 0.5, 'loc': 0.25} @@ -152,7 +139,7 @@ def test__add_child_rows_with_sampled_data(self): parent_row = pd.DataFrame({ 'user_id': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna'], - '__sessions__user_id__num_rows': [10, 10, 10] + '__sessions__user_id__num_rows': [10, 10, 10], }) sampled_data = { 'sessions': pd.DataFrame({ @@ -165,7 +152,8 @@ def test__add_child_rows_with_sampled_data(self): # Run BaseHierarchicalSampler._add_child_rows( - instance, 'sessions', 'users', parent_row, sampled_data) + instance, 'sessions', 'users', parent_row, sampled_data + ) # Assert expected_result = pd.DataFrame({ @@ -181,11 +169,12 @@ def test__sample_children(self): ``_sample_table`` does not sample the root parents of a graph, only the children. """ + # Setup def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2, 3], - 'session_id': ['a', 'a', 'b'] + 'session_id': ['a', 'a', 'b'], }) def _add_child_rows(child_name, parent_name, parent_row, sampled_data): @@ -195,7 +184,7 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data): 'user_id': [1, 1], 'session_id': ['a', 'b'], 'os': ['windows', 'linux'], - 'country': ['us', 'us'] + 'country': ['us', 'us'], }) if parent_row['user_id'] == 3: @@ -203,11 +192,12 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data): 'user_id': [3], 'session_id': ['c'], 'os': ['mac'], - 'country': ['es'] + 'country': ['es'], }) - sampled_data[child_name] = pd.concat( - [sampled_data[child_name], row] - ).reset_index(drop=True) + sampled_data[child_name] = pd.concat([ + sampled_data[child_name], + row, + ]).reset_index(drop=True) instance = Mock() instance.metadata._get_child_map.return_value = {'users': ['sessions', 'transactions']} @@ -218,30 +208,28 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data): instance._add_child_rows.side_effect = _add_child_rows # Run - result = { - 'users': pd.DataFrame({ - 'user_id': [1, 3] - }) - } + result = {'users': pd.DataFrame({'user_id': [1, 3]})} BaseHierarchicalSampler._sample_children( - self=instance, - table_name='users', - sampled_data=result + self=instance, table_name='users', sampled_data=result ) # Assert expected_calls = [ - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), - sampled_data=result), - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(pd.Series({'user_id': 3}, name=1)), - sampled_data=result) + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), + sampled_data=result, + ), + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(pd.Series({'user_id': 3}, name=1)), + sampled_data=result, + ), ] expected_result = { - 'users': pd.DataFrame({ - 'user_id': [1, 3] - }), + 'users': pd.DataFrame({'user_id': [1, 3]}), 'sessions': pd.DataFrame({ 'user_id': [1, 1, 3], 'session_id': ['a', 'b', 'c'], @@ -250,8 +238,8 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data): }), 'transactions': pd.DataFrame({ 'transaction_id': [1, 2, 3], - 'session_id': ['a', 'a', 'b'] - }) + 'session_id': ['a', 'a', 'b'], + }), } instance._add_child_rows.assert_has_calls(expected_calls) for result_frame, expected_frame in zip(result.values(), expected_result.values()): @@ -263,19 +251,17 @@ def test__sample_children_no_rows_sampled(self): ``_sample_table`` should select the parent row with the highest ``num_rows`` value and force a child to be created from that row. """ + # Setup def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2], - 'session_id': ['a', 'a'] + 'session_id': ['a', 'a'], }) def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=None): if num_rows is not None: - sampled_data['sessions'] = pd.DataFrame({ - 'user_id': [1], - 'session_id': ['a'] - }) + sampled_data['sessions'] = pd.DataFrame({'user_id': [1], 'session_id': ['a']}) instance = Mock() instance.metadata._get_child_map.return_value = {'users': ['sessions', 'transactions']} @@ -287,44 +273,35 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows= instance._add_child_rows.side_effect = _add_child_rows # Run - result = { - 'users': pd.DataFrame({ - 'user_id': [1], - '__sessions__user_id__num_rows': [1] - }) - } + result = {'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]})} BaseHierarchicalSampler._sample_children( - self=instance, - table_name='users', - sampled_data=result + self=instance, table_name='users', sampled_data=result ) # Assert - expected_parent_row = pd.Series({ - 'user_id': 1, - '__sessions__user_id__num_rows': 1 - }, name=0) + expected_parent_row = pd.Series({'user_id': 1, '__sessions__user_id__num_rows': 1}, name=0) expected_calls = [ - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(expected_parent_row), - sampled_data=result), - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(expected_parent_row), - sampled_data=result, num_rows=1) + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(expected_parent_row), + sampled_data=result, + ), + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(expected_parent_row), + sampled_data=result, + num_rows=1, + ), ] expected_result = { - 'users': pd.DataFrame({ - 'user_id': [1], - '__sessions__user_id__num_rows': [1] - }), + 'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]}), 'sessions': pd.DataFrame({ 'user_id': [1], 'session_id': ['a'], }), - 'transactions': pd.DataFrame({ - 'transaction_id': [1, 2], - 'session_id': ['a', 'a'] - }) + 'transactions': pd.DataFrame({'transaction_id': [1, 2], 'session_id': ['a', 'a']}), } instance._add_child_rows.assert_has_calls(expected_calls) for result_frame, expected_frame in zip(result.values(), expected_result.values()): @@ -336,19 +313,17 @@ def test__sample_children_no_rows_sampled_no_num_rows(self): ``_sample_table`` should select randomly select a parent row and force a child to be created from that row. """ + # Setup def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2], - 'session_id': ['a', 'a'] + 'session_id': ['a', 'a'], }) def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=None): if num_rows is not None: - sampled_data['sessions'] = pd.DataFrame({ - 'user_id': [1], - 'session_id': ['a'] - }) + sampled_data['sessions'] = pd.DataFrame({'user_id': [1], 'session_id': ['a']}) instance = Mock() instance.metadata._get_child_map.return_value = {'users': ['sessions', 'transactions']} @@ -360,38 +335,34 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows= instance._add_child_rows.side_effect = _add_child_rows # Run - result = { - 'users': pd.DataFrame({ - 'user_id': [1] - }) - } + result = {'users': pd.DataFrame({'user_id': [1]})} BaseHierarchicalSampler._sample_children( - self=instance, - table_name='users', - sampled_data=result + self=instance, table_name='users', sampled_data=result ) # Assert expected_calls = [ - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), - sampled_data=result), - call(child_name='sessions', parent_name='users', - parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), - sampled_data=result, num_rows=1) + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), + sampled_data=result, + ), + call( + child_name='sessions', + parent_name='users', + parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)), + sampled_data=result, + num_rows=1, + ), ] expected_result = { - 'users': pd.DataFrame({ - 'user_id': [1] - }), + 'users': pd.DataFrame({'user_id': [1]}), 'sessions': pd.DataFrame({ 'user_id': [1], 'session_id': ['a'], }), - 'transactions': pd.DataFrame({ - 'transaction_id': [1, 2], - 'session_id': ['a', 'a'] - }) + 'transactions': pd.DataFrame({'transaction_id': [1, 2], 'session_id': ['a', 'a']}), } instance._add_child_rows.assert_has_calls(expected_calls) for result_frame, expected_frame in zip(result.values(), expected_result.values()): @@ -404,7 +375,7 @@ def test__finalize(self): metadata = Mock() metadata._get_parent_map.return_value = { 'sessions': ['users'], - 'transactions': ['sessions'] + 'transactions': ['sessions'], } instance.metadata = metadata @@ -434,18 +405,15 @@ def test__finalize(self): 'user_id': np.int64, 'session_id': str, 'os': str, - 'country': str + 'country': str, } transactions_synth = Mock() - transactions_synth._data_processor._dtypes = { - 'transaction_id': np.int64, - 'session_id': str - } + transactions_synth._data_processor._dtypes = {'transaction_id': np.int64, 'session_id': str} instance._table_synthesizers = { 'users': users_synth, 'sessions': sessions_synth, - 'transactions': transactions_synth + 'transactions': transactions_synth, } # Run @@ -481,20 +449,17 @@ def test__sample(self): 4. All extra columns are dropped by calling ``_finalize``. """ # Setup - users = pd.DataFrame({ - 'id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }) + users = pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) sessions = pd.DataFrame({ 'user_id': [1, 1, 3], 'session_id': ['a', 'b', 'c'], 'os': ['windows', 'linux', 'mac'], - 'country': ['us', 'us', 'es'] + 'country': ['us', 'us', 'es'], }) transactions = pd.DataFrame({ 'user_id': [1, 2, 3], 'transaction_id': [1, 2, 3], - 'transaction_amount': [100, 1000, 200] + 'transaction_amount': [100, 1000, 200], }) def _sample_children_dummy(table_name, sampled_data, scale): @@ -502,30 +467,26 @@ def _sample_children_dummy(table_name, sampled_data, scale): sampled_data['transactions'] = transactions instance = Mock() - instance._table_sizes = { - 'users': 3, - 'transactions': 9, - 'sessions': 5 - } + instance._table_sizes = {'users': 3, 'transactions': 9, 'sessions': 5} instance.metadata.relationships = [ { 'parent_table_name': 'users', 'parent_primary_key': 'id', 'child_table_name': 'sessions', - 'child_foreign_key': 'user_id' + 'child_foreign_key': 'user_id', }, { 'parent_table_name': 'users', 'parent_primary_key': 'id', 'child_table_name': 'transactions', - 'child_foreign_key': 'user_id' - } + 'child_foreign_key': 'user_id', + }, ] users_synthesizer = Mock() instance._table_synthesizers = defaultdict(Mock, {'users': users_synthesizer}) instance.metadata._get_parent_map.return_value = { 'sessions': ['users'], - 'transactions': ['users'] + 'transactions': ['users'], } instance.metadata.tables = { 'users': Mock(), @@ -540,41 +501,34 @@ def _sample_children_dummy(table_name, sampled_data, scale): # Assert expected_sample = { - 'users': DataFrameMatcher(pd.DataFrame({ - 'id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - })), - 'sessions': DataFrameMatcher(pd.DataFrame({ - 'user_id': [1, 1, 3], - 'session_id': ['a', 'b', 'c'], - 'os': ['windows', 'linux', 'mac'], - 'country': ['us', 'us', 'es'] - })), - 'transactions': DataFrameMatcher(pd.DataFrame({ - 'user_id': [1, 2, 3], - 'transaction_id': [1, 2, 3], - 'transaction_amount': [100, 1000, 200] - })) + 'users': DataFrameMatcher( + pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) + ), + 'sessions': DataFrameMatcher( + pd.DataFrame({ + 'user_id': [1, 1, 3], + 'session_id': ['a', 'b', 'c'], + 'os': ['windows', 'linux', 'mac'], + 'country': ['us', 'us', 'es'], + }) + ), + 'transactions': DataFrameMatcher( + pd.DataFrame({ + 'user_id': [1, 2, 3], + 'transaction_id': [1, 2, 3], + 'transaction_amount': [100, 1000, 200], + }) + ), } assert result == instance._finalize.return_value instance._sample_children.assert_called_once_with( - table_name='users', - sampled_data=expected_sample, - scale=1.0 + table_name='users', sampled_data=expected_sample, scale=1.0 ) instance._add_foreign_key_columns.assert_has_calls([ + call(expected_sample['sessions'], expected_sample['users'], 'sessions', 'users'), call( - expected_sample['sessions'], - expected_sample['users'], - 'sessions', - 'users' + expected_sample['transactions'], expected_sample['users'], 'transactions', 'users' ), - call( - expected_sample['transactions'], - expected_sample['users'], - 'transactions', - 'users' - ) ]) instance._finalize.assert_called_once_with(expected_sample) @@ -585,25 +539,14 @@ def test___enforce_table_size_too_many_rows(self): """ # Setup instance = MagicMock() - data = { - 'parent': pd.DataFrame({ - 'fk': ['a', 'b', 'c'], - '__child__fk__num_rows': [1, 2, 3] - }) - } + data = {'parent': pd.DataFrame({'fk': ['a', 'b', 'c'], '__child__fk__num_rows': [1, 2, 3]})} instance.metadata._get_foreign_keys.return_value = ['fk'] instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} # Run - BaseHierarchicalSampler._enforce_table_size( - instance, - 'child', - 'parent', - 1.0, - data - ) + BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [1, 1, 2] @@ -615,25 +558,14 @@ def test___enforce_table_size_not_enough_rows(self): """ # Setup instance = MagicMock() - data = { - 'parent': pd.DataFrame({ - 'fk': ['a', 'b', 'c'], - '__child__fk__num_rows': [1, 1, 1] - }) - } + data = {'parent': pd.DataFrame({'fk': ['a', 'b', 'c'], '__child__fk__num_rows': [1, 1, 1]})} instance.metadata._get_foreign_keys.return_value = ['fk'] instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} # Run - BaseHierarchicalSampler._enforce_table_size( - instance, - 'child', - 'parent', - 1.0, - data - ) + BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [2, 1, 1] @@ -645,25 +577,14 @@ def test___enforce_table_size_clipping(self): """ # Setup instance = MagicMock() - data = { - 'parent': pd.DataFrame({ - 'fk': ['a', 'b', 'c'], - '__child__fk__num_rows': [1, 2, 5] - }) - } + data = {'parent': pd.DataFrame({'fk': ['a', 'b', 'c'], '__child__fk__num_rows': [1, 2, 5]})} instance.metadata._get_foreign_keys.return_value = ['fk'] instance._min_child_rows = {'__child__fk__num_rows': 2} instance._max_child_rows = {'__child__fk__num_rows': 4} instance._table_sizes = {'child': 8} # Run - BaseHierarchicalSampler._enforce_table_size( - instance, - 'child', - 'parent', - 1.0, - data - ) + BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [2, 2, 4] @@ -675,25 +596,14 @@ def test___enforce_table_size_too_small_sample(self): """ # Setup instance = MagicMock() - data = { - 'parent': pd.DataFrame({ - 'fk': ['a', 'b', 'c'], - '__child__fk__num_rows': [1, 2, 3] - }) - } + data = {'parent': pd.DataFrame({'fk': ['a', 'b', 'c'], '__child__fk__num_rows': [1, 2, 3]})} instance.metadata._get_foreign_keys.return_value = ['fk'] instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} # Run - BaseHierarchicalSampler._enforce_table_size( - instance, - 'child', - 'parent', - .001, - data - ) + BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 0.001, data) # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [0, 0, 0] diff --git a/tests/unit/sampling/test_independent_sampler.py b/tests/unit/sampling/test_independent_sampler.py index a46bf2cce..f45215ba2 100644 --- a/tests/unit/sampling/test_independent_sampler.py +++ b/tests/unit/sampling/test_independent_sampler.py @@ -8,8 +8,7 @@ from tests.utils import DataFrameMatcher, get_multi_table_metadata -class TestBaseIndependentSampler(): - +class TestBaseIndependentSampler: def test___init__(self): """Test the default initialization of the ``BaseIndependentSampler``.""" # Run @@ -33,7 +32,7 @@ def test__add_foreign_key_columns(self): child_table=pd.DataFrame(), parent_table=pd.DataFrame(), child_name='oseba', - parent_name='nescra' + parent_name='nescra', ) def test__sample_table(self): @@ -42,7 +41,7 @@ def test__sample_table(self): table_synthesizer = Mock() table_synthesizer._sample_batch.return_value = pd.DataFrame({ 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] + 'name': ['John', 'Doe', 'Johanna'], }) instance = Mock() @@ -59,19 +58,17 @@ def test__sample_table(self): 'name': ['John', 'Doe', 'Johanna'], }) } - table_synthesizer._sample_batch.assert_called_once_with( - 3, - keep_extra_columns=True - ) + table_synthesizer._sample_batch.assert_called_once_with(3, keep_extra_columns=True) pd.testing.assert_frame_equal(result['users'], expected_result['users']) def test__connect_table(self): """Test the method adds all foreign key columns to each table.""" + def _get_all_foreign_keys(child): foreign_keys = { 'users': [], 'sessions': ['users_id'], - 'transactions': ['users_id', 'sessions_id'] + 'transactions': ['users_id', 'sessions_id'], } return foreign_keys[child] @@ -90,19 +87,19 @@ def _add_foreign_key_columns(child_data, parent_data, child, parent): instance.metadata._get_parent_map.return_value = { 'sessions': {'users'}, 'transactions': {'sessions', 'users'}, - 'users': set() + 'users': set(), } instance.metadata._get_child_map.return_value = { 'users': ['transactions', 'sessions'], 'sessions': {'transactions'}, - 'transactions': set() + 'transactions': set(), } instance.metadata._get_all_foreign_keys.side_effect = _get_all_foreign_keys instance._add_foreign_key_columns.side_effect = _add_foreign_key_columns sampled_data = { 'users': pd.DataFrame(dtype='object'), 'sessions': pd.DataFrame(dtype='object'), - 'transactions': pd.DataFrame(dtype='object') + 'transactions': pd.DataFrame(dtype='object'), } # Run @@ -110,15 +107,24 @@ def _add_foreign_key_columns(child_data, parent_data, child, parent): # Assert _add_foreign_key_columns_mock.assert_has_calls([ - call(DataFrameMatcher(pd.DataFrame(dtype='object')), - DataFrameMatcher(pd.DataFrame(dtype='object')), - 'transactions', 'users'), - call(DataFrameMatcher(pd.DataFrame(dtype='object')), - DataFrameMatcher(pd.DataFrame(dtype='object')), - 'sessions', 'users'), - call(DataFrameMatcher(pd.DataFrame(dtype='object', columns=['users_id'])), - DataFrameMatcher(pd.DataFrame(dtype='object', columns=['users_id'])), - 'transactions', 'sessions'), + call( + DataFrameMatcher(pd.DataFrame(dtype='object')), + DataFrameMatcher(pd.DataFrame(dtype='object')), + 'transactions', + 'users', + ), + call( + DataFrameMatcher(pd.DataFrame(dtype='object')), + DataFrameMatcher(pd.DataFrame(dtype='object')), + 'sessions', + 'users', + ), + call( + DataFrameMatcher(pd.DataFrame(dtype='object', columns=['users_id'])), + DataFrameMatcher(pd.DataFrame(dtype='object', columns=['users_id'])), + 'transactions', + 'sessions', + ), ]) def test__finalize(self): @@ -129,7 +135,7 @@ def test__finalize(self): metadata._get_parent_map.return_value = { 'sessions': ['users'], 'transactions': ['sessions'], - 'users': set() + 'users': set(), } instance.metadata = metadata @@ -159,18 +165,15 @@ def test__finalize(self): 'user_id': np.int64, 'session_id': str, 'os': str, - 'country': str + 'country': str, } transactions_synth = Mock() - transactions_synth._data_processor._dtypes = { - 'transaction_id': np.int64, - 'session_id': str - } + transactions_synth._data_processor._dtypes = {'transaction_id': np.int64, 'session_id': str} instance._table_synthesizers = { 'users': users_synth, 'sessions': sessions_synth, - 'transactions': transactions_synth + 'transactions': transactions_synth, } # Run @@ -205,7 +208,7 @@ def test__finalize_id_being_string(self, mock_logger): metadata._get_parent_map.return_value = { 'sessions': ['users'], 'transactions': ['sessions'], - 'users': set() + 'users': set(), } instance.metadata = metadata @@ -243,7 +246,7 @@ def test__finalize_id_being_string(self, mock_logger): 'user_id': np.int64, 'session_id': str, 'os': str, - 'country': str + 'country': str, } sessions_synth._data_processor._DTYPE_TO_SDTYPE = { @@ -254,10 +257,7 @@ def test__finalize_id_being_string(self, mock_logger): 'M': 'datetime', } transactions_synth = Mock() - transactions_synth._data_processor._dtypes = { - 'transaction_id': np.int64, - 'session_id': str - } + transactions_synth._data_processor._dtypes = {'transaction_id': np.int64, 'session_id': str} transactions_synth._data_processor._DTYPE_TO_SDTYPE = { 'i': 'numerical', 'f': 'numerical', @@ -269,7 +269,7 @@ def test__finalize_id_being_string(self, mock_logger): instance._table_synthesizers = { 'users': users_synth, 'sessions': sessions_synth, - 'transactions': transactions_synth + 'transactions': transactions_synth, } # Run @@ -318,26 +318,27 @@ def test__sample(self): _connect_tables_mock = Mock() expected_sample = { - 'users': pd.DataFrame({ - 'user_id': [1, 2, 3], - 'name': ['John', 'Doe', 'Johanna'] - }), + 'users': pd.DataFrame({'user_id': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}), 'sessions': pd.DataFrame({ 'user_id': [1, 1, 3], 'session_id': ['a', 'b', 'c'], 'os': ['windows', 'linux', 'mac'], - 'country': ['us', 'us', 'es'] + 'country': ['us', 'us', 'es'], }), - 'transactions': pd.DataFrame(dtype='Int64') + 'transactions': pd.DataFrame(dtype='Int64'), } connected_transactions = pd.DataFrame({ 'user_id': [1, 3, 1], - 'sessions_id': ['a', 'c', 'b'] + 'sessions_id': ['a', 'c', 'b'], }) def _sample_table(synthesizer, table_name, num_rows, sampled_data): - _sample_table_mock(synthesizer=synthesizer, table_name=table_name, num_rows=num_rows, - sampled_data=sampled_data.copy()) + _sample_table_mock( + synthesizer=synthesizer, + table_name=table_name, + num_rows=num_rows, + sampled_data=sampled_data.copy(), + ) sampled_data[table_name] = expected_sample[table_name] def _connect_tables(sampled_data): @@ -347,17 +348,13 @@ def _connect_tables(sampled_data): instance._sample_table.side_effect = _sample_table instance._connect_tables.side_effect = _connect_tables - instance._table_sizes = { - 'users': 3, - 'transactions': 9, - 'sessions': 5 - } + instance._table_sizes = {'users': 3, 'transactions': 9, 'sessions': 5} instance.metadata.relationships = [ { 'parent_table_name': 'users', 'parent_primary_key': 'id', 'child_table_name': 'sessions', - 'child_foreign_key': 'id' + 'child_foreign_key': 'id', } ] users_synthesizer = Mock() @@ -366,7 +363,7 @@ def _connect_tables(sampled_data): instance._table_synthesizers = { 'users': users_synthesizer, 'sessions': sessions_synthesizer, - 'transactions': transactions_synthesizer + 'transactions': transactions_synthesizer, } instance.metadata.tables = { 'users': Mock(), @@ -378,20 +375,23 @@ def _connect_tables(sampled_data): result = BaseIndependentSampler._sample(instance) # Assert - users_call = call(synthesizer=users_synthesizer, table_name='users', num_rows=3, - sampled_data={}) - sessions_call = call(synthesizer=sessions_synthesizer, table_name='sessions', num_rows=5, - sampled_data={ - 'users': DataFrameMatcher(expected_sample['users']) - }) + users_call = call( + synthesizer=users_synthesizer, table_name='users', num_rows=3, sampled_data={} + ) + sessions_call = call( + synthesizer=sessions_synthesizer, + table_name='sessions', + num_rows=5, + sampled_data={'users': DataFrameMatcher(expected_sample['users'])}, + ) transactions_call = call( synthesizer=transactions_synthesizer, table_name='transactions', num_rows=9, sampled_data={ 'users': DataFrameMatcher(expected_sample['users']), - 'sessions': DataFrameMatcher(expected_sample['sessions']) - } + 'sessions': DataFrameMatcher(expected_sample['sessions']), + }, ) _sample_table_mock.assert_has_calls([users_call, sessions_call, transactions_call]) @@ -401,6 +401,6 @@ def _connect_tables(sampled_data): instance._finalize.assert_called_once_with({ 'users': DataFrameMatcher(expected_sample['users']), 'sessions': DataFrameMatcher(expected_sample['sessions']), - 'transactions': DataFrameMatcher(connected_transactions) + 'transactions': DataFrameMatcher(connected_transactions), }) assert result == instance._finalize.return_value diff --git a/tests/unit/sampling/test_tabular.py b/tests/unit/sampling/test_tabular.py index 65f6048bf..e135bcc58 100644 --- a/tests/unit/sampling/test_tabular.py +++ b/tests/unit/sampling/test_tabular.py @@ -1,9 +1,9 @@ """Tests for the sdv.sampling.tabular module.""" -from sdv.sampling.tabular import Condition +from sdv.sampling.tabular import Condition -class TestCondition(): +class TestCondition: def test___init__(self): """Test ```Condition.__init__`` method. diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 04ec2eaf6..79a5c105d 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -16,7 +16,6 @@ class TestPARSynthesizer: - def get_metadata(self, add_sequence_key=True, add_sequence_index=False): metadata = SingleTableMetadata() metadata.add_column('time', sdtype='datetime') @@ -36,7 +35,7 @@ def get_data(self): 'time': ['2020-01-01', '2020-01-02', '2020-01-03'], 'gender': ['F', 'M', 'M'], 'name': ['Jane', 'John', 'Doe'], - 'measurement': [55, 60, 65] + 'measurement': [55, 60, 65], }) return data @@ -59,7 +58,7 @@ def test___init__(self): epochs=10, sample_size=5, cuda=False, - verbose=False + verbose=False, ) # Assert @@ -72,14 +71,14 @@ def test___init__(self): 'epochs': 10, 'sample_size': 5, 'cuda': False, - 'verbose': False + 'verbose': False, } assert isinstance(synthesizer._data_processor, DataProcessor) assert synthesizer._data_processor.metadata == metadata assert isinstance(synthesizer._context_synthesizer, GaussianCopulaSynthesizer) assert synthesizer._context_synthesizer.metadata.columns == { 'gender': {'sdtype': 'categorical'}, - 'name': {'sdtype': 'id'} + 'name': {'sdtype': 'id'}, } def test___init___no_sequence_key(self): @@ -94,7 +93,6 @@ def test___init___no_sequence_key(self): error_message = ( 'The PARSythesizer is designed for multi-sequence data, identifiable through a ' 'sequence key. Your metadata does not include a sequence key.' - ) with pytest.raises(SynthesizerInputError, match=error_message): PARSynthesizer( @@ -106,44 +104,33 @@ def test___init___no_sequence_key(self): epochs=10, sample_size=5, cuda=False, - verbose=False + verbose=False, ) def test_add_constraints(self): """Test that that only simple constraints can be added to PARSynthesizer.""" # Setup metadata = self.get_metadata() - synthesizer = PARSynthesizer(metadata=metadata, - context_columns=['name', 'measurement']) + synthesizer = PARSynthesizer(metadata=metadata, context_columns=['name', 'measurement']) name_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_name': 'name' - } + 'constraint_parameters': {'column_name': 'name'}, } measurement_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_name': 'measurement' - } + 'constraint_parameters': {'column_name': 'measurement'}, } gender_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_name': 'gender' - } + 'constraint_parameters': {'column_name': 'gender'}, } time_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_name': 'time' - } + 'constraint_parameters': {'column_name': 'time'}, } multi_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': { - 'column_names': ['name', 'time'] - } + 'constraint_parameters': {'column_names': ['name', 'time']}, } overlapping_error_msg = re.escape( 'The PARSynthesizer cannot accommodate multiple constraints ' @@ -184,9 +171,7 @@ def test_load_custom_constraint_classes(self): synthesizer = PARSynthesizer(metadata=metadata) # Run and Assert - error_message = re.escape( - 'The PARSynthesizer cannot accommodate custom constraints.' - ) + error_message = re.escape('The PARSynthesizer cannot accommodate custom constraints.') with pytest.raises(SynthesizerInputError, match=error_message): synthesizer.load_custom_constraint_classes(filepath='test', class_names=[]) @@ -197,9 +182,7 @@ def test_add_custom_constraint_class(self): synthesizer = PARSynthesizer(metadata=metadata) # Run and Assert - error_message = re.escape( - 'The PARSynthesizer cannot accommodate custom constraints.' - ) + error_message = re.escape('The PARSynthesizer cannot accommodate custom constraints.') with pytest.raises(SynthesizerInputError, match=error_message): synthesizer.add_custom_constraint_class(Mock(), class_name='Mock') @@ -217,7 +200,7 @@ def test_get_parameters(self): epochs=10, sample_size=5, cuda=False, - verbose=False + verbose=False, ) # Run @@ -234,7 +217,7 @@ def test_get_parameters(self): 'epochs': 10, 'sample_size': 5, 'cuda': False, - 'verbose': False + 'verbose': False, } def test_get_metadata(self): @@ -250,7 +233,7 @@ def test_get_metadata(self): epochs=10, sample_size=5, cuda=False, - verbose=False + verbose=False, ) # Run @@ -279,10 +262,7 @@ def test_validate_context_columns_unique_per_sequence_key(self): metadata.add_column('ct_col1', sdtype='numerical') metadata.add_column('ct_col2', sdtype='numerical') metadata.set_sequence_key('sk_col1') - instance = PARSynthesizer( - metadata=metadata, - context_columns=['ct_col1', 'ct_col2'] - ) + instance = PARSynthesizer(metadata=metadata, context_columns=['ct_col1', 'ct_col2']) # Run and Assert err_msg = re.escape( @@ -302,14 +282,12 @@ def test_validate_context_columns_unique_per_sequence_key(self): def test__transform_sequence(self): # Setup metadata = self.get_metadata(add_sequence_index=True) - par = PARSynthesizer( - metadata=metadata - ) + par = PARSynthesizer(metadata=metadata) data = pd.DataFrame({ 'time': [1, 2, 4, 5], 'gender': ['F', 'M', 'M', 'M'], 'name': ['Jane', 'John', 'John', 'John'], - 'measurement': [55, 60, 65, 68] + 'measurement': [55, 60, 65, 68], }) # Run @@ -317,11 +295,11 @@ def test__transform_sequence(self): # Assert expected = pd.DataFrame({ - 'time': [1., 2., 2., 1.], + 'time': [1.0, 2.0, 2.0, 1.0], 'gender': ['F', 'M', 'M', 'M'], 'name': ['Jane', 'John', 'John', 'John'], 'measurement': [55, 60, 65, 68], - 'time.context': [1, 2, 2, 2] + 'time.context': [1, 2, 2, 2], }) pd.testing.assert_frame_equal(transformed_data, expected) assert par._extra_context_columns == {'time.context': {'sdtype': 'numerical'}} @@ -331,9 +309,7 @@ def test__transform_sequence(self): def test__transform_sequence_index_single_instances(self): # Setup metadata = self.get_metadata(add_sequence_index=True) - par = PARSynthesizer( - metadata=metadata - ) + par = PARSynthesizer(metadata=metadata) data = self.get_data() # Run @@ -341,11 +317,11 @@ def test__transform_sequence_index_single_instances(self): # Assert expected = pd.DataFrame({ - 'time': [0., 0., 0.], + 'time': [0.0, 0.0, 0.0], 'gender': ['F', 'M', 'M'], 'name': ['Jane', 'John', 'Doe'], 'measurement': [55, 60, 65], - 'time.context': ['2020-01-01', '2020-01-02', '2020-01-03'] + 'time.context': ['2020-01-01', '2020-01-02', '2020-01-03'], }) pd.testing.assert_frame_equal(transformed_data, expected) assert par._extra_context_columns == {'time.context': {'sdtype': 'numerical'}} @@ -355,9 +331,7 @@ def test__transform_sequence_index_single_instances(self): def test__transform_sequence_index_non_unique_sequence_key(self): # Setup metadata = self.get_metadata(add_sequence_index=True) - par = PARSynthesizer( - metadata=metadata - ) + par = PARSynthesizer(metadata=metadata) data = self.get_data() data = data[data['name'] == 'John'].reset_index(drop=True) @@ -366,11 +340,11 @@ def test__transform_sequence_index_non_unique_sequence_key(self): # Assert expected = pd.DataFrame({ - 'time': [0.], + 'time': [0.0], 'gender': ['M'], 'name': ['John'], 'measurement': [60], - 'time.context': ['2020-01-02'] + 'time.context': ['2020-01-02'], }) pd.testing.assert_frame_equal(transformed_data, expected) assert par._extra_context_columns == {'time.context': {'sdtype': 'numerical'}} @@ -386,9 +360,7 @@ def test_preprocess_transformers_not_assigned(self, base_preprocess_mock): """ # Setup metadata = self.get_metadata() - par = PARSynthesizer( - metadata=metadata - ) + par = PARSynthesizer(metadata=metadata) par.auto_assign_transformers = Mock() par.update_transformers = Mock() data = self.get_data() @@ -410,9 +382,7 @@ def test_preprocess(self, base_preprocess_mock): """ # Setup metadata = self.get_metadata(add_sequence_index=True) - par = PARSynthesizer( - metadata=metadata - ) + par = PARSynthesizer(metadata=metadata) par._transform_sequence_index = Mock() par.auto_assign_transformers = Mock() par.update_transformers = Mock() @@ -475,14 +445,7 @@ def test__fit_context_model_with_context_columns(self, gaussian_copula_mock): par = PARSynthesizer(metadata, context_columns=['gender']) initial_synthesizer = Mock() context_metadata = SingleTableMetadata.load_from_dict({ - 'columns': { - 'gender': { - 'sdtype': 'categorical' - }, - 'name': { - 'sdtype': 'id' - } - } + 'columns': {'gender': {'sdtype': 'categorical'}, 'name': {'sdtype': 'id'}} }) par._context_synthesizer = initial_synthesizer par._get_context_metadata = Mock() @@ -495,12 +458,12 @@ def test__fit_context_model_with_context_columns(self, gaussian_copula_mock): gaussian_copula_mock.assert_called_with( context_metadata, enforce_min_max_values=initial_synthesizer.enforce_min_max_values, - enforce_rounding=initial_synthesizer.enforce_rounding + enforce_rounding=initial_synthesizer.enforce_rounding, ) fitted_data = gaussian_copula_mock().fit.mock_calls[0][1][0] expected_fitted_data = pd.DataFrame({ 'name': ['Doe', 'Jane', 'John'], - 'gender': ['M', 'F', 'M'] + 'gender': ['M', 'F', 'M'], }) pd.testing.assert_frame_equal(fitted_data.sort_values(by='name'), expected_fitted_data) @@ -524,10 +487,7 @@ def test__fit_context_model_without_context_columns(self, uuid_mock, gaussian_co # Assert fitted_data = par._context_synthesizer.fit.mock_calls[0][1][0] - expected_fitted_data = pd.DataFrame({ - 'name': ['Doe', 'Jane', 'John'], - 'abc': [0, 0, 0] - }) + expected_fitted_data = pd.DataFrame({'name': ['Doe', 'Jane', 'John'], 'abc': [0, 0, 0]}) pd.testing.assert_frame_equal(fitted_data.sort_values(by='name'), expected_fitted_data) @patch('sdv.sequential.par.PARModel') @@ -542,14 +502,11 @@ def test__fit_sequence_columns(self, assemble_sequences_mock, model_mock): # Setup data = self.get_data() metadata = self.get_metadata() - par = PARSynthesizer( - metadata=metadata, - context_columns=['gender'] - ) + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['M'], dtype=object), 'data': [['2020-01-03'], [65]]}, {'context': np.array(['F'], dtype=object), 'data': [['2020-01-01'], [55]]}, - {'context': np.array(['M'], dtype=object), 'data': [['2020-01-02'], [60]]} + {'context': np.array(['M'], dtype=object), 'data': [['2020-01-02'], [60]]}, ] assemble_sequences_mock.return_value = sequences @@ -558,18 +515,11 @@ def test__fit_sequence_columns(self, assemble_sequences_mock, model_mock): # Assert assemble_sequences_mock.assert_called_once_with( - data, - ['name'], - ['gender'], - None, - None, - drop_sequence_index=False + data, ['name'], ['gender'], None, None, drop_sequence_index=False ) model_mock.assert_called_once_with(epochs=128, sample_size=1, cuda=True, verbose=False) model_mock.return_value.fit_sequences.assert_called_once_with( - sequences, - ['categorical'], - ['categorical', 'continuous'] + sequences, ['categorical'], ['categorical', 'continuous'] ) @patch('sdv.sequential.par.PARModel') @@ -586,19 +536,16 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock 'time': [1, 2, 3, 5, 8], 'gender': ['F', 'F', 'M', 'M', 'M'], 'name': ['Jane', 'Jane', 'John', 'John', 'John'], - 'measurement': [55, 60, 65, 65, 70] + 'measurement': [55, 60, 65, 65, 70], }) metadata = self.get_metadata() metadata.set_sequence_index('time') - par = PARSynthesizer( - metadata=metadata, - context_columns=['gender'] - ) + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, { 'context': np.array(['M'], dtype=object), - 'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]] + 'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]], }, ] assemble_sequences_mock.return_value = sequences @@ -618,14 +565,12 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, { 'context': np.array(['M'], dtype=object), - 'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]] - } + 'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]], + }, ] model_mock.assert_called_once_with(epochs=128, sample_size=1, cuda=True, verbose=False) model_mock.return_value.fit_sequences.assert_called_once_with( - expected_sequences, - ['categorical'], - ['continuous', 'continuous'] + expected_sequences, ['categorical'], ['continuous', 'continuous'] ) @patch('sdv.sequential.par.PARModel') @@ -634,19 +579,16 @@ def test__fit_sequence_columns_bad_dtype(self, assemble_sequences_mock, model_mo """Test the method when a column has an unsupported dtype.""" # Setup datetime = pd.Series( - [pd.to_datetime('1/1/1999'), pd.to_datetime('1/2/1999'), '1/3/1999'], - dtype='