diff --git a/src/syngen/ml/worker/worker.py b/src/syngen/ml/worker/worker.py index 34964b17..040ee40f 100644 --- a/src/syngen/ml/worker/worker.py +++ b/src/syngen/ml/worker/worker.py @@ -60,14 +60,14 @@ def __validate_metadata(self): validator.run() self.merged_metadata = validator.merged_metadata - def _preprocess_data(self): + def _preprocess_data(self, table_name: str): """ Preprocess the data before a training process """ PreprocessHandler( metadata=self.metadata, metadata_path=self.metadata_path, - table_name=self.table_name, + table_name=table_name, loader=self.loader ).run() @@ -338,11 +338,7 @@ def __train_tables( delta = 0.49 / len(tables_for_training) for table in self.metadata.keys(): - PreprocessHandler( - metadata=self.metadata, - metadata_path=self.metadata_path, - table_name=table - ).run() + self._preprocess_data(table) for table in tables_for_training: self._train_table(table, metadata_for_training, delta) diff --git a/src/tests/unit/config/test_config.py b/src/tests/unit/config/test_config.py index abed3ff0..c57fa670 100644 --- a/src/tests/unit/config/test_config.py +++ b/src/tests/unit/config/test_config.py @@ -166,20 +166,20 @@ def test_init_infer_config_with_existed_input_data_in_train_process(mocker, rp_l mocker.patch("syngen.ml.data_loaders.DataLoader.has_existed_path", return_value=True) infer_config = InferConfig( - destination="path/to/destination.csv", - metadata=metadata, - metadata_path="path/to/metadata.yaml", - size=100, - table_name=table_name, - run_parallel=False, - batch_size=100, - random_seed=None, - reports=["accuracy"], - both_keys=True, - log_level="DEBUG", - loader=None, - type_of_process="train" - ) + destination="path/to/destination.csv", + metadata=metadata, + metadata_path="path/to/metadata.yaml", + size=100, + table_name=table_name, + run_parallel=False, + batch_size=100, + random_seed=None, + reports=["accuracy"], + both_keys=True, + log_level="DEBUG", + loader=None, + type_of_process="train" + ) assert infer_config.reports == ["accuracy"]