Skip to content

Commit

Permalink
Do not error if sequence_index is numerical (#2080)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jun 20, 2024
1 parent e06ceca commit 486519d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def auto_assign_transformers(self, data):
# Ensure that sequence index does not get auto assigned with enforce_min_max_values
if self._sequence_index:
sequence_index_transformer = self.get_transformers()[self._sequence_index]
if sequence_index_transformer.enforce_min_max_values:
if sequence_index_transformer and getattr(
sequence_index_transformer, 'enforce_min_max_values', False
):
sequence_index_transformer.enforce_min_max_values = False

def _preprocess(self, data):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,22 @@ def test_par_unique_sequence_index_with_enforce_min_max():
seq_df = synth_df[synth_df['s_key'] == i]
has_duplicates = seq_df['visits'].duplicated().any()
assert not has_duplicates


def test_par_sequence_index_is_numerical():
metadata_dict = {
'sequence_index': 'time_in_cycles',
'columns': {
'engine_no': {'sdtype': 'id'},
'time_in_cycles': {'sdtype': 'numerical'},
},
'sequence_key': 'engine_no',
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)
data = pd.DataFrame({'engine_no': [0, 0, 1, 1], 'time_in_cycles': [1, 2, 3, 4]})

s1 = PARSynthesizer(metadata)
s1.fit(data)
sample = s1.sample(2, 5)
assert sample.columns.to_list() == data.columns.to_list()
28 changes: 28 additions & 0 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,34 @@ def test__fit_context_model_with_context_columns(self, gaussian_copula_mock):
})
pd.testing.assert_frame_equal(fitted_data.sort_values(by='name'), expected_fitted_data)

@patch('sdv.sequential.par.PARSynthesizer.get_transformers')
def test_auto_assign_transformers_without_enforce_min_max(self, mock_get_transfomers):
"""Test to see if auto_assign_transformers does not add enforce_min_max_values if the transformer
does not contain it already
"""
# Setup
datetime = pd.Series(
[pd.to_datetime('1/1/1999'), pd.to_datetime('1/2/1999'), '1/3/1999'], dtype='<M8[ns]'
)
data = pd.DataFrame({
'time': datetime,
'gender': ['F', 'F', 'M'],
'name': ['Jane', 'Jane', 'John'],
'measurement': [55, 60, 65],
})
metadata = self.get_metadata()
metadata.set_sequence_index('time')
mock_get_transfomers.return_value = {'time': FloatFormatter}

# Run
par = PARSynthesizer(metadata=metadata, context_columns=['gender'])
par.auto_assign_transformers(data)

# Assert
assert (
hasattr(par.get_transformers()[par._sequence_index], 'enforce_min_max_values') is False
)

@patch('sdv.sequential.par.GaussianCopulaSynthesizer')
@patch('sdv.sequential.par.uuid')
def test__fit_context_model_without_context_columns(self, uuid_mock, gaussian_copula_mock):
Expand Down

0 comments on commit 486519d

Please sign in to comment.