Skip to content

Commit

Permalink
Do not enforce min/max on sequence index column (#2043)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jun 5, 2024
1 parent 29cf341 commit 70c2bf7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
22 changes: 22 additions & 0 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,28 @@ def _transform_sequence_index(self, data):

return data

def auto_assign_transformers(self, data):
"""Automatically assign the required transformers for the given data and constraints.
This method will automatically set a configuration to the ``rdt.HyperTransformer``
with the required transformers for the current data.
Args:
data (dict):
Mapping of table name to pandas.DataFrame.
Raises:
InvalidDataError:
If a table of the data is not present in the metadata.
"""
super().auto_assign_transformers(data)

# Ensure that sequence index does not get auto assigned with enforce_min_max_values
if self._sequence_index:
sequence_index_transformer = self.get_transformers()[self._sequence_index]
if sequence_index_transformer.enforce_min_max_values:
sequence_index_transformer.enforce_min_max_values = False

def _preprocess(self, data):
"""Transform the raw data to numerical space.
Expand Down
42 changes: 42 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,45 @@ def test_par_missing_sequence_index():
# Assert
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()


def test_par_unique_sequence_index_with_enforce_min_max():
"""Test to see if there are duplicate sequence index values
when sequence_length is higher than real data
"""
# Setup
test_id = list(range(10))
s_key = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
visits = [
'2021-01-01', '2021-01-03', '2021-01-05', '2021-01-07', '2021-01-09',
'2021-09-11', '2021-09-17', '2021-10-01', '2021-10-08', '2021-11-01'
]
pre_date = [
'2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05',
'2021-04-01', '2021-04-02', '2021-04-03', '2021-04-04', '2021-04-05'
]
test_df = pd.DataFrame({
'id': test_id,
's_key': s_key,
'visits': visits,
'pre_date': pre_date
})
test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply(
pd.to_datetime, format='%Y-%m-%d', errors='coerce')
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(test_df)
metadata.update_column(column_name='s_key', sdtype='id')
metadata.set_sequence_key('s_key')
metadata.set_sequence_index('visits')
synthesizer = PARSynthesizer(metadata, enforce_min_max_values=True,
enforce_rounding=False, epochs=100, verbose=True)

# Run
synthesizer.fit(test_df)
synth_df = synthesizer.sample(num_sequences=50, sequence_length=50)

# Assert
for i in synth_df['s_key'].unique():
seq_df = synth_df[synth_df['s_key'] == i]
has_duplicates = seq_df['visits'].duplicated().any()
assert not has_duplicates
3 changes: 3 additions & 0 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def test_preprocess(self, base_preprocess_mock):
par._transform_sequence_index = Mock()
par.auto_assign_transformers = Mock()
par.update_transformers = Mock()
get_transform_mock = Mock()
get_transform_mock.return_value = {'time': Mock()}
par.get_transformers = get_transform_mock
par._data_processor._prepared_for_fitting = True
data = self.get_data()

Expand Down

0 comments on commit 70c2bf7

Please sign in to comment.