Skip to content

Commit

Permalink
[ML-18] Auto detect KVS port for oneCCL to avoid port conflict (#19)
Browse files Browse the repository at this point in the history
* auto detect port for oneccl

* nit

* temp disable CI test for oneCCL fixes

* nit

* update use SparkConf to set kvs ip port
  • Loading branch information
xwu99 authored Feb 7, 2021
1 parent b60158b commit ac20823
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 14 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/oap-mllib-ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: OAP MLlib CI

on: [push]
on: [push, pull_request]

jobs:
build:
Expand Down Expand Up @@ -38,4 +38,5 @@ jobs:
source /opt/intel/oneapi/dal/latest/env/vars.sh
source /opt/intel/oneapi/tbb/latest/env/vars.sh
source /tmp/oneCCL/build/_install/env/setvars.sh
./test.sh
# temp disable and will enable for new release of oneCCL
#./build.sh
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.o
*.log
.vscode
*.iml
target/
.idea/
.idea_modules/
53 changes: 52 additions & 1 deletion mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
#include <iostream>
#include <chrono>

#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>

#include <oneapi/ccl.hpp>

#include "org_apache_spark_ml_util_OneCCL__.h"

// todo: fill initial comm_size and rank_id
Expand All @@ -17,10 +25,12 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init

std::cout << "oneCCL (native): init" << std::endl;

auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
ccl::string ccl_ip_port(str);

auto kvs_attr = ccl::create_kvs_attr();
kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);
Expand All @@ -30,6 +40,10 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init

g_comms.push_back(ccl::create_communicator(size, rank, kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>( t2 - t1 ).count();
std::cout << "oneCCL (native): init took " << duration << " secs" << std::endl;

rank_id = getComm().rank();
comm_size = getComm().size();

Expand Down Expand Up @@ -97,3 +111,40 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_setEnv

return err;
}

/*
* Class: org_apache_spark_ml_util_OneCCL__
* Method: getAvailPort
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_getAvailPort
(JNIEnv *env, jobject obj, jstring localIP) {

const int port_start_base = 3000;

char* local_host_ip = (char *) env->GetStringUTFChars(localIP, NULL);

struct sockaddr_in main_server_address;
int server_listen_sock;

if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("OneCCL (native) getAvailPort error!");
return -1;
}

main_server_address.sin_family = AF_INET;
main_server_address.sin_addr.s_addr = inet_addr(local_host_ip);
main_server_address.sin_port = port_start_base;

while (bind(server_listen_sock,
(const struct sockaddr *)&main_server_address,
sizeof(main_server_address)) < 0) {
main_server_address.sin_port++;
}

close(server_listen_sock);

env->ReleaseStringUTFChars(localIP, local_host_ip);

return main_server_address.sin_port;
}
1 change: 1 addition & 0 deletions mllib-dal/src/main/native/build-jni.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ javah -d $WORK_DIR/javah -classpath "$WORK_DIR/../../../target/classes:$DAAL_JAR
org.apache.spark.ml.util.OneDAL$ \
org.apache.spark.ml.clustering.KMeansDALImpl \
org.apache.spark.ml.feature.PCADALImpl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class KMeansDALImpl (
val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)

val kvsPortDetected = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsPort = data.sparkContext.conf.getInt("spark.oap.mllib.oneccl.kvs.port", kvsPortDetected)

val kvsIPPort = kvsIP+"_"+kvsPort

// repartition to executorNum if not enough partitions
val dataForConversion = if (data.getNumPartitions < executorNum) {
data.repartition(executorNum).setName("Repartitioned for conversion").cache()
Expand Down Expand Up @@ -114,7 +119,7 @@ class KMeansDALImpl (

val results = coalescedTables.mapPartitionsWithIndex { (rank, table) =>
val tableArr = table.next()
OneCCL.init(executorNum, rank, kvsIP)
OneCCL.init(executorNum, rank, kvsIPPort)

val initCentroids = OneDAL.makeNumericTable(centers)
val result = new KMeansResult()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,23 @@ class PCADALImpl (
res.map(_.asML)
}

def fitWithDAL(input: RDD[Vector]) : MLlibPCAModel = {
def fitWithDAL(data: RDD[Vector]) : MLlibPCAModel = {

val normalizedData = normalizeData(input)
val normalizedData = normalizeData(data)

val coalescedTables = OneDAL.rddVectorToNumericTables(normalizedData, executorNum)

val executorIPAddress = Utils.sparkFirstExecutorIP(input.sparkContext)
val kvsIP = input.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)
val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)

val kvsPortDetected = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsPort = data.sparkContext.conf.getInt("spark.oap.mllib.oneccl.kvs.port", kvsPortDetected)

val kvsIPPort = kvsIP+"_"+kvsPort

val results = coalescedTables.mapPartitionsWithIndex { (rank, table) =>
val tableArr = table.next()
OneCCL.init(executorNum, rank, kvsIP)
OneCCL.init(executorNum, rank, kvsIPPort)

val result = new PCAResult()
cPCATrainDAL(
Expand Down
11 changes: 6 additions & 5 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object OneCCL {
// var kvsIPPort = sys.env.getOrElse("CCL_KVS_IP_PORT", "")
// var worldSize = sys.env.getOrElse("CCL_WORLD_SIZE", "1").toInt

var kvsPort = 5000
// var kvsPort = 5000

// private def checkEnv() {
// val altTransport = sys.env.getOrElse("CCL_ATL_TRANSPORT", "")
Expand Down Expand Up @@ -57,21 +57,21 @@ object OneCCL {
// // setEnv("CCL_LOG_LEVEL", "2")
// }

def init(executor_num: Int, rank: Int, ip: String) = {
def init(executor_num: Int, rank: Int, ip_port: String) = {

// setExecutorEnv(executor_num, ip, port)
println(s"oneCCL: Initializing with IP_PORT: ${ip}_${kvsPort}")
println(s"oneCCL: Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip+"_"+kvsPort.toString, cclParam)
c_init(executor_num, rank, ip_port, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.commSize, "executor number should equal to oneCCL world size")

println(s"oneCCL: Initialized with executorNum: $executor_num, commSize, ${cclParam.commSize}, rankId: ${cclParam.rankId}")

// Use a new port when calling init again
kvsPort = kvsPort + 1
// kvsPort = kvsPort + 1

}

Expand All @@ -87,4 +87,5 @@ object OneCCL {
@native def rankID() : Int

@native def setEnv(key: String, value: String, overwrite: Boolean = true): Int
@native def getAvailPort(localIP: String): Int
}
14 changes: 14 additions & 0 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ object Utils {
ip
}

def checkExecutorAvailPort(sc: SparkContext, localIP: String) : Int = {
val executor_num = Utils.sparkExecutorNum(sc)
val data = sc.parallelize(1 to executor_num, executor_num)
val result = data.mapPartitionsWithIndex { (index, p) =>
LibLoader.loadLibraries()
if (index == 0)
Iterator(OneCCL.getAvailPort(localIP))
else
Iterator()
}.collect()

return result(0)
}

def checkClusterPlatformCompatibility(sc: SparkContext) : Boolean = {
LibLoader.loadLibraries()

Expand Down

0 comments on commit ac20823

Please sign in to comment.