Skip to content

Commit

Permalink
[R] Fix integer inputs with NA. (#9522)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Aug 28, 2023
1 parent 1b87a1d commit c3574d9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
24 changes: 19 additions & 5 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,25 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
ctx.nthread = asInteger(n_threads);
std::int32_t threads = ctx.Threads();

xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
for (size_t j = 0; j < ncol; ++j) {
data[i * ncol + j] = is_int ? static_cast<float>(iin[i + nrow * j]) : din[i + nrow * j];
}
});
if (is_int) {
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
for (size_t j = 0; j < ncol; ++j) {
auto v = iin[i + nrow * j];
if (v == NA_INTEGER) {
data[i * ncol + j] = std::numeric_limits<float>::quiet_NaN();
} else {
data[i * ncol + j] = static_cast<float>(v);
}
}
});
} else {
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
for (size_t j = 0; j < ncol; ++j) {
data[i * ncol + j] = din[i + nrow * j];
}
});
}

DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
asReal(missing), &handle, threads));
Expand Down
36 changes: 36 additions & 0 deletions R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,42 @@ test_that("xgb.DMatrix: basic construction", {
expect_equal(raw_fd, raw_dgc)
})

test_that("xgb.DMatrix: NA", {
n_samples <- 3
x <- cbind(
x1 = sample(x = 4, size = n_samples, replace = TRUE),
x2 = sample(x = 4, size = n_samples, replace = TRUE)
)
x[1, "x1"] <- NA

m <- xgb.DMatrix(x)
xgb.DMatrix.save(m, "int.dmatrix")

x <- matrix(as.numeric(x), nrow = n_samples, ncol = 2)
colnames(x) <- c("x1", "x2")
m <- xgb.DMatrix(x)

xgb.DMatrix.save(m, "float.dmatrix")

iconn <- file("int.dmatrix", "rb")
fconn <- file("float.dmatrix", "rb")

expect_equal(file.size("int.dmatrix"), file.size("float.dmatrix"))

bytes <- file.size("int.dmatrix")
idmatrix <- readBin(iconn, "raw", n = bytes)
fdmatrix <- readBin(fconn, "raw", n = bytes)

expect_equal(length(idmatrix), length(fdmatrix))
expect_equal(idmatrix, fdmatrix)

close(iconn)
close(fconn)

file.remove("int.dmatrix")
file.remove("float.dmatrix")
})

test_that("xgb.DMatrix: saving, loading", {
# save to a local file
dtest1 <- xgb.DMatrix(test_data, label = test_label)
Expand Down

0 comments on commit c3574d9

Please sign in to comment.