Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Aug 19, 2024
1 parent 407535a commit b07fe36
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 60 deletions.
34 changes: 16 additions & 18 deletions mllib-dal/src/main/native/LinearRegressionImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ ridge_regression_compute(size_t rankId, ccl::communicator &comm,

#ifdef CPU_GPU_PROFILE
static jlong doLROneAPICompute(JNIEnv *env, size_t rankId,
ccl::communicator &cclComm, sycl::queue &queue,
sycl::queue &queue,
jlong pNumTabFeature, jlong featureRows,
jlong featureCols, jlong pNumTabLabel,
jlong labelCols, jboolean jfitIntercept,
Expand Down Expand Up @@ -287,15 +287,17 @@ Java_com_intel_oap_mllib_regression_LinearRegressionDALImpl_cLinearRegressionTra
rank);

jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0);
int size = cclComm.size();
auto queue = getAssignedGPU(device, gpuIndices);

resultptr = doLROneAPICompute(
env, rank, cclComm, queue, feature, featureRows, featureCols,
env, rank, queue, feature, featureRows, featureCols,
label, labelCols, fitIntercept, executorNum, resultObj);
env->ReleaseIntArrayElements(gpuIdxArray, gpuIndices, 0);
#endif
} else {
ccl::communicator &cclComm = getComm();
size_t rankId = cclComm.rank();

NumericTablePtr pLabel = *((NumericTablePtr *)label);
NumericTablePtr pData = *((NumericTablePtr *)feature);

Expand All @@ -318,22 +320,18 @@ Java_com_intel_oap_mllib_regression_LinearRegressionDALImpl_cLinearRegressionTra

NumericTablePtr *coeffvectors = new NumericTablePtr(resultTable);
resultptr = (jlong)coeffvectors;
}

jlong ret = 0L;
if (rankId == ccl_root) {
// Get the class of the result object
jclass clazz = env->GetObjectClass(resultObj);
// Get Field references
jfieldID coeffNumericTableField =
env->GetFieldID(clazz, "coeffNumericTable", "J");
if (rankId == ccl_root) {
// Get the class of the result object
jclass clazz = env->GetObjectClass(resultObj);
// Get Field references
jfieldID coeffNumericTableField =
env->GetFieldID(clazz, "coeffNumericTable", "J");

env->SetLongField(resultObj, coeffNumericTableField, resultptr);
env->SetLongField(resultObj, coeffNumericTableField, resultptr);

// intercept is already in first column of coeffvectors
ret = resultptr;
} else {
ret = (jlong)0;
// intercept is already in first column of coeffvectors
resultptr = (jlong)coeffvectors;
}
}
return ret;
return resultptr;
}
44 changes: 2 additions & 42 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
const char *device = env->GetStringUTFChars(use_device, 0);
ccl::string ccl_ip_port(str);

auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

Expand All @@ -81,46 +79,8 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");

env->SetLongField(param, size, comm_size);
env->SetLongField(param, rank, rank_id);
env->ReleaseStringUTFChars(ip_port, str);

return 1;
}

/*
* Class: com_intel_oap_mllib_OneCCL__
* Method: c_init
* Signature: ()I
*/
JNIEXPORT jint JNICALL
Java_com_intel_oap_mllib_OneCCL_00024_c_1initDpcpp(JNIEnv *env, jobject, jint size, jint rank, jobject param) {
logger::printerrln(logger::INFO, "OneCCL (native): init dpcpp");
auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);

auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

g_kvs.push_back(singletonCCLInit.kvs);


auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");

env->SetLongField(param, size, comm_size);
env->SetLongField(param, rank, rank_id);
env->SetLongField(param, fid_comm_size, comm_size);
env->SetLongField(param, fid_rank_id, rank_id);
env->ReleaseStringUTFChars(ip_port, str);

return 1;
Expand Down

0 comments on commit b07fe36

Please sign in to comment.