diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 9ef314ca5b0b..7f5a6184d363 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -93,19 +93,10 @@ jobs: - name: Test XGBoost4J (Core, Spark, Examples) run: | - rm -rfv build/ cd jvm-packages mvn -B test - if: matrix.os == 'ubuntu-latest' # Distributed training doesn't work on Windows - env: - RABIT_MOCK: ON - - name: Build and Test XGBoost4J with scala 2.13 run: | - rm -rfv build/ cd jvm-packages mvn -B clean install test -Pdefault,scala-2.13 - if: matrix.os == 'ubuntu-latest' # Distributed training doesn't work on Windows - env: - RABIT_MOCK: ON diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 00639fccf1e4..669fd7ced2c6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index dc2a52c4dd81..53811a781dcd 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -89,22 +89,18 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { val workerCount: Int = numWorkers - val dummyTasks = rdd.mapPartitions { iter => + rdd.mapPartitions { iter => val index = iter.next() Communicator.init(trackerEnvs) + val a = Array(1.0f, 2.0f, 3.0f) + System.out.println(a.mkString(", ")) + val b = Communicator.allReduce(a, Communicator.OpType.SUM) + for (i <- 0 to 2) { + assert(a(i) * workerCount == b(i)) + } Communicator.shutdown() Iterator(index) - }.cache() - - val sparkThread = new Thread() { - override def run(): Unit = { - // forces a Spark job. - dummyTasks.foreachPartition(() => _) - } - } - - sparkThread.setUncaughtExceptionHandler(tracker) - sparkThread.start() + }.collect() } test("should allow the dataframe containing communicator calls to be partially evaluated for" +