Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primary keys may not be unique for variable length regexes #2161

Merged
merged 7 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@

import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from rdt.transformers.utils import _GENERATORS

from sdv import version
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError

try:
from re import _parser as sre_parse
except ImportError:
import sre_parse


def _cast_to_iterable(value):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
Expand Down Expand Up @@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer):
synth_version = version.public
unique_id = ''.join(str(uuid.uuid4()).split('-'))
return f'{class_name}_{synth_version}_{unique_id}'


def _get_chars_for_option(option, params):
if option not in _GENERATORS:
raise ValueError(f'REGEX operation: {option} is not supported by SDV.')

if option == sre_parse.MAX_REPEAT:
new_option, new_params = params[2][0] # The value at the second index is the nested option
return _get_chars_for_option(new_option, new_params)

return list(_GENERATORS[option](params, 1)[0])


def get_possible_chars(regex, num_subpatterns=None):
"""Get the list of possible characters a regex can create.

Args:
regex (str):
The regex to parse.
num_subpatterns (int):
The number of sub-patterns from the regex to find characters for.
"""
parsed = sre_parse.parse(regex)
parsed = [p for p in parsed if p[0] != sre_parse.AT]
num_subpatterns = num_subpatterns or len(parsed)
possible_chars = []
for option, params in parsed[:num_subpatterns]:
possible_chars += _get_chars_for_option(option, params)

return possible_chars
15 changes: 12 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -363,9 +364,17 @@ def preprocess(self, data):
processed_data = {}
pbar_args = self._get_pbar_args(desc='Preprocess Tables')
for table_name, table_data in tqdm(data.items(), **pbar_args):
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
try:
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
except SynthesizerInputError as e:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(e):
raise SynthesizerInputError(
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
)

raise e
pvk-developer marked this conversation as resolved.
Show resolved Hide resolved

for table in list_of_changed_tables:
data[table].columns = self._original_table_columns[table]
Expand Down
17 changes: 17 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
get_possible_chars,
)
from sdv.constraints.errors import AggregateConstraintsError
from sdv.data_processing.data_processor import DataProcessor
Expand All @@ -41,6 +42,10 @@

COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251
INT_REGEX_ZERO_ERROR_MESSAGE = (
'is stored as an int but the Regex allows it to start with "0". Please remove the Regex '
'or update it to correspond to valid ints.'
)


class BaseSynthesizer:
Expand Down Expand Up @@ -163,6 +168,17 @@ def _validate(self, data):
"""
return []

def _validate_primary_key(self, data):
primary_key = self.metadata.primary_key
is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key])
regex = self.metadata.columns.get(primary_key, {}).get('regex_format')
if is_int and regex:
possible_characters = get_possible_chars(regex, 1)
if '0' in possible_characters:
raise SynthesizerInputError(
f'Primary key "{primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}.'
)

def validate(self, data):
"""Validate data.

Expand All @@ -184,6 +200,7 @@ def validate(self, data):
* values of a column don't satisfy their sdtype
"""
self._validate_metadata(data)
self._validate_primary_key(data)
self._validate_constraints(data)

# Retaining the logic of returning errors and raising them here to maintain consistency
Expand Down
31 changes: 31 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,3 +1988,34 @@ def test_hma_synthesizer_with_fixed_combinations():
assert len(sampled['users']) > 1
assert len(sampled['records']) > 1
assert len(sampled['locations']) > 1


REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w']


@pytest.mark.parametrize('regex', REGEXES)
def test_fit_int_primary_key_regex_includes_zero(regex):
"""Test that sdv errors if the primary key has a regex, is an int, and can start with 0."""
# Setup
parent_data = pd.DataFrame({
'parent_id': [1, 2, 3, 4, 5, 6],
'col': ['a', 'b', 'a', 'b', 'a', 'b'],
})
child_data = pd.DataFrame({'id': [1, 2, 3, 4, 5, 6], 'parent_id': [1, 2, 3, 4, 5, 6]})
data = {
'parent_data': parent_data,
'child_data': child_data,
}
metadata = MultiTableMetadata()
metadata.detect_from_dataframes(data)
metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex)
metadata.set_primary_key('parent_data', 'parent_id')

# Run and Assert
instance = HMASynthesizer(metadata)
message = (
'Primary key for table "parent_data" is stored as an int but the Regex allows it to start '
'with "0". Please remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)
28 changes: 28 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,31 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class):

# Assert
assert sample.columns.tolist() == data.columns.tolist()


REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w']


@pytest.mark.parametrize('regex', REGEXES)
@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
"""Test that sdv errors if the primary key has a regex, is an int, and can start with 0."""
# Setup
data = pd.DataFrame({
'a': [1, 2, 3],
'b': [4, 5, 6],
'c': ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata.update_column('a', sdtype='id', regex_format=regex)
metadata.set_primary_key('a')

# Run and Assert
instance = synthesizer_class(metadata)
message = (
'Primary key "a" is stored as an int but the Regex allows it to start with "0". Please '
'remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)
79 changes: 79 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer
from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata
Expand Down Expand Up @@ -908,6 +909,84 @@ def test_preprocess_warning(self, mock_warnings):
"please refit the model using 'fit' or 'fit_processed_data'."
)

def test_preprocess_single_table_preprocess_raises_error_0_int_regex(self):
"""Test that if the single table synthesizer raises a specific error, it is reformatted.

