Skip to content

Commit

Permalink
support offset in classif.xgboost for multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Jan 18, 2025
1 parent a639de1 commit 99b3e70
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
19 changes: 15 additions & 4 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,14 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
xgboost::setinfo(xgb_data, "weight", task$weights$weight)
}

# TODO: multiclass
if ("offset" %in% task$properties) {
base_margin = task$data(cols = task$col_roles$offset)[[1L]]
offset = task$data(cols = task$col_roles$offset)
if (startsWith(pv$objective, "binary")) {
base_margin = offset[[1L]]
} else {
# multiclass needs a matrix (n_samples, n_classes)
base_margin = as_numeric_matrix(offset)
}
xgboost::setinfo(xgb_data, "base_margin", base_margin)
}

Expand All @@ -274,9 +279,15 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
xgboost::setinfo(xgb_valid_data, "weight", internal_valid_task$weights$weight)
}

# TODO: multiclass
if ("offset" %in% internal_valid_task$properties) {
base_margin = internal_valid_task$data(cols = internal_valid_task$col_roles$offset)[[1L]]
valid_offset = internal_valid_task$data(cols = internal_valid_task$col_roles$offset)
if (startsWith(pv$objective, "binary")) {
base_margin = valid_offset[[1L]]
} else {
# multiclass needs a matrix (n_samples, n_classes)
base_margin = as_numeric_matrix(valid_offset)
}

xgboost::setinfo(xgb_valid_data, "base_margin", base_margin)
}

Expand Down
30 changes: 21 additions & 9 deletions tests/testthat/test_classif_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,6 @@ test_that("mlr3measures are equal to internal measures", {
})

test_that("base_margin (offset)", {
# # multiclass task
# task = tsk("iris")
#
# # same task with multiclass offset
# data = task$data()
# set(data, j = "offset_Adelie", value = runif(nrow(data)))
# set(data, j = "offset_Chinstrap", value = runif(nrow(data)))
# task = as_task_classif(data, target = "species")

# binary classification task
task = tsk("sonar")

Expand Down Expand Up @@ -406,4 +397,25 @@ test_that("base_margin (offset)", {

expect_equal(p1$prob, p2$prob) # zero offset => same predictions
expect_false(all(p1$prob[, 1L] == p3$prob[, 1L])) # non-zero offset => different predictions

# multiclass task
task = tsk("iris")

# same task with multiclass offset
data = task$data()
set(data, j = "offset_setosa", value = runif(nrow(data)))
set(data, j = "offset_versicolor", value = runif(nrow(data)))
set(data, j = "offset_virginica", value = runif(nrow(data)))
task_offset = as_task_classif(data, target = "Species")
task_offset2 = task_offset$clone()
task_offset$set_col_roles(cols = c("offset_setosa", "offset_versicolor", "offset_virginica"), roles = "offset")
task_offset2$set_col_roles(cols = c("offset_setosa", "offset_versicolor"), roles = "offset")
part = partition(task)

l = lrn("classif.xgboost", nrounds = 5, predict_type = "prob")
expect_error(l$train(task_offset2), "Invalid shape of base_margin")
p1 = l$train(task, part$train)$predict(task, part$test) # no offset
p2 = l$train(task_offset, part$train)$predict(task_offset, part$test) # with offset

expect_false(all(p1$prob == p2$prob))
})

0 comments on commit 99b3e70

Please sign in to comment.