diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index b251bf312..f7c4ca26a 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -160,7 +160,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} - self.parent_extended_columns = defaultdict(list) + self._parent_extended_columns = defaultdict(list) self.verbose = verbose BaseHierarchicalSampler.__init__( self, self.metadata, self._table_synthesizers, self._table_sizes @@ -217,8 +217,8 @@ def _get_distributions(self): distributions = {} for table in self.metadata.tables: parameters = self.get_table_parameters(table) - synthesizer_parameter = parameters.get('synthesizer_parameters', {}) - distributions[table] = synthesizer_parameter.get('default_distribution', None) + synthesizer_parameters = parameters.get('synthesizer_parameters', {}) + distributions[table] = synthesizer_parameters.get('default_distribution', None) return distributions @@ -273,10 +273,10 @@ def preprocess(self, data): def _set_extended_columns_distributions(self, synthesizer, table_name, valid_columns): numerical_distributions = {} if ( - table_name in self.parent_extended_columns - and len(self.parent_extended_columns[table_name]) > 0 + table_name in self._parent_extended_columns + and len(self._parent_extended_columns[table_name]) > 0 ): - for extended_column in self.parent_extended_columns[table_name]: + for extended_column in self._parent_extended_columns[table_name]: if extended_column in valid_columns: numerical_distributions[extended_column] = 'truncnorm' @@ -411,7 +411,7 @@ def _augment_table(self, table, tables, table_name): self._min_child_rows[num_rows_key] = table[num_rows_key].min() if len(extension.columns) > 0: - self.parent_extended_columns[table_name].extend(list(extension.columns)) + self._parent_extended_columns[table_name].extend(list(extension.columns)) tables[table_name] = table self._learned_relationships += 1