Skip to content

Commit

Permalink
refactor the class Worker, unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Jan 14, 2025
1 parent 45947f6 commit ad39712
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
10 changes: 3 additions & 7 deletions src/syngen/ml/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def __validate_metadata(self):
validator.run()
self.merged_metadata = validator.merged_metadata

def _preprocess_data(self):
def _preprocess_data(self, table_name: str):
"""
Preprocess the data before a training process
"""
PreprocessHandler(
metadata=self.metadata,
metadata_path=self.metadata_path,
table_name=self.table_name,
table_name=table_name,
loader=self.loader
).run()

Expand Down Expand Up @@ -338,11 +338,7 @@ def __train_tables(
delta = 0.49 / len(tables_for_training)

for table in self.metadata.keys():
PreprocessHandler(
metadata=self.metadata,
metadata_path=self.metadata_path,
table_name=table
).run()
self._preprocess_data(table)

for table in tables_for_training:
self._train_table(table, metadata_for_training, delta)
Expand Down
28 changes: 14 additions & 14 deletions src/tests/unit/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,20 @@ def test_init_infer_config_with_existed_input_data_in_train_process(mocker, rp_l
mocker.patch("syngen.ml.data_loaders.DataLoader.has_existed_path", return_value=True)

infer_config = InferConfig(
destination="path/to/destination.csv",
metadata=metadata,
metadata_path="path/to/metadata.yaml",
size=100,
table_name=table_name,
run_parallel=False,
batch_size=100,
random_seed=None,
reports=["accuracy"],
both_keys=True,
log_level="DEBUG",
loader=None,
type_of_process="train"
)
destination="path/to/destination.csv",
metadata=metadata,
metadata_path="path/to/metadata.yaml",
size=100,
table_name=table_name,
run_parallel=False,
batch_size=100,
random_seed=None,
reports=["accuracy"],
both_keys=True,
log_level="DEBUG",
loader=None,
type_of_process="train"
)

assert infer_config.reports == ["accuracy"]

Expand Down

0 comments on commit ad39712

Please sign in to comment.