From 8599f87597dc230dddee831686a036af176f5c46 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 5 Aug 2020 15:21:11 +0800 Subject: [PATCH] Update JSON schema. (#5982) * Update JSON schema for pseudo huber. * Update JSON model schema. --- doc/model.schema | 46 +++++++++++++++++++++++++++++++ tests/python/test_basic_models.py | 19 +++++++++++++ 2 files changed, 65 insertions(+) diff --git a/doc/model.schema b/doc/model.schema index 322e610c213d..e9b9b3ead396 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -177,6 +177,17 @@ } } }, + "aft_loss_param": { + "type": "object", + "properties": { + "aft_loss_distribution": { + "type": "string" + }, + "aft_loss_distribution_scale": { + "type": "string" + } + } + }, "softmax_multiclass_param": { "type": "object", "properties": { @@ -273,6 +284,17 @@ "reg_loss_param" ] }, + { + "type": "object", + "properties": { + "name": { "const": "reg:pseudohubererror" }, + "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} + }, + "required": [ + "name", + "reg_loss_param" + ] + }, { "type": "object", "properties": { @@ -284,6 +306,17 @@ "reg_loss_param" ] }, + { + "type": "object", + "properties": { + "name": { "const": "reg:linear" }, + "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} + }, + "required": [ + "name", + "reg_loss_param" + ] + }, { "type": "object", "properties": { @@ -420,6 +453,19 @@ "name", "lambda_rank_param" ] + }, + { + "type": "object", + "properties": { + "name": {"const": "survival:aft"}, + "aft_loss_param": { "$ref": "#/definitions/aft_loss_param"} + } + }, + { + "type": "object", + "properties": { + "name": {"const": "binary:hinge"} + } } ] }, diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index c1802caad877..3eafdf71d821 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -346,6 +346,25 @@ def test_json_io_schema(self): schema=schema) os.remove(model_path) + try: + xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1) + except ValueError as e: + e_str = str(e) + beg = e_str.find('Objective candidate') + end = e_str.find('Stack trace') + e_str = e_str[beg: end] + e_str = e_str.strip() + splited = e_str.splitlines() + objectives = [s.split(': ')[1] for s in splited] + j_objectives = schema['properties']['learner']['properties'][ + 'objective']['oneOf'] + objectives_from_schema = set() + for j_obj in j_objectives: + objectives_from_schema.add( + j_obj['properties']['name']['const']) + objectives = set(objectives) + assert objectives == objectives_from_schema + @pytest.mark.skipif(**tm.no_json_schema()) def test_json_dump_schema(self): import jsonschema