Skip to content

Commit

Permalink
Primary keys may not be unique for variable length regexes (#2161)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored Aug 6, 2024
1 parent a904634 commit 87318d0
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 3 deletions.
36 changes: 36 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@

import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from rdt.transformers.utils import _GENERATORS

from sdv import version
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError

try:
from re import _parser as sre_parse
except ImportError:
import sre_parse


def _cast_to_iterable(value):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
Expand Down Expand Up @@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer):
synth_version = version.public
unique_id = ''.join(str(uuid.uuid4()).split('-'))
return f'{class_name}_{synth_version}_{unique_id}'


def _get_chars_for_option(option, params):
if option not in _GENERATORS:
raise ValueError(f'REGEX operation: {option} is not supported by SDV.')

if option == sre_parse.MAX_REPEAT:
new_option, new_params = params[2][0] # The value at the second index is the nested option
return _get_chars_for_option(new_option, new_params)

return list(_GENERATORS[option](params, 1)[0])


def get_possible_chars(regex, num_subpatterns=None):
"""Get the list of possible characters a regex can create.
Args:
regex (str):
The regex to parse.
num_subpatterns (int):
The number of sub-patterns from the regex to find characters for.
"""
parsed = sre_parse.parse(regex)
parsed = [p for p in parsed if p[0] != sre_parse.AT]
num_subpatterns = num_subpatterns or len(parsed)
possible_chars = []
for option, params in parsed[:num_subpatterns]:
possible_chars += _get_chars_for_option(option, params)

return possible_chars
15 changes: 12 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -363,9 +364,17 @@ def preprocess(self, data):
processed_data = {}
pbar_args = self._get_pbar_args(desc='Preprocess Tables')
for table_name, table_data in tqdm(data.items(), **pbar_args):
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
try:
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
except SynthesizerInputError as e:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(e):
raise SynthesizerInputError(
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
)

raise e

for table in list_of_changed_tables:
data[table].columns = self._original_table_columns[table]
Expand Down
17 changes: 17 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
get_possible_chars,
)
from sdv.constraints.errors import AggregateConstraintsError
from sdv.data_processing.data_processor import DataProcessor
Expand All @@ -41,6 +42,10 @@

COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251
INT_REGEX_ZERO_ERROR_MESSAGE = (
'is stored as an int but the Regex allows it to start with "0". Please remove the Regex '
'or update it to correspond to valid ints.'
)