If a single table synthesizer raises an error about the primary key being an integer
with a regex that can start with zero, the error should be reformatted to include the
table name.
"""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance.validate = Mock()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

synth_nesreca = Mock()
synth_oseba = Mock()
synth_upravna_enota = Mock()
synth_nesreca._preprocess.side_effect = SynthesizerInputError(INT_REGEX_ZERO_ERROR_MESSAGE)
instance._table_synthesizers = {
'nesreca': synth_nesreca,
'oseba': synth_oseba,
'upravna_enota': synth_upravna_enota,
}

# Run
message = f'Primary key for table "nesreca" {INT_REGEX_ZERO_ERROR_MESSAGE}'
with pytest.raises(SynthesizerInputError, match=message):
instance.preprocess(data)

def test_preprocess_single_table_preprocess_raises_error(self):
"""Test that if the single table synthesizer raises any other error, it is raised.

If a single table synthesizer raises an error besides the one concerning int primary keys
starting with 0 and having a regex, then the error should be raised as is.
"""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance.validate = Mock()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

synth_nesreca = Mock()
synth_oseba = Mock()
synth_upravna_enota = Mock()
synth_nesreca._preprocess.side_effect = SynthesizerInputError('blah')
instance._table_synthesizers = {
'nesreca': synth_nesreca,
'oseba': synth_oseba,
'upravna_enota': synth_upravna_enota,
}

# Run
with pytest.raises(SynthesizerInputError, match='blah'):
instance.preprocess(data)

@patch('sdv.multi_table.base.datetime')
def test_fit_processed_data(self, mock_datetime, caplog):
"""Test that fit processed data calls ``_augment_tables`` and ``_model_tables``.
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,47 @@ def test_validate_raises_invalid_data_for_metadata(self):
instance._validate_constraints.assert_called_once_with(data)
instance._validate.assert_not_called()

def test_validate_int_primary_key_regex_starts_with_zero(self):
"""Test that an error is raised if the primary key is an int that can start with 0.

If the the primary key is stored as an int, but a regex is used with it, it is possible
that the first character can be a 0. If this happens, then we can get duplicate primary
key values since two different strings can be the same when converted ints
(eg. '00123' and '0123').
"""
# Setup
data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']})
metadata = Mock()
metadata.primary_key = 'key'
metadata.column_relationships = []
metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[0-9]{3,4}'}}
instance = BaseSingleTableSynthesizer(metadata)

# Run and Assert
message = (
'Primary key "key" is stored as an int but the Regex allows it to start with '
'"0". Please remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.validate(data)

def test_validate_int_primary_key_regex_does_not_start_with_zero(self):
"""Test that no error is raised if the primary key is an int that can't start with 0.

If the the primary key is stored as an int, but a regex is used with it, it is possible
that the first character can be a 0. If it isn't possible, then no error should be raised.
"""
# Setup
data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']})
metadata = Mock()
metadata.primary_key = 'key'
metadata.column_relationships = []
metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[1-9]{3,4}'}}
instance = BaseSingleTableSynthesizer(metadata)

# Run and Assert
instance.validate(data)

def test_update_transformers_invalid_keys(self):
"""Test error is raised if passed transformer doesn't match key column.

Expand Down
Loading
Loading