diff --git a/src/main/scala/com/microsoft/lightgbm/CSRUtils.scala b/src/main/scala/com/microsoft/lightgbm/CSRUtils.scala new file mode 100644 index 00000000000..be2b0f454b0 --- /dev/null +++ b/src/main/scala/com/microsoft/lightgbm/CSRUtils.scala @@ -0,0 +1,21 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.lightgbm + +/** Temporary class that accepts int32_t arrays instead of void pointer arguments. + * TODO: Need to generate a new lightGBM jar with utility to convert int array + * to void pointer and then remove this file. + */ +object CSRUtils { + // scalastyle:off parameter.number + def LGBM_DatasetCreateFromCSR(var0: SWIGTYPE_p_int, var1: Int, var2: SWIGTYPE_p_int, var3: SWIGTYPE_p_void, + var4: Int, var5: Int, var6: Int, + var7: Int, var8: String, var9: SWIGTYPE_p_void, + var10: SWIGTYPE_p_p_void): Int = { + lightgbmlibJNI.LGBM_DatasetCreateFromCSR(SWIGTYPE_p_int.getCPtr(var0), var1, SWIGTYPE_p_int.getCPtr(var2), + SWIGTYPE_p_void.getCPtr(var3), var4, var5, var6, + var7, var8, SWIGTYPE_p_void.getCPtr(var9), SWIGTYPE_p_p_void.getCPtr(var10)) + } + // scalastyle:on parameter.number +} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala index 3c5e653515d..0383e0a6d29 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala @@ -255,6 +255,29 @@ object LightGBMUtils { dataset } + + def newDoubleArray(array: Array[Double]): (SWIGTYPE_p_void, SWIGTYPE_p_double) = { + val data = lightgbmlib.new_doubleArray(array.length) + array.zipWithIndex.foreach { + case (value, index) => lightgbmlib.doubleArray_setitem(data, index, value) + } + (lightgbmlib.double_to_voidp_ptr(data), data) + } + + def newIntArray(array: Array[Int]): (SWIGTYPE_p_int, SWIGTYPE_p_int) = { + val data = lightgbmlib.new_intArray(array.length) + array.zipWithIndex.foreach { + case (value, index) => lightgbmlib.intArray_setitem(data, index, value) + } + (lightgbmlib.int_to_int32_t_ptr(data), data) + } + + def intToPtr(value: Int): SWIGTYPE_p_long = { + val longPtr = lightgbmlib.new_longp() + lightgbmlib.longp_assign(longPtr, value) + longPtr + } + /** Generates a sparse dataset in CSR format. * * @param sparseRows The rows of sparse vector. @@ -264,21 +287,41 @@ object LightGBMUtils { referenceDataset: Option[LightGBMDataset], featureNamesOpt: Option[Array[String]], trainParams: TrainParams): LightGBMDataset = { - val numCols = sparseRows(0).size + var values: Option[(SWIGTYPE_p_void, SWIGTYPE_p_double)] = None + var indexes: Option[(SWIGTYPE_p_int, SWIGTYPE_p_int)] = None + var indptrNative: Option[(SWIGTYPE_p_int, SWIGTYPE_p_int)] = None + try { + val valuesArray = sparseRows.flatMap(_.values) + values = Some(newDoubleArray(valuesArray)) + val indexesArray = sparseRows.flatMap(_.indices) + indexes = Some(newIntArray(indexesArray)) + val indptr = new Array[Int](sparseRows.length + 1) + sparseRows.zipWithIndex.foreach { + case (row, index) => indptr(index + 1) = indptr(index) + row.numNonzeros + } + indptrNative = Some(newIntArray(indptr)) + val numCols = sparseRows(0).size - val datasetOutPtr = lightgbmlib.voidpp_handle() - val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " + - (if (trainParams.categoricalFeatures.isEmpty) "" - else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}") - // Generate the dataset for features - LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark( - sparseRows.asInstanceOf[Array[Object]], - sparseRows.length, - numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull, - datasetOutPtr), - "Dataset create") - val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr)) - dataset.setFeatureNames(featureNamesOpt, numCols) - dataset + val datasetOutPtr = lightgbmlib.voidpp_handle() + val datasetParams = "max_bin=255 is_pre_partition=True" + val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32 + val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64 + // Generate the dataset for features + LightGBMUtils.validate(CSRUtils.LGBM_DatasetCreateFromCSR( + indptrNative.get._1, dataInt32bitType, + indexes.get._1, values.get._1, data64bitType, + indptr.length, valuesArray.length, + numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull, + datasetOutPtr), + "Dataset create") + val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr)) + dataset.setFeatureNames(featureNamesOpt, numCols) + dataset + } finally { + // Delete the input rows + if (values.isDefined) lightgbmlib.delete_doubleArray(values.get._2) + if (indexes.isDefined) lightgbmlib.delete_intArray(indexes.get._2) + if (indptrNative.isDefined) lightgbmlib.delete_intArray(indptrNative.get._2) + } } }