diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 805e63a32e87..a76536adee1b 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -120,11 +120,25 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) { ctx.nthread = asInteger(n_threads); std::int32_t threads = ctx.Threads(); - xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) { - for (size_t j = 0; j < ncol; ++j) { - data[i * ncol + j] = is_int ? static_cast(iin[i + nrow * j]) : din[i + nrow * j]; - } - }); + if (is_int) { + xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) { + for (size_t j = 0; j < ncol; ++j) { + auto v = iin[i + nrow * j]; + if (v == NA_INTEGER) { + data[i * ncol + j] = std::numeric_limits::quiet_NaN(); + } else { + data[i * ncol + j] = static_cast(v); + } + } + }); + } else { + xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) { + for (size_t j = 0; j < ncol; ++j) { + data[i * ncol + j] = din[i + nrow * j]; + } + }); + } + DMatrixHandle handle; CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol, asReal(missing), &handle, threads)); diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 57cc82c170ed..8d74a0357057 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -56,6 +56,42 @@ test_that("xgb.DMatrix: basic construction", { expect_equal(raw_fd, raw_dgc) }) +test_that("xgb.DMatrix: NA", { + n_samples <- 3 + x <- cbind( + x1 = sample(x = 4, size = n_samples, replace = TRUE), + x2 = sample(x = 4, size = n_samples, replace = TRUE) + ) + x[1, "x1"] <- NA + + m <- xgb.DMatrix(x) + xgb.DMatrix.save(m, "int.dmatrix") + + x <- matrix(as.numeric(x), nrow = n_samples, ncol = 2) + colnames(x) <- c("x1", "x2") + m <- xgb.DMatrix(x) + + xgb.DMatrix.save(m, "float.dmatrix") + + iconn <- file("int.dmatrix", "rb") + fconn <- file("float.dmatrix", "rb") + + expect_equal(file.size("int.dmatrix"), file.size("float.dmatrix")) + + bytes <- file.size("int.dmatrix") + idmatrix <- readBin(iconn, "raw", n = bytes) + fdmatrix <- readBin(fconn, "raw", n = bytes) + + expect_equal(length(idmatrix), length(fdmatrix)) + expect_equal(idmatrix, fdmatrix) + + close(iconn) + close(fconn) + + file.remove("int.dmatrix") + file.remove("float.dmatrix") +}) + test_that("xgb.DMatrix: saving, loading", { # save to a local file dtest1 <- xgb.DMatrix(test_data, label = test_label)