Skip to content

Commit

Permalink
Allow for simple PARSynthesizer constraints (#2044)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jun 7, 2024
1 parent 70c2bf7 commit 2a7e348
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 20 deletions.
59 changes: 51 additions & 8 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import logging
import uuid
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -135,13 +134,57 @@ def get_parameters(self):
return instantiated_parameters

def add_constraints(self, constraints):
"""Warn the user that constraints can't be added to the ``PARSynthesizer``."""
warnings.warn(
'The PARSynthesizer does not yet support constraints. This model will ignore any '
'constraints in the metadata.'
)
self._data_processor._constraints = []
self._data_processor._constraints_list = []
"""Add constraints to the synthesizer.
For PARSynthesizers only allow a list of constraints that follow these rules:
1) All constraints must be either for all contextual columns or non-contextual column.
No mixing constraints that cover both contextual and non-contextual columns
2) No overlapping constraints (there are no constraints that act on the same column)
3) No custom constraints
Args:
constraints (list):
List of constraints described as dictionaries in the following format:
* ``constraint_class``: Name of the constraint to apply.
* ``constraint_parameters``: A dictionary with the constraint parameters.
"""
context_set = set(self.context_columns)
constraint_cols = []
for constraint in constraints:
constraint_parameters = constraint['constraint_parameters']
columns = []
for param in constraint_parameters:
if 'column_name' in param:
col_names = constraint_parameters[param]
if isinstance(col_names, list):
columns.extend(col_names)
else:
columns.append(col_names)
for col in columns:
if col in constraint_cols:
raise SynthesizerInputError(
'The PARSynthesizer cannot accommodate multiple constraints '
'that overlap on the same columns.')
constraint_cols.append(col)

all_context = all(col in context_set for col in constraint_cols)
no_context = all(col not in context_set for col in constraint_cols)

if all_context or no_context:
super().add_constraints(constraints)
else:
raise SynthesizerInputError(
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.')

def load_custom_constraint_classes(self, filepath, class_names):
"""Error that tells the user custom constraints can't be used in the ``PARSynthesizer``."""
raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.')

def add_custom_constraint_class(self, class_object, class_name):
"""Error that tells the user custom constraints can't be used in the ``PARSynthesizer``."""
raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.')

def _validate_context_columns(self, data):
errors = []
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
import re

import numpy as np
import pandas as pd
import pytest
from deepecho import load_demo

from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError
from sdv.metadata import SingleTableMetadata
from sdv.sequential import PARSynthesizer

Expand Down Expand Up @@ -284,6 +287,60 @@ def test_par_missing_sequence_index():
assert (sampled.dtypes == data.dtypes).all()


def test_constraints_on_par():
"""Test if only simple constraints work on PARSynthesizer."""
# Setup
real_data, metadata = download_demo(
modality='sequential',
dataset_name='nasdaq100_2019'
)

synthesizer = PARSynthesizer(
metadata,
epochs=5,
context_columns=['Sector', 'Industry']
)

market_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'MarketCap',
'strict_boundaries': True
}
}
volume_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'Volume',
'strict_boundaries': True
}
}

context_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'Sector',
'strict_boundaries': True
}
}

# Run
synthesizer.add_constraints([volume_constraint, market_constraint])
synthesizer.fit(real_data)
samples = synthesizer.sample(50, 10)

# Assert
assert not (samples['MarketCap'] < 0).any().any()
assert not (samples['Volume'] < 0).any().any()
mixed_constraint_error_msg = re.escape(
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.'
)

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([volume_constraint, context_constraint])


def test_par_unique_sequence_index_with_enforce_min_max():
"""Test to see if there are duplicate sequence index values
when sequence_length is higher than real data
Expand Down
100 changes: 88 additions & 12 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rdt.transformers import FloatFormatter, UnixTimestampEncoder

from sdv.data_processing.data_processor import DataProcessor
from sdv.data_processing.errors import InvalidConstraintsError
from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling import Condition
Expand Down Expand Up @@ -108,24 +109,99 @@ def test___init___no_sequence_key(self):
verbose=False
)

@patch('sdv.sequential.par.warnings')
def test_add_constraints(self, warnings_mock):
"""Test that if constraints are being added, a warning is raised."""
def test_add_constraints(self):
"""Test that that only simple constraints can be added to PARSynthesizer."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata,
context_columns=['name', 'measurement'])
name_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'name'
}
}
measurement_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'measurement'
}
}
gender_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'gender'
}
}
time_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'time'
}
}
multi_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_names': ['name', 'time']
}
}
overlapping_error_msg = re.escape(
'The PARSynthesizer cannot accommodate multiple constraints '
'that overlap on the same columns.'
)
mixed_constraint_error_msg = re.escape(
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.'
)

# Run and Assert
with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([name_constraint, gender_constraint])

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([time_constraint, measurement_constraint])

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([multi_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([multi_constraint, name_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([name_constraint, name_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([gender_constraint, gender_constraint])

# Custom constraint will not be found
with pytest.raises(InvalidConstraintsError):
synthesizer.add_constraints([gender_constraint])

def test_load_custom_constraint_classes(self):
"""Test that if custom constraint is being added, an error is raised."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata)

# Run
synthesizer.add_constraints([object()])
# Run and Assert
error_message = re.escape(
'The PARSynthesizer cannot accommodate custom constraints.'
)
with pytest.raises(SynthesizerInputError, match=error_message):
synthesizer.load_custom_constraint_classes(filepath='test', class_names=[])

# Assert
warning_message = (
'The PARSynthesizer does not yet support constraints. This model will ignore any '
'constraints in the metadata.'
def test_add_custom_constraint_class(self):
"""Test that if custom constraint is being added, an error is raised."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata)

# Run and Assert
error_message = re.escape(
'The PARSynthesizer cannot accommodate custom constraints.'
)
warnings_mock.warn.assert_called_once_with(warning_message)
assert synthesizer._data_processor._constraints == []
assert synthesizer._data_processor._constraints_list == []
with pytest.raises(SynthesizerInputError, match=error_message):
synthesizer.add_custom_constraint_class(Mock(), class_name='Mock')

def test_get_parameters(self):
"""Test that it returns every ``init`` parameter without the ``metadata``."""
Expand Down

0 comments on commit 2a7e348

Please sign in to comment.