Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Jun 18, 2024
1 parent dc569b6 commit 0490d81
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
self._augmented_tables = []
self._learned_relationships = 0
self._default_parameters = {}
self.parent_extended_columns = defaultdict(list)
self._parent_extended_columns = defaultdict(list)
self.verbose = verbose
BaseHierarchicalSampler.__init__(
self, self.metadata, self._table_synthesizers, self._table_sizes
Expand Down Expand Up @@ -217,8 +217,8 @@ def _get_distributions(self):
distributions = {}
for table in self.metadata.tables:
parameters = self.get_table_parameters(table)
synthesizer_parameter = parameters.get('synthesizer_parameters', {})
distributions[table] = synthesizer_parameter.get('default_distribution', None)
synthesizer_parameters = parameters.get('synthesizer_parameters', {})
distributions[table] = synthesizer_parameters.get('default_distribution', None)

return distributions

Expand Down Expand Up @@ -273,10 +273,10 @@ def preprocess(self, data):
def _set_extended_columns_distributions(self, synthesizer, table_name, valid_columns):
numerical_distributions = {}
if (
table_name in self.parent_extended_columns
and len(self.parent_extended_columns[table_name]) > 0
table_name in self._parent_extended_columns
and len(self._parent_extended_columns[table_name]) > 0
):
for extended_column in self.parent_extended_columns[table_name]:
for extended_column in self._parent_extended_columns[table_name]:
if extended_column in valid_columns:
numerical_distributions[extended_column] = 'truncnorm'

Expand Down Expand Up @@ -411,7 +411,7 @@ def _augment_table(self, table, tables, table_name):
self._min_child_rows[num_rows_key] = table[num_rows_key].min()

if len(extension.columns) > 0:
self.parent_extended_columns[table_name].extend(list(extension.columns))
self._parent_extended_columns[table_name].extend(list(extension.columns))

tables[table_name] = table
self._learned_relationships += 1
Expand Down

0 comments on commit 0490d81

Please sign in to comment.