From 698f3910d34d553dba37892be237d4cdd7d1034f Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 12 Jun 2024 23:51:43 -0500 Subject: [PATCH 1/3] Add a minimum number of rows for sample --- sdv/sampling/hierarchical_sampler.py | 14 ++++++++++++- tests/integration/multi_table/test_hma.py | 24 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 3169a162c..1cec13ece 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -1,5 +1,6 @@ """Hierarchical Samplers.""" import logging +import warnings import pandas as pd @@ -138,7 +139,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): sampled_data (dict): A dictionary mapping table names to sampled data (pd.DataFrame). """ - total_num_rows = round(self._table_sizes[child_name] * scale) + total_num_rows = max(round(self._table_sizes[child_name] * scale), 1) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] @@ -273,13 +274,24 @@ def _sample(self, scale=1.0): # DFS to sample roots and then their children non_root_parents = set(self.metadata._get_parent_map().keys()) root_parents = set(self.metadata.tables.keys()) - non_root_parents + send_min_sample_warning = False for table in root_parents: num_rows = round(self._table_sizes[table] * scale) + if num_rows <= 0: + send_min_sample_warning = True + num_rows = 1 synthesizer = self._table_synthesizers[table] LOGGER.info(f'Sampling {num_rows} rows from table {table}') sampled_data[table] = self._sample_rows(synthesizer, num_rows) self._sample_children(table_name=table, sampled_data=sampled_data, scale=scale) + if send_min_sample_warning: + warn_msg = ( + "The 'scale' parameter it too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + warnings.warn(warn_msg) + added_relationships = set() for relationship in self.metadata.relationships: parent_name = relationship['parent_table_name'] diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4be8bd53f..8360ed551 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1839,3 +1839,27 @@ def test_disjointed_tables(): # Assert for table in real_data: assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns) + + +def test_small_sample(): + """Test that the sample function still works with a small scale""" + # Setup + data, metadata = download_demo( + modality='multi_table', + dataset_name='fake_hotels' + ) + synthesizer = HMASynthesizer(metadata) + synthesizer.fit(data) + + # Run and Assert + warn_msg = re.escape( + "The 'scale' parameter it too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + with pytest.warns(Warning, match=warn_msg): + synthetic_data = synthesizer.sample(scale=0.01) + + assert (len(synthetic_data['hotels']) == 1) + assert (len(synthetic_data['guests']) >= len(data['guests']) * .01) + assert synthetic_data['hotels'].columns.tolist() == data['hotels'].columns.tolist() + assert synthetic_data['guests'].columns.tolist() == data['guests'].columns.tolist() From cb848668e8ffe551386800a28aceee5aaffff54b Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 13 Jun 2024 00:01:49 -0500 Subject: [PATCH 2/3] chile table shouldn't affected --- sdv/sampling/hierarchical_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 1cec13ece..f937e81ca 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -139,7 +139,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): sampled_data (dict): A dictionary mapping table names to sampled data (pd.DataFrame). """ - total_num_rows = max(round(self._table_sizes[child_name] * scale), 1) + total_num_rows = round(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] From cb45d65df0b06cdaf58771b2b1bc193e9f605ba3 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 14 Jun 2024 10:17:06 -0500 Subject: [PATCH 3/3] Add test for ensuring that enfore_table_size can handle a small dataset --- sdv/sampling/hierarchical_sampler.py | 2 +- tests/integration/multi_table/test_hma.py | 2 +- .../sampling/test_hierarchical_sampler.py | 30 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index f937e81ca..96ae041fb 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -287,7 +287,7 @@ def _sample(self, scale=1.0): if send_min_sample_warning: warn_msg = ( - "The 'scale' parameter it too small. Some tables may have 1 row." + "The 'scale' parameter is too small. Some tables may have 1 row." ' For better quality data, please choose a larger scale.' ) warnings.warn(warn_msg) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 8360ed551..60488e054 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1853,7 +1853,7 @@ def test_small_sample(): # Run and Assert warn_msg = re.escape( - "The 'scale' parameter it too small. Some tables may have 1 row." + "The 'scale' parameter is too small. Some tables may have 1 row." ' For better quality data, please choose a larger scale.' ) with pytest.warns(Warning, match=warn_msg): diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index aa1bf6570..b9f6d02c9 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -667,3 +667,33 @@ def test___enforce_table_size_clipping(self): # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [2, 2, 4] + + def test___enforce_table_size_too_small_sample(self): + """Test it enforces the sampled data to have the same size as the real data. + + If the sample scale is too small ensure that the function doesn't error out. + """ + # Setup + instance = MagicMock() + data = { + 'parent': pd.DataFrame({ + 'fk': ['a', 'b', 'c'], + '__child__fk__num_rows': [1, 2, 3] + }) + } + instance.metadata._get_foreign_keys.return_value = ['fk'] + instance._min_child_rows = {'__child__fk__num_rows': 1} + instance._max_child_rows = {'__child__fk__num_rows': 3} + instance._table_sizes = {'child': 4} + + # Run + BaseHierarchicalSampler._enforce_table_size( + instance, + 'child', + 'parent', + .001, + data + ) + + # Assert + assert data['parent']['__child__fk__num_rows'].to_list() == [0, 0, 0]