diff --git a/sdv/_utils.py b/sdv/_utils.py index 9138dbcce..3ec466537 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -10,10 +10,16 @@ import pandas as pd from pandas.core.tools.datetimes import _guess_datetime_format_for_array +from rdt.transformers.utils import _GENERATORS from sdv import version from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError +try: + from re import _parser as sre_parse +except ImportError: + import sre_parse + def _cast_to_iterable(value): """Return a ``list`` if the input object is not a ``list`` or ``tuple``.""" @@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer): synth_version = version.public unique_id = ''.join(str(uuid.uuid4()).split('-')) return f'{class_name}_{synth_version}_{unique_id}' + + +def _get_chars_for_option(option, params): + if option not in _GENERATORS: + raise ValueError(f'REGEX operation: {option} is not supported by SDV.') + + if option == sre_parse.MAX_REPEAT: + new_option, new_params = params[2][0] # The value at the second index is the nested option + return _get_chars_for_option(new_option, new_params) + + return list(_GENERATORS[option](params, 1)[0]) + + +def get_possible_chars(regex, num_subpatterns=None): + """Get the list of possible characters a regex can create. + + Args: + regex (str): + The regex to parse. + num_subpatterns (int): + The number of sub-patterns from the regex to find characters for. + """ + parsed = sre_parse.parse(regex) + parsed = [p for p in parsed if p[0] != sre_parse.AT] + num_subpatterns = num_subpatterns or len(parsed) + possible_chars = [] + for option, params in parsed[:num_subpatterns]: + possible_chars += _get_chars_for_option(option, params) + + return possible_chars diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index f457d6786..006f6636e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -26,6 +26,7 @@ SynthesizerInputError, ) from sdv.logging import disable_single_table_logger, get_sdv_logger +from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -363,9 +364,17 @@ def preprocess(self, data): processed_data = {} pbar_args = self._get_pbar_args(desc='Preprocess Tables') for table_name, table_data in tqdm(data.items(), **pbar_args): - synthesizer = self._table_synthesizers[table_name] - self._assign_table_transformers(synthesizer, table_name, table_data) - processed_data[table_name] = synthesizer._preprocess(table_data) + try: + synthesizer = self._table_synthesizers[table_name] + self._assign_table_transformers(synthesizer, table_name, table_data) + processed_data[table_name] = synthesizer._preprocess(table_data) + except SynthesizerInputError as e: + if INT_REGEX_ZERO_ERROR_MESSAGE in str(e): + raise SynthesizerInputError( + f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}' + ) + + raise e for table in list_of_changed_tables: data[table].columns = self._original_table_columns[table] diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 77ab31ce0..e9ce2178f 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -24,6 +24,7 @@ check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, + get_possible_chars, ) from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor @@ -41,6 +42,10 @@ COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 +INT_REGEX_ZERO_ERROR_MESSAGE = ( + 'is stored as an int but the Regex allows it to start with "0". Please remove the Regex ' + 'or update it to correspond to valid ints.' +) class BaseSynthesizer: @@ -163,6 +168,17 @@ def _validate(self, data): """ return [] + def _validate_primary_key(self, data): + primary_key = self.metadata.primary_key + is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key]) + regex = self.metadata.columns.get(primary_key, {}).get('regex_format') + if is_int and regex: + possible_characters = get_possible_chars(regex, 1) + if '0' in possible_characters: + raise SynthesizerInputError( + f'Primary key "{primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}.' + ) + def validate(self, data): """Validate data. @@ -184,6 +200,7 @@ def validate(self, data): * values of a column don't satisfy their sdtype """ self._validate_metadata(data) + self._validate_primary_key(data) self._validate_constraints(data) # Retaining the logic of returning errors and raising them here to maintain consistency diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index ef7f0b464..e6fa27e2e 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1988,3 +1988,34 @@ def test_hma_synthesizer_with_fixed_combinations(): assert len(sampled['users']) > 1 assert len(sampled['records']) > 1 assert len(sampled['locations']) > 1 + + +REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w'] + + +@pytest.mark.parametrize('regex', REGEXES) +def test_fit_int_primary_key_regex_includes_zero(regex): + """Test that sdv errors if the primary key has a regex, is an int, and can start with 0.""" + # Setup + parent_data = pd.DataFrame({ + 'parent_id': [1, 2, 3, 4, 5, 6], + 'col': ['a', 'b', 'a', 'b', 'a', 'b'], + }) + child_data = pd.DataFrame({'id': [1, 2, 3, 4, 5, 6], 'parent_id': [1, 2, 3, 4, 5, 6]}) + data = { + 'parent_data': parent_data, + 'child_data': child_data, + } + metadata = MultiTableMetadata() + metadata.detect_from_dataframes(data) + metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) + metadata.set_primary_key('parent_data', 'parent_id') + + # Run and Assert + instance = HMASynthesizer(metadata) + message = ( + 'Primary key for table "parent_data" is stored as an int but the Regex allows it to start ' + 'with "0". Please remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.fit(data) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index dc9c328d5..27ab09d06 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -801,3 +801,31 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class): # Assert assert sample.columns.tolist() == data.columns.tolist() + + +REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w'] + + +@pytest.mark.parametrize('regex', REGEXES) +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex): + """Test that sdv errors if the primary key has a regex, is an int, and can start with 0.""" + # Setup + data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6], + 'c': ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + metadata.update_column('a', sdtype='id', regex_format=regex) + metadata.set_primary_key('a') + + # Run and Assert + instance = synthesizer_class(metadata) + message = ( + 'Primary key "a" is stored as an int but the Regex allows it to start with "0". Please ' + 'remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.fit(data) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3e3abdbb9..c1fb04aac 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -23,6 +23,7 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer from sdv.multi_table.hma import HMASynthesizer +from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata @@ -908,6 +909,84 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) + def test_preprocess_single_table_preprocess_raises_error_0_int_regex(self): + """Test that if the single table synthesizer raises a specific error, it is reformatted. + + If a single table synthesizer raises an error about the primary key being an integer + with a regex that can start with zero, the error should be reformatted to include the + table name. + """ + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': np.arange(0, 20, 2), + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(10), + 'id_nesreca': np.arange(10), + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + synth_nesreca = Mock() + synth_oseba = Mock() + synth_upravna_enota = Mock() + synth_nesreca._preprocess.side_effect = SynthesizerInputError(INT_REGEX_ZERO_ERROR_MESSAGE) + instance._table_synthesizers = { + 'nesreca': synth_nesreca, + 'oseba': synth_oseba, + 'upravna_enota': synth_upravna_enota, + } + + # Run + message = f'Primary key for table "nesreca" {INT_REGEX_ZERO_ERROR_MESSAGE}' + with pytest.raises(SynthesizerInputError, match=message): + instance.preprocess(data) + + def test_preprocess_single_table_preprocess_raises_error(self): + """Test that if the single table synthesizer raises any other error, it is raised. + + If a single table synthesizer raises an error besides the one concerning int primary keys + starting with 0 and having a regex, then the error should be raised as is. + """ + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': np.arange(0, 20, 2), + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(10), + 'id_nesreca': np.arange(10), + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + synth_nesreca = Mock() + synth_oseba = Mock() + synth_upravna_enota = Mock() + synth_nesreca._preprocess.side_effect = SynthesizerInputError('blah') + instance._table_synthesizers = { + 'nesreca': synth_nesreca, + 'oseba': synth_oseba, + 'upravna_enota': synth_upravna_enota, + } + + # Run + with pytest.raises(SynthesizerInputError, match='blah'): + instance.preprocess(data) + @patch('sdv.multi_table.base.datetime') def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 3830f3f4f..eae85337f 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -600,6 +600,47 @@ def test_validate_raises_invalid_data_for_metadata(self): instance._validate_constraints.assert_called_once_with(data) instance._validate.assert_not_called() + def test_validate_int_primary_key_regex_starts_with_zero(self): + """Test that an error is raised if the primary key is an int that can start with 0. + + If the the primary key is stored as an int, but a regex is used with it, it is possible + that the first character can be a 0. If this happens, then we can get duplicate primary + key values since two different strings can be the same when converted ints + (eg. '00123' and '0123'). + """ + # Setup + data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']}) + metadata = Mock() + metadata.primary_key = 'key' + metadata.column_relationships = [] + metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[0-9]{3,4}'}} + instance = BaseSingleTableSynthesizer(metadata) + + # Run and Assert + message = ( + 'Primary key "key" is stored as an int but the Regex allows it to start with ' + '"0". Please remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.validate(data) + + def test_validate_int_primary_key_regex_does_not_start_with_zero(self): + """Test that no error is raised if the primary key is an int that can't start with 0. + + If the the primary key is stored as an int, but a regex is used with it, it is possible + that the first character can be a 0. If it isn't possible, then no error should be raised. + """ + # Setup + data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']}) + metadata = Mock() + metadata.primary_key = 'key' + metadata.column_relationships = [] + metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[1-9]{3,4}'}} + instance = BaseSingleTableSynthesizer(metadata) + + # Run and Assert + instance.validate(data) + def test_update_transformers_invalid_keys(self): """Test error is raised if passed transformer doesn't match key column. diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index 87520c5df..1cbcf3416 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -1,5 +1,6 @@ import operator import re +import string from datetime import datetime from unittest.mock import Mock, patch @@ -12,6 +13,7 @@ _compare_versions, _convert_to_timedelta, _create_unique_name, + _get_chars_for_option, _get_datetime_format, _get_root_tables, _is_datetime_type, @@ -19,12 +21,18 @@ check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, + get_possible_chars, ) from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.base import BaseSingleTableSynthesizer from tests.utils import SeriesMatcher +try: + from re import _parser as sre_parse +except ImportError: + import sre_parse + @patch('sdv._utils.pd.to_timedelta') def test__convert_to_timedelta(to_timedelta_mock): @@ -626,3 +634,82 @@ def test_generate_synthesizer_id(mock_version, mock_uuid): # Assert assert result == 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + + +@patch('sdv._utils._get_chars_for_option') +def test_get_possible_chars_excludes_at(mock_get_chars): + """Test that 'at' regex operations aren't included when getting chars.""" + # Setup + regex = '^[1-9]{1,2}$' + mock_get_chars.return_value = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + mock_get_chars.assert_called_once() + mock_call = mock_get_chars.mock_calls[0] + assert mock_call[1][0] == sre_parse.MAX_REPEAT + assert mock_call[1][1][0] == 1 + assert mock_call[1][1][1] == 2 + assert mock_call[1][1][2].data == [(sre_parse.IN, [(sre_parse.RANGE, (49, 57))])] + assert possible_chars == [str(i) for i in range(10)] + + +def test_get_possible_chars_unsupported_regex(): + """Test that an error is raised if the regex contains unsupported options.""" + # Setup + regex = '(ab)*' + + # Run and assert + message = 'REGEX operation: SUBPATTERN is not supported by SDV.' + with pytest.raises(ValueError, match=message): + get_possible_chars(regex) + + +@patch('sdv._utils._get_chars_for_option') +def test_get_possible_chars_handles_max_repeat(mock_get_chars): + """Test that MAX_REPEATS are handled by recursively finding the first non MAX_REPEAT. + + One valid regex option is a MAX_REPEAT. Getting all possible values for this could be slow, + so we just look for the first nexted option that isn't a max_repeat to get the possible + characters instead. + """ + # Setup + regex = '[1-9]{1,2}' + mock_get_chars.side_effect = lambda x, y: _get_chars_for_option(x, y) + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + assert len(mock_get_chars.mock_calls) == 2 + assert mock_get_chars.mock_calls[1][1] == mock_get_chars.mock_calls[0][1][1][2][0] + assert possible_chars == [str(i) for i in range(1, 10)] + + +def test_get_possible_chars_num_subpatterns(): + """Test that only characters for first x subpatterns are returned.""" + # Setup + regex = 'HID_[0-9]{3}_[a-z]{3}' + + # Run + possible_chars = get_possible_chars(regex, 3) + + # Assert + assert possible_chars == ['H', 'I', 'D'] + + +def test_get_possible_chars(): + """Test that all characters for regex are returned.""" + # Setup + regex = 'HID_[0-9]{3}_[a-z]{3}' + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + prefix = ['H', 'I', 'D', '_'] + nums = [str(i) for i in range(10)] + lowercase_letters = list(string.ascii_lowercase) + assert possible_chars == prefix + nums + ['_'] + lowercase_letters