diff --git a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py index f3397c911..458af34fa 100644 --- a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py +++ b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py @@ -286,7 +286,7 @@ def evaluate(self, data: BaseDataSource, **kwargs): from sklearn import metrics as sk_metrics - result, tasks = self.predict(data, kwargs=kwargs) + result, tasks = self.predict(data, **kwargs) m_dict = {} if metrics: if callable(metrics): # if metrics is a function diff --git a/lib/sedna/backend/base.py b/lib/sedna/backend/base.py index 8a89cb459..88f3ab963 100644 --- a/lib/sedna/backend/base.py +++ b/lib/sedna/backend/base.py @@ -49,28 +49,29 @@ def parse_kwargs(func, **kwargs): return kwargs return {k: v for k, v in kwargs.items() if k in need_kw.args} - def train(self, **kwargs): + def train(self, *args, **kwargs): """Train model.""" if callable(self.estimator): varkw = self.parse_kwargs(self.estimator, **kwargs) self.estimator = self.estimator(**varkw) - varkw = self.parse_kwargs(self.estimator.train, **kwargs) - return self.estimator.train(**varkw) + fit_method = getattr(self.estimator, "fit", self.estimator.train) + varkw = self.parse_kwargs(fit_method, **kwargs) + return fit_method(*args, **varkw) - def predict(self, **kwargs): + def predict(self, *args, **kwargs): """Inference model.""" varkw = self.parse_kwargs(self.estimator.predict, **kwargs) - return self.estimator.predict(**varkw) + return self.estimator.predict(*args, **varkw) - def predict_proba(self, **kwargs): + def predict_proba(self, *args, **kwargs): """Compute probabilities of possible outcomes for samples in X.""" varkw = self.parse_kwargs(self.estimator.predict_proba, **kwargs) - return self.estimator.predict_proba(**varkw) + return self.estimator.predict_proba(*args, **varkw) - def evaluate(self, **kwargs): + def evaluate(self, *args, **kwargs): """evaluate model.""" varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs) - return self.estimator.evaluate(**varkw) + return self.estimator.evaluate(*args, **varkw) def save(self, model_url="", model_name=None): mname = model_name or self.model_name diff --git a/lib/sedna/core/lifelong_learning/lifelong_learning.py b/lib/sedna/core/lifelong_learning/lifelong_learning.py index 17e4c613c..464c42525 100644 --- a/lib/sedna/core/lifelong_learning/lifelong_learning.py +++ b/lib/sedna/core/lifelong_learning/lifelong_learning.py @@ -65,7 +65,7 @@ def __init__(self, output_url=Context.get_parameters("OUTPUT_URL", "/tmp") ) task_index = FileOps.join_path(config['output_url'], - KBResourceConstant.KB_INDEX_NAME) + KBResourceConstant.KB_INDEX_NAME.value) config['task_index'] = task_index super(LifelongLearning, self).__init__( estimator=e, config=config @@ -141,7 +141,7 @@ def train(self, train_data, save_extractor = FileOps.join_path( self.config.output_url, - KBResourceConstant.TASK_EXTRACTOR_NAME + KBResourceConstant.TASK_EXTRACTOR_NAME.value ) extractor = FileOps.dump(extractor, save_extractor) try: