Skip to content

Commit

Permalink
fix the lifelong example problem from backend and constant
Browse files Browse the repository at this point in the history
- fix sklearn backend: support args in train/eval/infer
- fix lifelong constant

Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
  • Loading branch information
JoeyHwong-gk committed Aug 13, 2021
1 parent 626a892 commit c06106e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def evaluate(self, data: BaseDataSource,
**kwargs):
from sklearn import metrics as sk_metrics

result, tasks = self.predict(data, kwargs=kwargs)
result, tasks = self.predict(data, **kwargs)
m_dict = {}
if metrics:
if callable(metrics): # if metrics is a function
Expand Down
19 changes: 10 additions & 9 deletions lib/sedna/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,29 @@ def parse_kwargs(func, **kwargs):
return kwargs
return {k: v for k, v in kwargs.items() if k in need_kw.args}

def train(self, **kwargs):
def train(self, *args, **kwargs):
"""Train model."""
if callable(self.estimator):
varkw = self.parse_kwargs(self.estimator, **kwargs)
self.estimator = self.estimator(**varkw)
varkw = self.parse_kwargs(self.estimator.train, **kwargs)
return self.estimator.train(**varkw)
fit_method = getattr(self.estimator, "fit", self.estimator.train)
varkw = self.parse_kwargs(fit_method, **kwargs)
return fit_method(*args, **varkw)

def predict(self, **kwargs):
def predict(self, *args, **kwargs):
"""Inference model."""
varkw = self.parse_kwargs(self.estimator.predict, **kwargs)
return self.estimator.predict(**varkw)
return self.estimator.predict(*args, **varkw)

def predict_proba(self, **kwargs):
def predict_proba(self, *args, **kwargs):
"""Compute probabilities of possible outcomes for samples in X."""
varkw = self.parse_kwargs(self.estimator.predict_proba, **kwargs)
return self.estimator.predict_proba(**varkw)
return self.estimator.predict_proba(*args, **varkw)

def evaluate(self, **kwargs):
def evaluate(self, *args, **kwargs):
"""evaluate model."""
varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs)
return self.estimator.evaluate(**varkw)
return self.estimator.evaluate(*args, **varkw)

def save(self, model_url="", model_name=None):
mname = model_name or self.model_name
Expand Down
4 changes: 2 additions & 2 deletions lib/sedna/core/lifelong_learning/lifelong_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,
output_url=Context.get_parameters("OUTPUT_URL", "/tmp")
)
task_index = FileOps.join_path(config['output_url'],
KBResourceConstant.KB_INDEX_NAME)
KBResourceConstant.KB_INDEX_NAME.value)
config['task_index'] = task_index
super(LifelongLearning, self).__init__(
estimator=e, config=config
Expand Down Expand Up @@ -141,7 +141,7 @@ def train(self, train_data,

save_extractor = FileOps.join_path(
self.config.output_url,
KBResourceConstant.TASK_EXTRACTOR_NAME
KBResourceConstant.TASK_EXTRACTOR_NAME.value
)
extractor = FileOps.dump(extractor, save_extractor)
try:
Expand Down

0 comments on commit c06106e

Please sign in to comment.