diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 3169a162c..96ae041fb 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -1,5 +1,6 @@ """Hierarchical Samplers.""" import logging +import warnings import pandas as pd @@ -273,13 +274,24 @@ def _sample(self, scale=1.0): # DFS to sample roots and then their children non_root_parents = set(self.metadata._get_parent_map().keys()) root_parents = set(self.metadata.tables.keys()) - non_root_parents + send_min_sample_warning = False for table in root_parents: num_rows = round(self._table_sizes[table] * scale) + if num_rows <= 0: + send_min_sample_warning = True + num_rows = 1 synthesizer = self._table_synthesizers[table] LOGGER.info(f'Sampling {num_rows} rows from table {table}') sampled_data[table] = self._sample_rows(synthesizer, num_rows) self._sample_children(table_name=table, sampled_data=sampled_data, scale=scale) + if send_min_sample_warning: + warn_msg = ( + "The 'scale' parameter is too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + warnings.warn(warn_msg) + added_relationships = set() for relationship in self.metadata.relationships: parent_name = relationship['parent_table_name'] diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4be8bd53f..60488e054 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1839,3 +1839,27 @@ def test_disjointed_tables(): # Assert for table in real_data: assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns) + + +def test_small_sample(): + """Test that the sample function still works with a small scale""" + # Setup + data, metadata = download_demo( + modality='multi_table', + dataset_name='fake_hotels' + ) + synthesizer = HMASynthesizer(metadata) + synthesizer.fit(data) + + # Run and Assert + warn_msg = re.escape( + "The 'scale' parameter is too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + with pytest.warns(Warning, match=warn_msg): + synthetic_data = synthesizer.sample(scale=0.01) + + assert (len(synthetic_data['hotels']) == 1) + assert (len(synthetic_data['guests']) >= len(data['guests']) * .01) + assert synthetic_data['hotels'].columns.tolist() == data['hotels'].columns.tolist() + assert synthetic_data['guests'].columns.tolist() == data['guests'].columns.tolist() diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index aa1bf6570..b9f6d02c9 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -667,3 +667,33 @@ def test___enforce_table_size_clipping(self): # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [2, 2, 4] + + def test___enforce_table_size_too_small_sample(self): + """Test it enforces the sampled data to have the same size as the real data. + + If the sample scale is too small ensure that the function doesn't error out. + """ + # Setup + instance = MagicMock() + data = { + 'parent': pd.DataFrame({ + 'fk': ['a', 'b', 'c'], + '__child__fk__num_rows': [1, 2, 3] + }) + } + instance.metadata._get_foreign_keys.return_value = ['fk'] + instance._min_child_rows = {'__child__fk__num_rows': 1} + instance._max_child_rows = {'__child__fk__num_rows': 3} + instance._table_sizes = {'child': 4} + + # Run + BaseHierarchicalSampler._enforce_table_size( + instance, + 'child', + 'parent', + .001, + data + ) + + # Assert + assert data['parent']['__child__fk__num_rows'].to_list() == [0, 0, 0]