diff --git a/examples/lifelong_learning/atcii/eval.py b/examples/lifelong_learning/atcii/eval.py index 8ed719bb4..15e589cf4 100644 --- a/examples/lifelong_learning/atcii/eval.py +++ b/examples/lifelong_learning/atcii/eval.py @@ -15,7 +15,7 @@ import json from sedna.datasources import CSVDataParse -from sedna.common.config import Context, BaseConfig +from sedna.common.config import BaseConfig from sedna.core.lifelong_learning import LifelongLearning from interface import DATACONF, Estimator, feature_process @@ -26,17 +26,24 @@ def main(): valid_data = CSVDataParse(data_type="valid", func=feature_process) valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) - model_threshold = float(Context.get_parameters('model_threshold', 0)) + + task_definition = { + "method": "TaskDefinitionByDataAttr", + "param": attribute + } ll_job = LifelongLearning( estimator=Estimator, - task_definition="TaskDefinitionByDataAttr", - task_definition_param=attribute + task_definition=task_definition, + task_relationship_discovery=None, + task_mining=None, + task_remodeling=None, + inference_integrate=None, + unseen_task_detect=None ) eval_experiment = ll_job.evaluate( data=valid_data, metrics="precision_score", - metrics_param={"average": "micro"}, - model_threshold=model_threshold + metrics_param={"average": "micro"} ) return eval_experiment diff --git a/examples/lifelong_learning/atcii/inference.py b/examples/lifelong_learning/atcii/inference.py index 947107e0b..73843c36f 100644 --- a/examples/lifelong_learning/atcii/inference.py +++ b/examples/lifelong_learning/atcii/inference.py @@ -26,17 +26,29 @@ def main(): - utd = Context.get_parameters("UTD_NAME", "TaskAttr") + utd = Context.get_parameters("UTD_NAME", "TaskAttrFilter") attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") - ll_job = LifelongLearning( + task_mining = { + "method": "TaskMiningByDataAttr", + "param": attribute + } + + unseen_task_detect = { + "method": utd, + "param": utd_parameters + } + + ll_service = LifelongLearning( estimator=Estimator, - task_mining="TaskMiningByDataAttr", - task_mining_param=attribute, - unseen_task_detect=utd, - unseen_task_detect_param=utd_parameters) + task_mining=task_mining, + task_definition=None, + task_relationship_discovery=None, + task_remodeling=None, + inference_integrate=None, + unseen_task_detect=unseen_task_detect) infer_dataset_url = Context.get_parameters('infer_dataset_url') file_handle = open(infer_dataset_url, "r", encoding="utf-8") @@ -60,12 +72,14 @@ def main(): rows = reader[0] data = dict(zip(header, rows)) infer_data.parse(data, label=DATACONF["LABEL"]) - rsl, is_unseen, target_task = ll_job.inference(infer_data) + rsl, is_unseen, target_task = ll_service.inference(infer_data) rows.append(list(rsl)[0]) + + output = "\t".join(map(str, rows)) + "\n" if is_unseen: - unseen_sample.write("\t".join(map(str, rows)) + "\n") - output_sample.write("\t".join(map(str, rows)) + "\n") + unseen_sample.write(output) + output_sample.write(output) unseen_sample.close() output_sample.close() diff --git a/examples/lifelong_learning/atcii/train.py b/examples/lifelong_learning/atcii/train.py index cbf7b29fb..e96d4a6ec 100644 --- a/examples/lifelong_learning/atcii/train.py +++ b/examples/lifelong_learning/atcii/train.py @@ -28,13 +28,23 @@ def main(): train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) early_stopping_rounds = int( - Context.get_parameters( - "early_stopping_rounds", 100)) + Context.get_parameters("early_stopping_rounds", 100) + ) metric_name = Context.get_parameters("metric_name", "mlogloss") + + task_definition = { + "method": "TaskDefinitionByDataAttr", + "param": attribute + } + ll_job = LifelongLearning( estimator=Estimator, - task_definition="TaskDefinitionByDataAttr", - task_definition_param=attribute + task_definition=task_definition, + task_relationship_discovery=None, + task_mining=None, + task_remodeling=None, + inference_integrate=None, + unseen_task_detect=None ) train_experiment = ll_job.train( train_data=train_data, diff --git a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py index 1af4edc5c..f3397c911 100644 --- a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py +++ b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import json -import joblib +import pandas as pd from sedna.datasources import BaseDataSource from sedna.backend import set_backend from sedna.common.log import LOGGER -from sedna.common.config import Context from sedna.common.file_ops import FileOps +from sedna.common.config import Context +from sedna.common.constant import KBResourceConstant from sedna.common.class_factory import ClassFactory, ClassType from .task_jobs.artifact import Model, Task, TaskGroup @@ -36,117 +36,113 @@ class MulTaskLearning: def __init__(self, estimator=None, - task_definition="TaskDefinitionByDataAttr", + task_definition=None, task_relationship_discovery=None, task_mining=None, task_remodeling=None, - inference_integrate=None, - task_definition_param=None, - relationship_discovery_param=None, - task_mining_param=None, - task_remodeling_param=None, - inference_integrate_param=None + inference_integrate=None ): - if not task_relationship_discovery: - task_relationship_discovery = "DefaultTaskRelationDiscover" - if not task_remodeling: - task_remodeling = "DefaultTaskRemodeling" - if not inference_integrate: - inference_integrate = "DefaultInferenceIntegrate" - self.method_selection = dict( - task_definition=task_definition, - task_relationship_discovery=task_relationship_discovery, - task_mining=task_mining, - task_remodeling=task_remodeling, - inference_integrate=inference_integrate, - task_definition_param=task_definition_param, - task_relationship_discovery_param=relationship_discovery_param, - task_mining_param=task_mining_param, - task_remodeling_param=task_remodeling_param, - inference_integrate_param=inference_integrate_param) + + self.task_definition = task_definition or { + "method": "TaskDefinitionByDataAttr" + } + self.task_relationship_discovery = task_relationship_discovery or { + "method": "DefaultTaskRelationDiscover" + } + self.task_mining = task_mining or {} + self.task_remodeling = task_remodeling or { + "method": "DefaultTaskRemodeling" + } + self.inference_integrate = inference_integrate or { + "method": "DefaultInferenceIntegrate" + } self.models = None self.extractor = None self.base_model = estimator self.task_groups = None - self.task_index_url = Context.get_parameters( - "MODEL_URLS", '/tmp/index.pkl' - ) + self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value self.min_train_sample = int(Context.get_parameters( - "MIN_TRAIN_SAMPLE", '10' + "MIN_TRAIN_SAMPLE", KBResourceConstant.MIN_TRAIN_SAMPLE.value )) @staticmethod def parse_param(param_str): if not param_str: return {} + if isinstance(param_str, dict): + return param_str try: raw_dict = json.loads(param_str, encoding="utf-8") except json.JSONDecodeError: raw_dict = {} return raw_dict - def task_definition(self, samples): + def _task_definition(self, samples): """ Task attribute extractor and multi-task definition """ - method_name = self.method_selection.get( - "task_definition", "TaskDefinitionByDataAttr") + method_name = self.task_definition.get( + "method", "TaskDefinitionByDataAttr" + ) extend_param = self.parse_param( - self.method_selection.get("task_definition_param")) + self.task_definition.get("param") + ) method_cls = ClassFactory.get_cls( ClassType.MTL, method_name)(**extend_param) return method_cls(samples) - def task_relationship_discovery(self, tasks): + def _task_relationship_discovery(self, tasks): """ Merge tasks from task_definition """ - method_name = self.method_selection.get("task_relationship_discovery") + method_name = self.task_relationship_discovery.get("method") extend_param = self.parse_param( - self.method_selection.get("task_relationship_discovery_param") + self.task_relationship_discovery.get("param") ) method_cls = ClassFactory.get_cls( ClassType.MTL, method_name)(**extend_param) return method_cls(tasks) - def task_mining(self, samples): + def _task_mining(self, samples): """ Mining tasks of inference sample base on task attribute extractor """ - method_name = self.method_selection.get("task_mining") + method_name = self.task_mining.get("method") extend_param = self.parse_param( - self.method_selection.get("task_mining_param")) + self.task_mining.get("param") + ) if not method_name: - task_definition = self.method_selection.get( - "task_definition", "TaskDefinitionByDataAttr") + task_definition = self.task_definition.get( + "method", "TaskDefinitionByDataAttr" + ) method_name = self._method_pair.get(task_definition, 'TaskMiningByDataAttr') extend_param = self.parse_param( - self.method_selection.get("task_definition_param")) + self.task_definition.get("param")) method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( task_extractor=self.extractor, **extend_param ) return method_cls(samples=samples) - def task_remodeling(self, samples, mappings): + def _task_remodeling(self, samples, mappings): """ Remodeling tasks from task mining """ - method_name = self.method_selection.get("task_remodeling") + method_name = self.task_remodeling.get("method") extend_param = self.parse_param( - self.method_selection.get("task_remodeling_param")) + self.task_remodeling.get("param")) method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( models=self.models, **extend_param) return method_cls(samples=samples, mappings=mappings) - def inference_integrate(self, tasks): + def _inference_integrate(self, tasks): """ Aggregate inference results from target models """ - method_name = self.method_selection.get("inference_integrate") + method_name = self.inference_integrate.get("method") extend_param = self.parse_param( - self.method_selection.get("inference_integrate_param")) + self.inference_integrate.get("param")) method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( models=self.models, **extend_param) return method_cls(tasks=tasks) if method_cls else tasks @@ -154,12 +150,12 @@ def inference_integrate(self, tasks): def train(self, train_data: BaseDataSource, valid_data: BaseDataSource = None, post_process=None, **kwargs): - tasks, task_extractor, train_data = self.task_definition(train_data) + tasks, task_extractor, train_data = self._task_definition(train_data) self.extractor = task_extractor - task_groups = self.task_relationship_discovery(tasks) + task_groups = self._task_relationship_discovery(tasks) self.models = [] callback = None - if post_process: + if isinstance(post_process, str): callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() self.task_groups = [] feedback = {} @@ -180,18 +176,31 @@ def train(self, train_data: BaseDataSource, continue LOGGER.info(f"MTL Train start {i} : {task.entry}") - model_obj = set_backend(estimator=self.base_model) - res = model_obj.train(train_data=task.samples, **kwargs) - if callback: - res = callback(model_obj, res) - model_path = model_obj.save(model_name=f"{task.entry}.model") - model = Model(index=i, entry=task.entry, - model=model_path, result=res) - model.meta_attr = [t.meta_attr for t in task.tasks] + model = None + for t in task.tasks: # if model has train in tasks + if not (t.model and t.result): + continue + model_path = t.model.save(model_name=f"{task.entry}.model") + t.model = model_path + model = Model(index=i, entry=t.entry, + model=model_path, result=t.result) + model.meta_attr = t.meta_attr + break + if not model: + model_obj = set_backend(estimator=self.base_model) + res = model_obj.train(train_data=task.samples, **kwargs) + if callback: + res = callback(model_obj, res) + model_path = model_obj.save(model_name=f"{task.entry}.model") + model = Model(index=i, entry=task.entry, + model=model_path, result=res) + + model.meta_attr = [t.meta_attr for t in task.tasks] task.model = model self.models.append(model) - feedback[task.entry] = res + feedback[task.entry] = model.result self.task_groups.append(task) + if len(rare_task): model_obj = set_backend(estimator=self.base_model) res = model_obj.train(train_data=train_data, **kwargs) @@ -211,39 +220,41 @@ def train(self, train_data: BaseDataSource, self.models[i] = model feedback[entry] = res self.task_groups[i] = task - extractor_file = FileOps.join_path( - os.path.dirname(self.task_index_url), - "kb_extractor.pkl" - ) - joblib.dump(self.extractor, extractor_file) + task_index = { - "extractor": extractor_file, + "extractor": self.extractor, "task_groups": self.task_groups } - joblib.dump(task_index, self.task_index_url) if valid_data: - feedback = self.evaluate(valid_data, **kwargs) + feedback, _ = self.evaluate(valid_data, **kwargs) + try: + FileOps.dump(task_index, self.task_index_url) + except TypeError: + return feedback, task_index + return feedback, self.task_index_url - return feedback + def load(self, task_index_url=None): + if task_index_url: + self.task_index_url = task_index_url + assert FileOps.exists(self.task_index_url), FileExistsError( + f"Task index miss: {self.task_index_url}" + ) + task_index = FileOps.load(self.task_index_url) + self.extractor = task_index['extractor'] + if isinstance(self.extractor, str): + self.extractor = FileOps.load(self.extractor) + self.task_groups = task_index['task_groups'] + self.models = [task.model for task in self.task_groups] def predict(self, data: BaseDataSource, post_process=None, **kwargs): + if not (self.models and self.extractor): - task_index = joblib.load(self.task_index_url) - extractor_file = FileOps.join_path( - os.path.dirname(self.task_index_url), - "kb_extractor.pkl" - ) - if (not callable(task_index['extractor']) and - isinstance(task_index['extractor'], str)): - FileOps.download(task_index['extractor'], extractor_file) - self.extractor = joblib.load(extractor_file) - else: - self.extractor = task_index['extractor'] - self.task_groups = task_index['task_groups'] - self.models = [task.model for task in self.task_groups] - data, mappings = self.task_mining(samples=data) - samples, models = self.task_remodeling(samples=data, mappings=mappings) + self.load() + + data, mappings = self._task_mining(samples=data) + samples, models = self._task_remodeling(samples=data, + mappings=mappings) callback = None if post_process: @@ -254,17 +265,19 @@ def predict(self, data: BaseDataSource, m = models[inx] if not isinstance(m, Model): continue - model_obj = set_backend(estimator=self.base_model) - evaluator = model_obj.load(m.model) if isinstance( - m.model, str) else m.model - pred = evaluator.predict(df.x, kwargs=kwargs) + if isinstance(m.model, str): + evaluator = set_backend(estimator=self.base_model) + evaluator.load(m.model) + else: + evaluator = m.model + pred = evaluator.predict(df.x, **kwargs) if callable(callback): pred = callback(pred, df) task = Task(entry=m.entry, samples=df) task.result = pred task.model = m tasks.append(task) - res = self.inference_integrate(tasks) + res = self._inference_integrate(tasks) return res, tasks def evaluate(self, data: BaseDataSource, @@ -293,7 +306,7 @@ def evaluate(self, data: BaseDataSource, m_dict = { metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) } - elif isinstance(metrics, dict): # if metrics with name + elif isinstance(metrics, dict): # if metrics with name for k, v in metrics.items(): if isinstance(v, str): v = getattr(sk_metrics, v) @@ -306,8 +319,9 @@ def evaluate(self, data: BaseDataSource, } metrics_param = {"average": "micro"} - data.x['pred_y'] = result - data.x['real_y'] = data.y + if isinstance(data.x, pd.DataFrame): + data.x['pred_y'] = result + data.x['real_y'] = data.y if not metrics_param: metrics_param = {} elif isinstance(metrics_param, str): diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py index e56bd154a..fd7c1e9cd 100644 --- a/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py @@ -22,6 +22,7 @@ def __init__(self, entry, samples, meta_attr=None): self.entry = entry self.samples = samples self.meta_attr = meta_attr + self.test_samples = None # assign on task definition and use in TRD self.model = None # assign on running self.result = None # assign on running diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py index f33e23840..ebefc2cca 100644 --- a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py @@ -15,6 +15,7 @@ from typing import List import numpy as np +import pandas as pd from sedna.datasources import BaseDataSource from sedna.common.class_factory import ClassFactory, ClassType @@ -34,11 +35,18 @@ def __call__(self, samples: BaseDataSource, mappings: List): for m in np.unique(mappings): task_df = BaseDataSource(data_type=d_type) _inx = np.where(mappings == m) - task_df.x = samples.x.iloc[_inx] + if isinstance(samples.x, pd.DataFrame): + task_df.x = samples.x.iloc[_inx] + else: + task_df.x = np.array(samples.x)[_inx] if d_type != "test": - task_df.y = samples.y.iloc[_inx] + if isinstance(samples.x, pd.DataFrame): + task_df.y = samples.y.iloc[_inx] + else: + task_df.y = np.array(samples.y)[_inx] task_df.inx = _inx[0].tolist() - task_df.meta_attr = samples.meta_attr.iloc[_inx].values + if samples.meta_attr is not None: + task_df.meta_attr = np.array(samples.meta_attr)[_inx] data.append(task_df) model = self.models[m] or self.models[0] models.append(model) diff --git a/lib/sedna/algorithms/unseen_task_detect/__init__.py b/lib/sedna/algorithms/unseen_task_detect/__init__.py index 375996060..f31596a27 100644 --- a/lib/sedna/algorithms/unseen_task_detect/__init__.py +++ b/lib/sedna/algorithms/unseen_task_detect/__init__.py @@ -12,60 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unseen Task detect Algorithms for Lifelong Learning""" - -import abc -from typing import List - -import numpy as np - -from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task -from sedna.common.class_factory import ClassFactory, ClassType - -__all__ = ('ModelProbeFilter', 'TaskAttrFilter') - - -class BaseFilter(metaclass=abc.ABCMeta): - """The base class to define unified interface.""" - - def __call__(self, task: Task = None): - """predict function, and it must be implemented by - different methods class. - - :param task: inference task - :return: `True` means unseen task, `False` means not an unseen task. - """ - raise NotImplementedError - - -@ClassFactory.register(ClassType.UTD, alias="ModelProbe") -class ModelProbeFilter(BaseFilter, abc.ABC): - def __init__(self): - pass - - def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): - all_proba = [] - for task in tasks: - sample = task.samples - model = task.model - if hasattr(model, "predict_proba"): - proba = model.predict_proba(sample) - all_proba.append(np.max(proba)) - return np.mean(all_proba) > threshold if all_proba else True - - -@ClassFactory.register(ClassType.UTD, alias="TaskAttr") -class TaskAttrFilter(BaseFilter, abc.ABC): - def __init__(self): - pass - - def __call__(self, tasks: List[Task] = None, **kwargs): - for task in tasks: - model_attr = list(map(list, task.model.meta_attr)) - sample_attr = list(map(list, task.samples.meta_attr)) - - if not (model_attr and sample_attr): - continue - if list(model_attr) == list(sample_attr): - return False - return True +from .unseen_task_detect import ModelProbeFilter, TaskAttrFilter diff --git a/lib/sedna/algorithms/unseen_task_detect/unseen_task_detect.py b/lib/sedna/algorithms/unseen_task_detect/unseen_task_detect.py new file mode 100644 index 000000000..f8e40397b --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detect/unseen_task_detect.py @@ -0,0 +1,71 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unseen task detection algorithms for Lifelong Learning""" + +import abc +from typing import List + +import numpy as np + +from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('ModelProbeFilter', 'TaskAttrFilter') + + +class BaseFilter(metaclass=abc.ABCMeta): + """The base class to define unified interface.""" + + def __call__(self, task: Task = None): + """predict function, and it must be implemented by + different methods class. + + :param task: inference task + :return: `True` means unseen task, `False` means not an unseen task. + """ + raise NotImplementedError + + +@ClassFactory.register(ClassType.UTD) +class ModelProbeFilter(BaseFilter, abc.ABC): + def __init__(self): + pass + + def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): + all_proba = [] + for task in tasks: + sample = task.samples + model = task.model + if hasattr(model, "predict_proba"): + proba = model.predict_proba(sample) + all_proba.append(np.max(proba)) + return np.mean(all_proba) > threshold if all_proba else True + + +@ClassFactory.register(ClassType.UTD) +class TaskAttrFilter(BaseFilter, abc.ABC): + def __init__(self): + pass + + def __call__(self, tasks: List[Task] = None, **kwargs): + for task in tasks: + model_attr = list(map(list, task.model.meta_attr)) + sample_attr = list(map(list, task.samples.meta_attr)) + + if not (model_attr and sample_attr): + continue + if list(model_attr) == list(sample_attr): + return False + return True diff --git a/lib/sedna/backend/base.py b/lib/sedna/backend/base.py index 4aae8a04b..8a89cb459 100644 --- a/lib/sedna/backend/base.py +++ b/lib/sedna/backend/base.py @@ -35,7 +35,7 @@ def model_name(self): if self.default_name: return self.default_name model_postfix = {"pytorch": ".pth", - "keras": ".h5", "tensorflow": ".pb"} + "keras": ".pb", "tensorflow": ".pb"} continue_flag = "_finetune_" if self.fine_tune else "" post_fix = model_postfix.get(self.framework, ".pkl") return f"model{continue_flag}{self.framework}{post_fix}" @@ -107,9 +107,11 @@ def load(self, model_url="", model_name=None, **kwargs): self.model_save_path, mname = os.path.split(self.model_save_path) model_path = FileOps.join_path(self.model_save_path, mname) if model_url: - FileOps.download(model_url, model_path) + model_path = FileOps.download(model_url, model_path) self.has_load = True - + if not (hasattr(self.estimator, "load") + and os.path.exists(model_path)): + return return self.estimator.load(model_url=model_path) def set_weights(self, weights): diff --git a/lib/sedna/backend/tensorflow/__init__.py b/lib/sedna/backend/tensorflow/__init__.py index c029b2e5b..d59bb08c8 100644 --- a/lib/sedna/backend/tensorflow/__init__.py +++ b/lib/sedna/backend/tensorflow/__init__.py @@ -25,10 +25,12 @@ # version 2.0 tf ConfigProto = tf.compat.v1.ConfigProto Session = tf.compat.v1.Session + reset_default_graph = tf.compat.v1.reset_default_graph else: # version 1 ConfigProto = tf.ConfigProto Session = tf.Session + reset_default_graph = tf.reset_default_graph class TFBackend(BackendBase): @@ -64,24 +66,27 @@ def train(self, train_data, valid_data=None, **kwargs): self.estimator = self.estimator() if self.fine_tune and FileOps.exists(self.model_save_path): self.finetune() - + self.has_load = True + varkw = self.parse_kwargs(self.estimator.train, **kwargs) return self.estimator.train( train_data=train_data, valid_data=valid_data, - **kwargs + **varkw ) def predict(self, data, **kwargs): if not self.has_load: - tf.reset_default_graph() - self.sess = self.load() - return self.estimator.predict(data, **kwargs) + reset_default_graph() + self.load() + varkw = self.parse_kwargs(self.estimator.predict, **kwargs) + return self.estimator.predict(data=data, **varkw) def evaluate(self, data, **kwargs): if not self.has_load: - tf.reset_default_graph() - self.sess = self.load() - return self.estimator.evaluate(data, **kwargs) + reset_default_graph() + self.load() + varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs) + return self.estimator.evaluate(data, **varkw) def finetune(self): """todo: no support yet""" @@ -99,23 +104,25 @@ def set_weights(self, weights): def model_info(self, model, relpath=None, result=None): ckpt = os.path.dirname(model) + _, _type = os.path.splitext(model) if relpath: _url = FileOps.remove_path_prefix(model, relpath) ckpt_url = FileOps.remove_path_prefix(ckpt, relpath) else: _url = model ckpt_url = ckpt - results = [ - { - "format": "pb", + _type = _type.lstrip(".").lower() + results = [{ + "format": _type, "url": _url, "metrics": result - }, { + }] + if _type == "pb": # report ckpt path when model save as pb file + results.append({ "format": "ckpt", "url": ckpt_url, "metrics": result - } - ] + }) return results diff --git a/lib/sedna/common/constant.py b/lib/sedna/common/constant.py index 79ca191b7..e2e1c379d 100644 --- a/lib/sedna/common/constant.py +++ b/lib/sedna/common/constant.py @@ -12,23 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from enum import Enum -LOG = logging.getLogger(__name__) - - -class ModelType(Enum): - GlobalModel = 1 - PersonalizedModel = 2 - - -class Framework(Enum): - Tensorflow = "tensorflow" - Keras = "keras" - Pytorch = "pytorch" - Mindspore = "mindspore" - class K8sResourceKind(Enum): DEFAULT = "default" @@ -42,3 +27,9 @@ class K8sResourceKindStatus(Enum): COMPLETED = "completed" FAILED = "failed" RUNNING = "running" + + +class KBResourceConstant(Enum): + MIN_TRAIN_SAMPLE = 10 + KB_INDEX_NAME = "index.pkl" + TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" diff --git a/lib/sedna/common/file_ops.py b/lib/sedna/common/file_ops.py index 7694a1a7f..668fbbc53 100644 --- a/lib/sedna/common/file_ops.py +++ b/lib/sedna/common/file_ops.py @@ -17,11 +17,12 @@ import os import re +import joblib import codecs import pickle import shutil -import tempfile import hashlib +import tempfile from urllib.parse import urlparse from .utils import singleton @@ -98,15 +99,23 @@ def clean_folder(cls, target, clean=True): if not args[0]: args[0] = os.path.sep _path = cls.join_path(*args) - if os.path.isdir(_path) and clean: - shutil.rmtree(_path) + if clean: + cls.delete(_path) if os.path.isfile(_path): - if clean: - os.remove(_path) _path = cls.join_path(*args[:len(args) - 1]) os.makedirs(_path, exist_ok=True) return target + @classmethod + def delete(cls, path): + try: + if os.path.isdir(path): + shutil.rmtree(path) + if os.path.isfile(path): + os.remove(path) + except Exception: + pass + @classmethod def make_base_dir(cls, *args): """Make new a base directory. @@ -179,6 +188,7 @@ def load_pickle(cls, filename): :rtype: object or None. """ + filename = cls.download(filename) if not os.path.isfile(filename): return None with open(filename, "rb") as f: @@ -203,8 +213,7 @@ def copy_folder(cls, src, dst): name = os.path.join(src, files) back_name = os.path.join(dst, files) if os.path.isfile(name): - if os.path.isfile(back_name): - shutil.copy(name, back_name) + shutil.copy(name, back_name) else: if not os.path.isdir(back_name): shutil.copytree(name, back_name) @@ -219,7 +228,7 @@ def copy_file(cls, src, dst): :param str dst: destination path. """ - if dst is None or dst == "": + if not dst: return if os.path.isfile(src): @@ -237,10 +246,34 @@ def copy_file(cls, src, dst): cls.copy_folder(src, dst) @classmethod - def download(cls, src, dst, unzip=False) -> str: - if dst is None: - dst = tempfile.mkdtemp() + def dump(cls, obj, dst=None) -> str: + fd, name = tempfile.mkstemp() + os.close(fd) + joblib.dump(obj, name) + return cls.upload(name, dst) + + @classmethod + def load(cls, src: str): + src = cls.download(src) + obj = joblib.load(src) + return obj + @classmethod + def is_remote(cls, src): + if src.startswith(( + cls._GCS_PREFIX, + cls._S3_PREFIX + )): + return True + if re.search(cls._URI_RE, src): + return True + return False + + @classmethod + def download(cls, src, dst=None, unzip=False) -> str: + if dst is None: + fd, dst = tempfile.mkstemp() + os.close(fd) cls.clean_folder([os.path.dirname(dst)], clean=False) if src.startswith(cls._GCS_PREFIX): cls.gcs_download(src, dst) @@ -255,18 +288,29 @@ def download(cls, src, dst, unzip=False) -> str: return dst @classmethod - def upload(cls, src, dst, tar=False) -> str: + def upload(cls, src, dst, tar=False, clean=True) -> str: if dst is None: - dst = tempfile.mkdtemp() + fd, dst = tempfile.mkstemp() + os.close(fd) + if not cls.is_local(src): + fd, name = tempfile.mkstemp() + os.close(fd) + cls.download(src, name) + src = name if tar: cls._tar(src, f"{src}.tar.gz") src = f"{src}.tar.gz" + if dst.startswith(cls._GCS_PREFIX): cls.gcs_upload(src, dst) elif dst.startswith(cls._S3_PREFIX): cls.s3_upload(src, dst) - elif cls.is_local(dst): + else: cls.copy_file(src, dst) + if cls.is_local(src) and clean: + if cls.is_local(dst) and os.path.samefile(src, dst): + return dst + cls.delete(src) return dst @classmethod @@ -287,21 +331,24 @@ def _download_s3(cls, client, uri, out_dir): bucket_name = bucket_args[0] bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" - objects = client.list_objects(bucket_name, - prefix=bucket_path, - recursive=True, - use_api_v1=True) + objects = list(client.list_objects(bucket_name, + prefix=bucket_path, + recursive=True, + use_api_v1=True)) count = 0 - + num = len(objects) for obj in objects: # Replace any prefix from the object key with out_dir subdir_object_key = obj.object_name[len(bucket_path):].strip("/") # fget_object handles directory creation if does not exist if not obj.is_dir: - local_file = os.path.join( - out_dir, - subdir_object_key or os.path.basename(obj.object_name) - ) + if num == 1 and not os.path.isdir(out_dir): + local_file = out_dir + else: + local_file = os.path.join( + out_dir, + subdir_object_key or os.path.basename(obj.object_name) + ) client.fget_object(bucket_name, obj.object_name, local_file) count += 1 @@ -311,9 +358,10 @@ def _download_s3(cls, client, uri, out_dir): def s3_download(cls, src, dst): s3 = _create_minio_client() count = cls._download_s3(s3, src, dst) + if count == 0: raise RuntimeError("Failed to fetch files." - "The path %s does not exist." % (src)) + "The path %s does not exist." % src) @classmethod def s3_upload(cls, src, dst): diff --git a/lib/sedna/core/base.py b/lib/sedna/core/base.py index 858d3c48d..b7f877242 100644 --- a/lib/sedna/core/base.py +++ b/lib/sedna/core/base.py @@ -27,35 +27,7 @@ __all__ = ('JobBase',) -class DistributedWorker: - """"Class of Distributed Worker use to manage all jobs""" - # original params - __worker_path__ = None - __worker_module__ = None - # id params - __worker_id__ = 0 - - def __init__(self): - DistributedWorker.__worker_id__ += 1 - self._worker_id = DistributedWorker.__worker_id__ - self.timeout = 0 - - @property - def worker_id(self): - """Property: worker_id.""" - return self._worker_id - - @worker_id.setter - def worker_id(self, value): - """Setter: set worker_id with value. - - :param value: worker id - :type value: int - """ - self._worker_id = value - - -class JobBase(DistributedWorker): +class JobBase: """ sedna feature base class """ parameters = Context @@ -68,8 +40,7 @@ def __init__(self, estimator, config=None): self.estimator = set_backend(estimator=estimator, config=self.config) self.job_kind = K8sResourceKind.DEFAULT.value self.job_name = self.config.job_name or self.config.service_name - work_name = f"{self.job_name}-{self.worker_id}" - self.worker_name = self.config.worker_name or work_name + self.worker_name = self.config.worker_name or self.job_name @property def initial_hem(self): diff --git a/lib/sedna/core/lifelong_learning/lifelong_learning.py b/lib/sedna/core/lifelong_learning/lifelong_learning.py index 64658a825..17e4c613c 100644 --- a/lib/sedna/core/lifelong_learning/lifelong_learning.py +++ b/lib/sedna/core/lifelong_learning/lifelong_learning.py @@ -15,12 +15,11 @@ import os import tempfile -import joblib - from sedna.backend import set_backend from sedna.core.base import JobBase from sedna.common.file_ops import FileOps from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus +from sedna.common.constant import KBResourceConstant from sedna.common.config import Context from sedna.common.class_factory import ClassType, ClassFactory from sedna.algorithms.multi_task_learning import MulTaskLearning @@ -34,40 +33,39 @@ class LifelongLearning(JobBase): def __init__(self, estimator, - task_definition="TaskDefinitionByDataAttr", + task_definition=None, task_relationship_discovery=None, task_mining=None, task_remodeling=None, inference_integrate=None, - unseen_task_detect="TaskAttrFilter", - - task_definition_param=None, - relationship_discovery_param=None, - task_mining_param=None, - task_remodeling_param=None, - inference_integrate_param=None, - unseen_task_detect_param=None): + unseen_task_detect=None): + + if not task_definition: + task_definition = { + "method": "TaskDefinitionByDataAttr" + } + if not unseen_task_detect: + unseen_task_detect = { + "method": "TaskAttrFilter" + } e = MulTaskLearning( estimator=estimator, task_definition=task_definition, task_relationship_discovery=task_relationship_discovery, task_mining=task_mining, task_remodeling=task_remodeling, - inference_integrate=inference_integrate, - task_definition_param=task_definition_param, - relationship_discovery_param=relationship_discovery_param, - task_mining_param=task_mining_param, - task_remodeling_param=task_remodeling_param, - inference_integrate_param=inference_integrate_param) - self.unseen_task_detect = unseen_task_detect + inference_integrate=inference_integrate) + self.unseen_task_detect = unseen_task_detect.get("method", + "TaskAttrFilter") self.unseen_task_detect_param = e.parse_param( - unseen_task_detect_param + unseen_task_detect.get("param", {}) ) config = dict( ll_kb_server=Context.get_parameters("KB_SERVER"), output_url=Context.get_parameters("OUTPUT_URL", "/tmp") ) - task_index = FileOps.join_path(config['output_url'], 'index.pkl') + task_index = FileOps.join_path(config['output_url'], + KBResourceConstant.KB_INDEX_NAME) config['task_index'] = task_index super(LifelongLearning, self).__init__( estimator=e, config=config @@ -91,51 +89,80 @@ def train(self, train_data, if post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) - res = self.estimator.train( + res, task_index_url = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs ) # todo: Distinguishing incremental update and fully overwrite - task_groups = self.estimator.estimator.task_groups - extractor_file = FileOps.join_path( - os.path.dirname(self.estimator.estimator.task_index_url), - "kb_extractor.pkl" - ) - try: - extractor_file = self.kb_server.upload_file(extractor_file) - except Exception as err: - self.log.error( - f"Upload task extractor_file fail {extractor_file}: {err}") - extractor_file = joblib.load(extractor_file) + if isinstance(task_index_url, str) and FileOps.exists(task_index_url): + task_index = FileOps.load(task_index_url) + else: + task_index = task_index_url + + extractor = task_index['extractor'] + task_groups = task_index['task_groups'] + + model_upload_key = {} for task in task_groups: + model_file = task.model.model + save_model = FileOps.join_path( + self.config.output_url, + os.path.basename(model_file) + ) + if model_file not in model_upload_key: + model_upload_key[model_file] = FileOps.upload(model_file, + save_model) + model_file = model_upload_key[model_file] + try: - model = self.kb_server.upload_file(task.model.model) - except Exception: - model_obj = set_backend( + model = self.kb_server.upload_file(save_model) + except Exception as err: + self.log.error( + f"Upload task model of {model_file} fail: {err}" + ) + model = set_backend( estimator=self.estimator.estimator.base_model ) - model = model_obj.load(task.model.model) + model.load(model_file) task.model.model = model + for _task in task.tasks: + sample_dir = FileOps.join_path( + self.config.output_url, + f"{_task.samples.data_type}_{_task.entry}.sample") + task.samples.save(sample_dir) + try: + sample_dir = self.kb_server.upload_file(sample_dir) + except Exception as err: + self.log.error( + f"Upload task samples of {_task.entry} fail: {err}") + _task.samples.data_url = sample_dir + + save_extractor = FileOps.join_path( + self.config.output_url, + KBResourceConstant.TASK_EXTRACTOR_NAME + ) + extractor = FileOps.dump(extractor, save_extractor) + try: + extractor = self.kb_server.upload_file(extractor) + except Exception as err: + self.log.error(f"Upload task extractor fail: {err}") task_info = { "task_groups": task_groups, - "extractor": extractor_file + "extractor": extractor } fd, name = tempfile.mkstemp() - joblib.dump(task_info, name) + FileOps.dump(task_info, name) index_file = self.kb_server.update_db(name) if not index_file: self.log.error(f"KB update Fail !") index_file = name - FileOps.upload(index_file, self.config.task_index) - if os.path.isfile(name): - os.close(fd) - os.remove(name) + task_info_res = self.estimator.model_info( - self.config.task_index, result=res, + self.config.task_index, relpath=self.config.data_path_prefix) self.report_task_info( None, K8sResourceKindStatus.COMPLETED.value, task_info_res) @@ -152,7 +179,7 @@ def update(self, train_data, valid_data=None, post_process=None, **kwargs): **kwargs ) - def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): + def evaluate(self, data, post_process=None, **kwargs): callback_func = None if callable(post_process): callback_func = post_process @@ -167,14 +194,35 @@ def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): FileOps.download(task_index_url, index_url) res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) drop_tasks = [] + + model_filter_operator = self.get_parameters("operator", ">") + model_threshold = float(self.get_parameters('model_threshold', 0.1)) + + operator_map = { + ">": lambda x, y: x > y, + "<": lambda x, y: x < y, + "=": lambda x, y: x == y, + ">=": lambda x, y: x >= y, + "<=": lambda x, y: x <= y, + } + if model_filter_operator not in operator_map: + self.log.warn( + f"operator {model_filter_operator} use to " + f"compare is not allow, set to <" + ) + model_filter_operator = "<" + operator_func = operator_map[model_filter_operator] + for detail in tasks_detail: scores = detail.scores entry = detail.entry - self.log.info(f"{entry} socres: {scores}") - if any(map(lambda x: float(x) < model_threshold, scores.values())): + self.log.info(f"{entry} scores: {scores}") + if any(map(lambda x: operator_func(float(x), + model_threshold), + scores.values())): self.log.warn( - f"{entry} will not be deploy " - f"because scores lt {model_threshold}") + f"{entry} will not be deploy because all " + f"scores {model_filter_operator} {model_threshold}") drop_tasks.append(entry) continue drop_task = ",".join(drop_tasks) @@ -196,6 +244,7 @@ def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): return callback_func(res) if callback_func else res def inference(self, data=None, post_process=None, **kwargs): + task_index_url = self.get_parameters( "MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url diff --git a/lib/sedna/datasources/__init__.py b/lib/sedna/datasources/__init__.py index 25d48dc82..b793e163f 100644 --- a/lib/sedna/datasources/__init__.py +++ b/lib/sedna/datasources/__init__.py @@ -14,7 +14,6 @@ from abc import ABC -import joblib import numpy as np import pandas as pd @@ -51,7 +50,7 @@ def is_test_data(self): return self.data_type == "test" def save(self, output=""): - joblib.dump(self, output) + return FileOps.dump(self, output) class TxtDataParse(BaseDataSource, ABC): diff --git a/lib/sedna/service/run_kb.py b/lib/sedna/service/run_kb.py index 73d682f6c..4cca7de4c 100644 --- a/lib/sedna/service/run_kb.py +++ b/lib/sedna/service/run_kb.py @@ -22,13 +22,14 @@ def main(): init_db() server = os.getenv("KnowledgeBaseServer", "") + kb_dir = os.getenv("KnowledgeBasePath", "") match = re.compile( "(https?)://([0-9]{1,3}(?:\\.[0-9]{1,3}){3}):([0-9]+)").match(server) if match: _, host, port = match.groups() else: host, port = '0.0.0.0', 9020 - KBServer(host=host, http_port=int(port)).start() + KBServer(host=host, http_port=int(port), save_dir=kb_dir).start() if __name__ == '__main__': diff --git a/lib/sedna/service/server/knowledgeBase/server.py b/lib/sedna/service/server/knowledgeBase/server.py index f570de05a..a489e0588 100644 --- a/lib/sedna/service/server/knowledgeBase/server.py +++ b/lib/sedna/service/server/knowledgeBase/server.py @@ -27,6 +27,7 @@ from sedna.service.server.base import BaseServer from sedna.common.file_ops import FileOps +from sedna.common.constant import KBResourceConstant from .model import * @@ -52,7 +53,7 @@ def __init__(self, host: str, http_port: int = 8080, http_port=http_port, workers=workers) self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0] self.url = f"{self.url}/{servername}" - self.latest = 0 + self.kb_index = KBResourceConstant.KB_INDEX_NAME.value self.app = FastAPI( routes=[ APIRoute( @@ -94,8 +95,7 @@ def query(self): pass def _get_db_index(self): - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") + _index_path = FileOps.join_path(self.save_dir, self.kb_index) if not FileOps.exists(_index_path): # todo: get from kb pass return _index_path @@ -130,8 +130,7 @@ def update_status(self, data: KBUpdateResult = Body(...)): }, synchronize_session=False) # todo: get from kb - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") + _index_path = FileOps.join_path(self.save_dir, self.kb_index) task_info = joblib.load(_index_path) new_task_group = [] @@ -143,13 +142,9 @@ def update_status(self, data: KBUpdateResult = Body(...)): continue new_task_group.append(task_group) task_info["task_groups"] = new_task_group - self.latest += 1 - - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") - joblib.dump(task_info, _index_path) - res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" - return res + _index_path = FileOps.join_path(self.save_dir, self.kb_index) + FileOps.dump(task_info, _index_path) + return f"/file/download?files={self.kb_index}&name={self.kb_index}" def update(self, task: UploadFile = File(...)): tasks = task.file.read() @@ -178,21 +173,16 @@ def update(self, task: UploadFile = File(...)): if t_create: session.add(t_obj) - sampel_obj = Samples( + sample_obj = Samples( data_type=task.samples.data_type, - sample_num=len(task.samples) + sample_num=len(task.samples), + data_url=getattr(task, 'data_url', '') ) - session.add(sampel_obj) + session.add(sample_obj) session.flush() session.commit() - sample_dir = FileOps.join_path( - self.save_dir, - f"{sampel_obj.data_type}_{sampel_obj.id}.pkl") - task.samples.save(sample_dir) - sampel_obj.data_url = sample_dir - - tsample = TaskSample(sample=sampel_obj, task=t_obj) + tsample = TaskSample(sample=sample_obj, task=t_obj) session.add(tsample) session.flush() t_id.append(t_obj.id) @@ -221,15 +211,8 @@ def update(self, task: UploadFile = File(...)): session.commit() - self.latest += 1 - extractor_file = upload_info["extractor"] - extractor_path = FileOps.join_path(self.save_dir, - f"kb_extractor.pkl") - FileOps.upload(extractor_file, extractor_path) - # todo: get from kb - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") - FileOps.upload(name, _index_path) - res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" - return res + _index_path = FileOps.join_path(self.save_dir, self.kb_index) + _index_path = FileOps.dump(upload_info, _index_path) + + return f"/file/download?files={self.kb_index}&name={self.kb_index}"