Skip to content

Commit

Permalink
Send default configuration from metric to objective. (#8760)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Feb 9, 2023
1 parent 5f76edd commit 199c421
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 7 deletions.
5 changes: 5 additions & 0 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class ObjFunction : public Configurable {

/*! \return the default evaluation metric for the objective */
virtual const char* DefaultEvalMetric() const = 0;
/**
* \brief Return the configuration for the default metric.
*/
virtual Json DefaultMetricConfig() const { return Json{Null{}}; }

// the following functions are optional, most of time default implementation is good enough
/*!
* \brief transform prediction values, this is only called when Prediction is called
Expand Down
5 changes: 4 additions & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ class LearnerConfiguration : public Learner {

auto const& objective_fn = learner_parameters.at("objective");
if (!obj_) {
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}
obj_->LoadConfig(objective_fn);
Expand Down Expand Up @@ -1311,8 +1312,10 @@ class LearnerImpl : public LearnerIO {
std::ostringstream os;
os.precision(std::numeric_limits<double>::max_digits10);
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
if (metrics_.empty() && tparam_.disable_default_eval_metric <= 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
auto config = obj_->DefaultMetricConfig();
metrics_.back()->LoadConfig(config);
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}

Expand Down
6 changes: 6 additions & 0 deletions src/objective/aft_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class AFTObj : public ObjFunction {
void LoadConfig(Json const& in) override {
FromJson(in["aft_loss_param"], &param_);
}
Json DefaultMetricConfig() const override {
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
config["aft_loss_param"] = ToJson(param_);
return config;
}

private:
AFTParam param_;
Expand Down
53 changes: 47 additions & 6 deletions tests/python/test_survival.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from typing import Optional, Tuple

import numpy as np
import pytest
Expand All @@ -10,16 +11,56 @@
dpath = tm.data_dir(__file__)


def test_aft_survival_toy_data():
# See demo/aft_survival/aft_survival_viz_demo.py
@pytest.fixture(scope="module")
def toy_data() -> Tuple[xgb.DMatrix, np.ndarray, np.ndarray]:
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])
y_lower = np.array([10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])

dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower)
dmat.set_float_info('label_upper_bound', y_upper)
dmat.set_float_info("label_lower_bound", y_lower)
dmat.set_float_info("label_upper_bound", y_upper)
return dmat, y_lower, y_upper


def test_default_metric(toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]) -> None:
Xy, y_lower, y_upper = toy_data

def run(evals: Optional[list]) -> None:
# test with or without actual evaluation.
booster = xgb.train(
{"objective": "survival:aft", "aft_loss_distribution": "extreme"},
Xy,
num_boost_round=1,
evals=evals,
)
config = json.loads(booster.save_config())
metrics = config["learner"]["metrics"]
assert len(metrics) == 1
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "extreme"

booster = xgb.train(
{"objective": "survival:aft"},
Xy,
num_boost_round=1,
evals=evals,
)
config = json.loads(booster.save_config())
metrics = config["learner"]["metrics"]
assert len(metrics) == 1
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "normal"

run([(Xy, "Train")])
run(None)


def test_aft_survival_toy_data(
toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]
) -> None:
# See demo/aft_survival/aft_survival_viz_demo.py
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
dmat, y_lower, y_upper = toy_data

# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
Expand Down

0 comments on commit 199c421

Please sign in to comment.