diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index f7c4ca26a..4e8a5af17 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -16,6 +16,7 @@ LOGGER = logging.getLogger(__name__) MAX_NUMBER_OF_COLUMNS = 1000 +DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm' class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): @@ -272,14 +273,10 @@ def preprocess(self, data): def _set_extended_columns_distributions(self, synthesizer, table_name, valid_columns): numerical_distributions = {} - if ( - table_name in self._parent_extended_columns - and len(self._parent_extended_columns[table_name]) > 0 - ): + if table_name in self._parent_extended_columns: for extended_column in self._parent_extended_columns[table_name]: if extended_column in valid_columns: - numerical_distributions[extended_column] = 'truncnorm' - + numerical_distributions[extended_column] = DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION synthesizer._set_numerical_distributions(numerical_distributions) def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):