class BaseSynthesizer:
Expand Down Expand Up @@ -163,6 +168,17 @@ def _validate(self, data):
"""
return []

def _validate_primary_key(self, data):
primary_key = self.metadata.primary_key
is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key])
regex = self.metadata.columns.get(primary_key, {}).get('regex_format')
if is_int and regex:
possible_characters = get_possible_chars(regex, 1)
if '0' in possible_characters:
raise SynthesizerInputError(
f'Primary key "{primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}.'
)

def validate(self, data):
"""Validate data.
Expand All @@ -184,6 +200,7 @@ def validate(self, data):
* values of a column don't satisfy their sdtype
"""
self._validate_metadata(data)
self._validate_primary_key(data)
self._validate_constraints(data)

# Retaining the logic of returning errors and raising them here to maintain consistency
Expand Down
31 changes: 31 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,3 +1988,34 @@ def test_hma_synthesizer_with_fixed_combinations():
assert len(sampled['users']) > 1
assert len(sampled['records']) > 1
assert len(sampled['locations']) > 1


REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w']


@pytest.mark.parametrize('regex', REGEXES)
def test_fit_int_primary_key_regex_includes_zero(regex):
"""Test that sdv errors if the primary key has a regex, is an int, and can start with 0."""
# Setup
parent_data = pd.DataFrame({
'parent_id': [1, 2, 3, 4, 5, 6],
'col': ['a', 'b', 'a', 'b', 'a', 'b'],
})
child_data = pd.DataFrame({'id': [1, 2, 3, 4, 5, 6], 'parent_id': [1, 2, 3, 4, 5, 6]})
data = {
'parent_data': parent_data,
'child_data': child_data,
}
metadata = MultiTableMetadata()
metadata.detect_from_dataframes(data)
metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex)
metadata.set_primary_key('parent_data', 'parent_id')

# Run and Assert
instance = HMASynthesizer(metadata)
message = (
'Primary key for table "parent_data" is stored as an int but the Regex allows it to start '
'with "0". Please remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)
28 changes: 28 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,31 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class):

# Assert
assert sample.columns.tolist() == data.columns.tolist()


REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w']


@pytest.mark.parametrize('regex', REGEXES)
@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
"""Test that sdv errors if the primary key has a regex, is an int, and can start with 0."""
# Setup
data = pd.DataFrame({
'a': [1, 2, 3],
'b': [4, 5, 6],
'c': ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata.update_column('a', sdtype='id', regex_format=regex)
metadata.set_primary_key('a')

# Run and Assert
instance = synthesizer_class(metadata)
message = (
'Primary key "a" is stored as an int but the Regex allows it to start with "0". Please '
'remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)
79 changes: 79 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer
from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata
Expand Down Expand Up @@ -908,6 +909,84 @@ def test_preprocess_warning(self, mock_warnings):
"please refit the model using 'fit' or 'fit_processed_data'."
)

def test_preprocess_single_table_preprocess_raises_error_0_int_regex(self):
"""Test that if the single table synthesizer raises a specific error, it is reformatted.
If a single table synthesizer raises an error about the primary key being an integer
with a regex that can start with zero, the error should be reformatted to include the
table name.
"""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance.validate = Mock()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

synth_nesreca = Mock()
synth_oseba = Mock()
synth_upravna_enota = Mock()
synth_nesreca._preprocess.side_effect = SynthesizerInputError(INT_REGEX_ZERO_ERROR_MESSAGE)
instance._table_synthesizers = {
'nesreca': synth_nesreca,
'oseba': synth_oseba,
'upravna_enota': synth_upravna_enota,
}

# Run
message = f'Primary key for table "nesreca" {INT_REGEX_ZERO_ERROR_MESSAGE}'
with pytest.raises(SynthesizerInputError, match=message):
instance.preprocess(data)

def test_preprocess_single_table_preprocess_raises_error(self):
"""Test that if the single table synthesizer raises any other error, it is raised.
If a single table synthesizer raises an error besides the one concerning int primary keys
starting with 0 and having a regex, then the error should be raised as is.
"""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance.validate = Mock()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

synth_nesreca = Mock()
synth_oseba = Mock()
synth_upravna_enota = Mock()
synth_nesreca._preprocess.side_effect = SynthesizerInputError('blah')
instance._table_synthesizers = {
'nesreca': synth_nesreca,
'oseba': synth_oseba,
'upravna_enota': synth_upravna_enota,
}

# Run
with pytest.raises(SynthesizerInputError, match='blah'):
instance.preprocess(data)

@patch('sdv.multi_table.base.datetime')
def test_fit_processed_data(self, mock_datetime, caplog):
"""Test that fit processed data calls ``_augment_tables`` and ``_model_tables``.
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,47 @@ def test_validate_raises_invalid_data_for_metadata(self):
instance._validate_constraints.assert_called_once_with(data)
instance._validate.assert_not_called()

def test_validate_int_primary_key_regex_starts_with_zero(self):
"""Test that an error is raised if the primary key is an int that can start with 0.
If the the primary key is stored as an int, but a regex is used with it, it is possible
that the first character can be a 0. If this happens, then we can get duplicate primary
key values since two different strings can be the same when converted ints
(eg. '00123' and '0123').
"""
# Setup
data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']})
metadata = Mock()
metadata.primary_key = 'key'
metadata.column_relationships = []
metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[0-9]{3,4}'}}
instance = BaseSingleTableSynthesizer(metadata)

# Run and Assert
message = (
'Primary key "key" is stored as an int but the Regex allows it to start with '
'"0". Please remove the Regex or update it to correspond to valid ints.'
)
with pytest.raises(SynthesizerInputError, match=message):
instance.validate(data)

def test_validate_int_primary_key_regex_does_not_start_with_zero(self):
"""Test that no error is raised if the primary key is an int that can't start with 0.
If the the primary key is stored as an int, but a regex is used with it, it is possible
that the first character can be a 0. If it isn't possible, then no error should be raised.
"""
# Setup
data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']})
metadata = Mock()
metadata.primary_key = 'key'
metadata.column_relationships = []
metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[1-9]{3,4}'}}
instance = BaseSingleTableSynthesizer(metadata)

# Run and Assert
instance.validate(data)

def test_update_transformers_invalid_keys(self):
"""Test error is raised if passed transformer doesn't match key column.
Expand Down
Loading

0 comments on commit 87318d0

Please sign in to comment.