diff --git a/sdv/evaluation/multi_table.py b/sdv/evaluation/multi_table.py index 25669af82..34ed0b94e 100644 --- a/sdv/evaluation/multi_table.py +++ b/sdv/evaluation/multi_table.py @@ -77,8 +77,8 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name 1D marginal distribution plot (i.e. a histogram) of the columns. """ metadata = metadata.tables[table_name] - real_data = real_data[table_name] - synthetic_data = synthetic_data[table_name] + real_data = real_data[table_name] if real_data else None + synthetic_data = synthetic_data[table_name] if synthetic_data else None return single_table_visualization.get_column_plot( real_data, synthetic_data, @@ -118,8 +118,8 @@ def get_column_pair_plot( 2D bivariate distribution plot (i.e. a scatterplot) of the columns. """ metadata = metadata.tables[table_name] - real_data = real_data[table_name] - synthetic_data = synthetic_data[table_name] + real_data = real_data[table_name] if real_data else None + synthetic_data = synthetic_data[table_name] if synthetic_data else None return single_table_visualization.get_column_pair_plot( real_data, synthetic_data, metadata, column_names, sample_size, plot_type ) diff --git a/tests/unit/evaluation/test_multi_table.py b/tests/unit/evaluation/test_multi_table.py index 4cdea58a6..2995f9dd5 100644 --- a/tests/unit/evaluation/test_multi_table.py +++ b/tests/unit/evaluation/test_multi_table.py @@ -73,9 +73,31 @@ def test_get_column_plot(mock_plot): assert plot == 'plot' +@patch('sdv.evaluation.single_table.get_column_plot') +def test_get_column_plot_only_real_or_synthetic(mock_plot): + """Test that ``get_column_plot`` works when only real or synthetic data is provided.""" + # Setup + table1 = pd.DataFrame({'col': [1, 2, 3]}) + data1 = {'table': table1} + metadata = MultiTableMetadata() + metadata.detect_table_from_dataframe('table', table1) + mock_plot.return_value = 'plot' + + # Run + get_column_plot(data1, None, metadata, 'table', 'col') + get_column_plot(None, data1, metadata, 'table', 'col') + + # Assert + call_metadata = metadata.tables['table'] + mock_plot.assert_has_calls([ + ((table1, None, call_metadata, 'col', None), {}), + ((None, table1, call_metadata, 'col', None), {}), + ]) + + @patch('sdv.evaluation.single_table.get_column_pair_plot') def test_get_column_pair_plot(mock_plot): - """Test that ``get_column_pair`` plot is being called with the expected objects.""" + """Test that ``get_column_pair_plot`` is being called with the expected objects.""" # Setup table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) table2 = pd.DataFrame({'col1': [2, 1, 3], 'col2': [1, 2, 3]}) @@ -94,6 +116,28 @@ def test_get_column_pair_plot(mock_plot): assert plot == 'plot' +@patch('sdv.evaluation.single_table.get_column_pair_plot') +def test_get_column_pair_plot_only_real_or_synthetic(mock_plot): + """Test that ``get_column_pair_plot`` works when only real or synthetic data is provided.""" + # Setup + table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) + data1 = {'table': table1} + metadata = MultiTableMetadata() + metadata.detect_table_from_dataframe('table', table1) + mock_plot.return_value = 'plot' + + # Run + get_column_pair_plot(data1, None, metadata, 'table', ['col1', 'col2'], 2) + get_column_pair_plot(None, data1, metadata, 'table', ['col1', 'col2'], 2) + + # Assert + call_metadata = metadata.tables['table'] + mock_plot.assert_has_calls([ + ((table1, None, call_metadata, ['col1', 'col2'], None, 2), {}), + ((None, table1, call_metadata, ['col1', 'col2'], None, 2), {}), + ]) + + @patch('sdmetrics.visualization.get_cardinality_plot') def test_get_cardinality_plot(mock_plot): """Test it calls ``get_column_cardinality_plot`` in sdmetrics with the parent primary key."""