Skip to content

Commit

Permalink
[ML-25] Fix oneCCL KVS port auto detect and improve logging (oap-proj…
Browse files Browse the repository at this point in the history
…ect#24)

* Fix port auto detect
Improve logging:
  Use stderr for native logging
  Use spark.internal.Logging for Scala logging
  Use slf4j for Java logging

* nit
  • Loading branch information
xwu99 authored Feb 9, 2021
1 parent 47aae78 commit 9080d02
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 68 deletions.
14 changes: 7 additions & 7 deletions mllib-dal/src/main/java/org/apache/spark/ml/util/LibLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import java.io.*;
import java.util.UUID;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.intel.daal.utils.LibUtils;

Expand All @@ -30,8 +31,7 @@ public final class LibLoader {
// Make sure loading libraries from different temp directory for each process
private final static String subDir = "MLlibDAL_" + UUID.randomUUID();

private static final Logger logger = Logger.getLogger(LibLoader.class.getName());
private static final Level logLevel = Level.INFO;
private static final Logger log = LoggerFactory.getLogger("LibLoader");

/**
* Get temp dir for exacting lib files
Expand Down Expand Up @@ -81,12 +81,12 @@ private static synchronized void loadLibMLlibDAL() throws IOException {
* @param name library name
*/
private static void loadFromJar(String path, String name) throws IOException {
logger.log(logLevel, "Loading " + name + " ...");
log.debug("Loading " + name + " ...");

File fileOut = createTempFile(path, name);
// File exists already
if (fileOut == null) {
logger.log(logLevel, "DONE: Loading library as resource.");
log.debug("DONE: Loading library as resource.");
return;
}

Expand All @@ -96,7 +96,7 @@ private static void loadFromJar(String path, String name) throws IOException {
}

try (OutputStream streamOut = new FileOutputStream(fileOut)) {
logger.log(logLevel, "Writing resource to temp file.");
log.debug("Writing resource to temp file.");

byte[] buffer = new byte[32768];
while (true) {
Expand All @@ -115,7 +115,7 @@ private static void loadFromJar(String path, String name) throws IOException {
}

System.load(fileOut.toString());
logger.log(logLevel, "DONE: Loading library as resource.");
log.debug("DONE: Loading library as resource.");
}

/**
Expand Down
40 changes: 33 additions & 7 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ccl::communicator &getComm() {
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init
(JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jobject param) {

std::cout << "oneCCL (native): init" << std::endl;
std::cerr << "OneCCL (native): init" << std::endl;

auto t1 = std::chrono::high_resolution_clock::now();

Expand All @@ -42,7 +42,7 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init

auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>( t2 - t1 ).count();
std::cout << "oneCCL (native): init took " << duration << " secs" << std::endl;
std::cerr << "OneCCL (native): init took " << duration << " secs" << std::endl;

rank_id = getComm().rank();
comm_size = getComm().size();
Expand All @@ -68,7 +68,7 @@ JNIEXPORT void JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1cleanup

g_comms.pop_back();

std::cout << "oneCCL (native): cleanup" << std::endl;
std::cerr << "OneCCL (native): cleanup" << std::endl;

}

Expand Down Expand Up @@ -112,6 +112,24 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_setEnv
return err;
}

#define GET_IP_CMD "hostname -I"
#define MAX_KVS_VAL_LENGTH 130
#define READ_ONLY "r"

static bool is_valid_ip(char ip[]) {
FILE *fp;
// TODO: use getifaddrs instead of popen
if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) {
printf("Can't get host IP\n");
exit(1);
}
char host_ips[MAX_KVS_VAL_LENGTH];
fgets(host_ips, MAX_KVS_VAL_LENGTH, fp);
pclose(fp);

return strstr(host_ips, ip) ? true : false;
}

/*
* Class: org_apache_spark_ml_util_OneCCL__
* Method: getAvailPort
Expand All @@ -120,12 +138,18 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_setEnv
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_getAvailPort
(JNIEnv *env, jobject obj, jstring localIP) {

// start from beginning of dynamic port
const int port_start_base = 3000;

char* local_host_ip = (char *) env->GetStringUTFChars(localIP, NULL);

// check if the input ip is one of host's ips
if (!is_valid_ip(local_host_ip))
return -1;

struct sockaddr_in main_server_address;
int server_listen_sock;
in_port_t port = port_start_base;

if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("OneCCL (native) getAvailPort error!");
Expand All @@ -134,17 +158,19 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_getAvailPort

main_server_address.sin_family = AF_INET;
main_server_address.sin_addr.s_addr = inet_addr(local_host_ip);
main_server_address.sin_port = port_start_base;
main_server_address.sin_port = htons(port);

// search for available port
while (bind(server_listen_sock,
(const struct sockaddr *)&main_server_address,
sizeof(main_server_address)) < 0) {
main_server_address.sin_port++;
port++;
main_server_address.sin_port = htons(port);
}

close(server_listen_sock);
close(server_listen_sock);

env->ReleaseStringUTFChars(localIP, local_host_ip);

return main_server_address.sin_port;
return port;
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class KMeansDALImpl (
val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)

val kvsPortDetected = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsPortDetected = Utils.checkExecutorAvailPort(data, kvsIP)
val kvsPort = data.sparkContext.conf.getInt("spark.oap.mllib.oneccl.kvs.port", kvsPortDetected)

val kvsIPPort = kvsIP+"_"+kvsPort
Expand All @@ -70,14 +70,14 @@ class KMeansDALImpl (
val it = entry._3
val numCols = partitionDims(index)._2

println(s"KMeansDALImpl: Partition index: $index, numCols: $numCols, numRows: $numRows")
logDebug(s"KMeansDALImpl: Partition index: $index, numCols: $numCols, numRows: $numRows")

// Build DALMatrix, this will load libJavaAPI, libtbb, libtbbmalloc
val context = new DaalContext()
val matrix = new DALMatrix(context, classOf[java.lang.Double],
numCols.toLong, numRows.toLong, NumericTable.AllocationFlag.DoAllocate)

println("KMeansDALImpl: Loading native libraries" )
logDebug("KMeansDALImpl: Loading native libraries" )
// oneDAL libs should be loaded by now, extract libMLlibDAL.so to temp file and load
LibLoader.loadLibraries()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
package org.apache.spark.ml.feature

import java.util.Arrays

import com.intel.daal.data_management.data.{HomogenNumericTable, NumericTable}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.util.{OneCCL, OneDAL, Utils}
import org.apache.spark.mllib.feature.{PCAModel => MLlibPCAModel}
import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.feature.{ StandardScaler => MLlibStandardScaler }
import org.apache.spark.mllib.feature.{StandardScaler => MLlibStandardScaler}

class PCADALImpl (
val k: Int,
val executorNum: Int,
val executorCores: Int) extends Serializable {
val executorCores: Int)
extends Serializable with Logging {

// Normalize data before apply fitWithDAL
private def normalizeData(input: RDD[Vector]) : RDD[Vector] = {
Expand All @@ -49,7 +50,7 @@ class PCADALImpl (
val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)

val kvsPortDetected = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsPortDetected = Utils.checkExecutorAvailPort(data, kvsIP)
val kvsPort = data.sparkContext.conf.getInt("spark.oap.mllib.oneccl.kvs.port", kvsPortDetected)

val kvsIPPort = kvsIP+"_"+kvsPort
Expand Down
52 changes: 11 additions & 41 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,32 @@

package org.apache.spark.ml.util

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging

object OneCCL {
object OneCCL extends Logging {

var cclParam = new CCLParam()

// var kvsIPPort = sys.env.getOrElse("CCL_KVS_IP_PORT", "")
// var worldSize = sys.env.getOrElse("CCL_WORLD_SIZE", "1").toInt

// var kvsPort = 5000

// private def checkEnv() {
// val altTransport = sys.env.getOrElse("CCL_ATL_TRANSPORT", "")
// val pmType = sys.env.getOrElse("CCL_PM_TYPE", "")
// val ipExchange = sys.env.getOrElse("CCL_KVS_IP_EXCHANGE", "")
//
// assert(altTransport == "ofi")
// assert(pmType == "resizable")
// assert(ipExchange == "env")
// assert(kvsIPPort != "")
//
// }

// Run on Executor
// def setExecutorEnv(executor_num: Int, ip: String, port: Int): Unit = {
// // Work around ccl by passings in a spark.executorEnv.CCL_KVS_IP_PORT.
// val ccl_kvs_ip_port = sys.env.getOrElse("CCL_KVS_IP_PORT", s"${ip}_${port}")
//
// println(s"oneCCL: Initializing with CCL_KVS_IP_PORT: $ccl_kvs_ip_port")
//
// setEnv("CCL_PM_TYPE", "resizable")
// setEnv("CCL_ATL_TRANSPORT","ofi")
// setEnv("CCL_ATL_TRANSPORT_PATH", LibLoader.getTempSubDir())
// setEnv("CCL_KVS_IP_EXCHANGE","env")
// setEnv("CCL_KVS_IP_PORT", ccl_kvs_ip_port)
// setEnv("CCL_WORLD_SIZE", s"${executor_num}")
// // Uncomment this if you whant to debug oneCCL
// // setEnv("CCL_LOG_LEVEL", "2")
// }
def setExecutorEnv(): Unit = {
setEnv("CCL_ATL_TRANSPORT","ofi")
// Uncomment this if you whant to debug oneCCL
// setEnv("CCL_LOG_LEVEL", "2")
}

def init(executor_num: Int, rank: Int, ip_port: String) = {

// setExecutorEnv(executor_num, ip, port)
println(s"oneCCL: Initializing with IP_PORT: ${ip_port}")
setExecutorEnv()

logInfo(s"Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip_port, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.commSize, "executor number should equal to oneCCL world size")

println(s"oneCCL: Initialized with executorNum: $executor_num, commSize, ${cclParam.commSize}, rankId: ${cclParam.rankId}")

// Use a new port when calling init again
// kvsPort = kvsPort + 1

logInfo(s"Initialized with executorNum: $executor_num, commSize, ${cclParam.commSize}, rankId: ${cclParam.rankId}")
}

// Run on Executor
Expand Down
12 changes: 6 additions & 6 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ object Utils {
ip
}

def checkExecutorAvailPort(sc: SparkContext, localIP: String) : Int = {
val executor_num = Utils.sparkExecutorNum(sc)
val data = sc.parallelize(1 to executor_num, executor_num)
val result = data.mapPartitionsWithIndex { (index, p) =>
def checkExecutorAvailPort(data: RDD[_], localIP: String) : Int = {
val sc = data.sparkContext
val result = data.mapPartitions { p =>
LibLoader.loadLibraries()
if (index == 0)
Iterator(OneCCL.getAvailPort(localIP))
val port = OneCCL.getAvailPort(localIP)
if (port != -1)
Iterator(port)
else
Iterator()
}.collect()
Expand Down

0 comments on commit 9080d02

Please sign in to comment.