Skip to content

Commit

Permalink
fix lifelong issue
Browse files Browse the repository at this point in the history
- Reduce parameters for initial
- show all interfaces of lifelong learning in example
- fix bugs from deep learning framework

Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
  • Loading branch information
JoeyHwong-gk committed Aug 8, 2021
1 parent cfd99d4 commit becfe48
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 264 deletions.
19 changes: 13 additions & 6 deletions examples/lifelong_learning/atcii/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json

from sedna.datasources import CSVDataParse
from sedna.common.config import Context, BaseConfig
from sedna.common.config import BaseConfig
from sedna.core.lifelong_learning import LifelongLearning

from interface import DATACONF, Estimator, feature_process
Expand All @@ -26,17 +26,24 @@ def main():
valid_data = CSVDataParse(data_type="valid", func=feature_process)
valid_data.parse(test_dataset_url, label=DATACONF["LABEL"])
attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
model_threshold = float(Context.get_parameters('model_threshold', 0))

task_definition = {
"method": "TaskDefinitionByDataAttr",
"param": attribute
}

ll_job = LifelongLearning(
estimator=Estimator,
task_definition="TaskDefinitionByDataAttr",
task_definition_param=attribute
task_definition=task_definition,
task_relationship_discovery=None,
task_mining=None,
task_remodeling=None,
inference_integrate=None,
unseen_task_detect=None
)
eval_experiment = ll_job.evaluate(
data=valid_data, metrics="precision_score",
metrics_param={"average": "micro"},
model_threshold=model_threshold
metrics_param={"average": "micro"}
)
return eval_experiment

Expand Down
32 changes: 23 additions & 9 deletions examples/lifelong_learning/atcii/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,29 @@

def main():

utd = Context.get_parameters("UTD_NAME", "TaskAttr")
utd = Context.get_parameters("UTD_NAME", "TaskAttrFilter")
attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
utd_parameters = Context.get_parameters("UTD_PARAMETERS", {})
ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp")

ll_job = LifelongLearning(
task_mining = {
"method": "TaskMiningByDataAttr",
"param": attribute
}

unseen_task_detect = {
"method": utd,
"param": utd_parameters
}

ll_service = LifelongLearning(
estimator=Estimator,
task_mining="TaskMiningByDataAttr",
task_mining_param=attribute,
unseen_task_detect=utd,
unseen_task_detect_param=utd_parameters)
task_mining=task_mining,
task_definition=None,
task_relationship_discovery=None,
task_remodeling=None,
inference_integrate=None,
unseen_task_detect=unseen_task_detect)

infer_dataset_url = Context.get_parameters('infer_dataset_url')
file_handle = open(infer_dataset_url, "r", encoding="utf-8")
Expand All @@ -60,12 +72,14 @@ def main():
rows = reader[0]
data = dict(zip(header, rows))
infer_data.parse(data, label=DATACONF["LABEL"])
rsl, is_unseen, target_task = ll_job.inference(infer_data)
rsl, is_unseen, target_task = ll_service.inference(infer_data)

rows.append(list(rsl)[0])

output = "\t".join(map(str, rows)) + "\n"
if is_unseen:
unseen_sample.write("\t".join(map(str, rows)) + "\n")
output_sample.write("\t".join(map(str, rows)) + "\n")
unseen_sample.write(output)
output_sample.write(output)
unseen_sample.close()
output_sample.close()

Expand Down
18 changes: 14 additions & 4 deletions examples/lifelong_learning/atcii/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,23 @@ def main():
train_data.parse(train_dataset_url, label=DATACONF["LABEL"])
attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
early_stopping_rounds = int(
Context.get_parameters(
"early_stopping_rounds", 100))
Context.get_parameters("early_stopping_rounds", 100)
)
metric_name = Context.get_parameters("metric_name", "mlogloss")

task_definition = {
"method": "TaskDefinitionByDataAttr",
"param": attribute
}

ll_job = LifelongLearning(
estimator=Estimator,
task_definition="TaskDefinitionByDataAttr",
task_definition_param=attribute
task_definition=task_definition,
task_relationship_discovery=None,
task_mining=None,
task_remodeling=None,
inference_integrate=None,
unseen_task_detect=None
)
train_experiment = ll_job.train(
train_data=train_data,
Expand Down
Loading

0 comments on commit becfe48

Please sign in to comment.