Skip to content

Commit

Permalink
test out old sparse dataset create method
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jun 2, 2021
1 parent 9e9ff1a commit 90e5f0a
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 87 deletions.
21 changes: 21 additions & 0 deletions src/main/scala/com/microsoft/lightgbm/CSRUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.lightgbm

/** Temporary class that accepts int32_t arrays instead of void pointer arguments.
* TODO: Need to generate a new lightGBM jar with utility to convert int array
* to void pointer and then remove this file.
*/
object CSRUtils {
// scalastyle:off parameter.number
def LGBM_DatasetCreateFromCSR(var0: SWIGTYPE_p_int, var1: Int, var2: SWIGTYPE_p_int, var3: SWIGTYPE_p_void,
var4: Int, var5: Int, var6: Int,
var7: Int, var8: String, var9: SWIGTYPE_p_void,
var10: SWIGTYPE_p_p_void): Int = {
lightgbmlibJNI.LGBM_DatasetCreateFromCSR(SWIGTYPE_p_int.getCPtr(var0), var1, SWIGTYPE_p_int.getCPtr(var2),
SWIGTYPE_p_void.getCPtr(var3), var4, var5, var6,
var7, var8, SWIGTYPE_p_void.getCPtr(var9), SWIGTYPE_p_p_void.getCPtr(var10))
}
// scalastyle:on parameter.number
}
46 changes: 37 additions & 9 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,30 +255,58 @@ object LightGBMUtils {
dataset
}


def newDoubleArray(array: Array[Double]): (SWIGTYPE_p_void, SWIGTYPE_p_double) = {
val data = lightgbmlib.new_doubleArray(array.length)
array.zipWithIndex.foreach {
case (value, index) => lightgbmlib.doubleArray_setitem(data, index, value)
}
(lightgbmlib.double_to_voidp_ptr(data), data)
}

def newIntArray(array: Array[Int]): SWIGTYPE_p_int = {
val data = lightgbmlib.new_intArray(array.length)
array.zipWithIndex.foreach {
case (value, index) => lightgbmlib.intArray_setitem(data, index, value)
}
data
}

def intToPtr(value: Int): SWIGTYPE_p_long = {
val longPtr = lightgbmlib.new_longp()
lightgbmlib.longp_assign(longPtr, value)
longPtr
}

/** Generates a sparse dataset in CSR format.
*
* @param sparseRows The rows of sparse vector.
* @return
*/
def generateSparseDataset(sparseRows: Array[SparseVector],
def generateSparseDataset(numCols: Long,
indptrLength: Long,
valuesLength: Long,
values: SWIGTYPE_p_void,
indexes: SWIGTYPE_p_int,
indptr: SWIGTYPE_p_void,
referenceDataset: Option[LightGBMDataset],
featureNamesOpt: Option[Array[String]],
trainParams: TrainParams): LightGBMDataset = {
val numCols = sparseRows(0).size

val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
// Generate the dataset for features
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
sparseRows.length,
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSR(indptr,
dataInt32bitType, indexes, values, data64bitType,
indptrLength, valuesLength,
numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull,
datasetOutPtr),
"Dataset create")
val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr))
dataset.setFeatureNames(featureNamesOpt, numCols)
dataset.setFeatureNames(featureNamesOpt, numCols.toInt)
dataset
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ private object TrainUtils extends Serializable {
val isMainWorker = mainExecutorWorker == myTaskId
log.info(s"Using singleDatasetMode. " +
s"Is main worker: ${isMainWorker} for task id: ${myTaskId} and main task id: ${mainExecutorWorker}")
SingletonDataset.incrementArrayProcessedSignal()
if (!isMainWorker) {
SingletonDataset.incrementDoneSignal()
}
Expand Down
Loading

0 comments on commit 90e5f0a

Please sign in to comment.