Skip to content

Commit

Permalink
[SPARKNLP-1037] Adding addFile changes to to replace broadcast in all…
Browse files Browse the repository at this point in the history
… ONNX based annotators (#14236)

* [SPARKNLP-1011] Adding changes to transfer ONNX files on executors through Spark files feature

* [SPARKNLP-1011] Adding missing copyright comment

* [SPARKNLP-1011] Adding changes to add prefix for models with onnx_data file

* [SPARKNLP-1037] Adding changes to transfer ONNX files on executors via addFile

* [SPARKNLP-1037] Adding unique suffix to avoid duplication in spark files
  • Loading branch information
danilojsl authored May 21, 2024
1 parent fcd4e9c commit 4419a70
Show file tree
Hide file tree
Showing 49 changed files with 262 additions and 221 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.nlp.Annotation

import scala.collection.JavaConverters._
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

private[johnsnowlabs] class M2M100(
val onnxWrappers: EncoderDecoderWithoutPastWrappers,
val spp: SentencePieceWrapper,
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ private[johnsnowlabs] class Whisper(
case TensorFlow.name =>
val session =
tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes, savedSignatures = signatures)
.getTFSessionWithSignature(
configProtoBytes,
savedSignatures = signatures,
initAllTables = false)

val encodedBatchFeatures: Tensor =
encode(featuresBatch, Some(session), None).asInstanceOf[Tensor]
Expand Down
132 changes: 66 additions & 66 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.johnsnowlabs.ml.onnx

import ai.onnxruntime.OrtSession.SessionOptions
import com.johnsnowlabs.util.FileHelper
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession
Expand All @@ -32,11 +30,10 @@ trait WriteOnnxModel {
path: String,
spark: SparkSession,
onnxWrappersWithNames: Seq[(OnnxWrapper, String)],
suffix: String,
dataFileSuffix: String = "_data"): Unit = {
suffix: String): Unit = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
val fileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

// 1. Create tmp folder
val tmpFolder = Files
Expand All @@ -51,15 +48,16 @@ trait WriteOnnxModel {
onnxWrapper.saveToFile(onnxFile)

// 3. Copy to dest folder
fs.copyFromLocalFile(new Path(onnxFile), new Path(path))
fileSystem.copyFromLocalFile(new Path(onnxFile), new Path(path))

// 4. check if there is a onnx_data file
if (onnxWrapper.onnxModelPath.isDefined) {
val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix)
if (fs.exists(onnxDataFile)) {
fs.copyFromLocalFile(onnxDataFile, new Path(path))
if (onnxWrapper.dataFileDirectory.isDefined) {
val onnxDataFile = new Path(onnxWrapper.dataFileDirectory.get)
if (fileSystem.exists(onnxDataFile)) {
fileSystem.copyFromLocalFile(onnxDataFile, new Path(path))
}
}

}

// 4. Remove tmp folder
Expand All @@ -74,7 +72,6 @@ trait WriteOnnxModel {
fileName: String): Unit = {
writeOnnxModels(path, spark, Seq((onnxWrapper, fileName)), suffix)
}

}

trait ReadOnnxModel {
Expand All @@ -86,38 +83,61 @@ trait ReadOnnxModel {
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false,
sessionOptions: Option[SessionOptions] = None,
dataFileSuffix: String = "_data"): OnnxWrapper = {
modelName: Option[String] = None,
tmpFolder: Option[String] = None,
dataFilePostfix: Option[String] = None): OnnxWrapper = {

// 1. Copy to local tmp dir
val localModelFile = if (modelName.isDefined) modelName.get else onnxFile
val srcPath = new Path(path, localModelFile)
val fileSystem = getFileSystem(path, spark)
val localTmpFolder = if (tmpFolder.isDefined) tmpFolder.get else createTmpDirectory(suffix)
fileSystem.copyToLocalFile(srcPath, new Path(localTmpFolder))

// 2. Copy onnx_data file if exists
val fsPath = new Path(path, localModelFile).toString

val onnxDataFile: Option[String] = if (modelName.isDefined && dataFilePostfix.isDefined) {
Some(fsPath.replaceAll(modelName.get, s"${suffix}_${modelName.get}${dataFilePostfix.get}"))
} else None

if (onnxDataFile.isDefined) {
val onnxDataFilePath = new Path(onnxDataFile.get)
if (fileSystem.exists(onnxDataFilePath)) {
fileSystem.copyToLocalFile(onnxDataFilePath, new Path(localTmpFolder))
}
}

// 3. Read ONNX state
val onnxFileTmpPath = new Path(localTmpFolder, localModelFile).toString
val onnxWrapper =
OnnxWrapper.read(
spark,
onnxFileTmpPath,
zipped = zipped,
useBundle = useBundle,
modelName = if (modelName.isDefined) modelName.get else onnxFile,
onnxFileSuffix = Some(suffix))

onnxWrapper

}

private def getFileSystem(path: String, sparkSession: SparkSession): FileSystem = {
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
val fileSystem = FileSystem.get(uri, sparkSession.sparkContext.hadoopConfiguration)
fileSystem
}

private def createTmpDirectory(suffix: String): String = {

// 1. Create tmp directory
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
.createTempDirectory(s"${UUID.randomUUID().toString.takeRight(12)}_$suffix")
.toAbsolutePath
.toString

// 2. Copy to local dir
fs.copyToLocalFile(new Path(path, onnxFile), new Path(tmpFolder))

val localPath = new Path(tmpFolder, onnxFile).toString

val fsPath = new Path(path, onnxFile)

// 3. Copy onnx_data file if exists
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}
// 4. Read ONNX state
val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle)

// 5. Remove tmp folder
FileHelper.delete(tmpFolder)

onnxWrapper
tmpFolder
}

def readOnnxModels(
Expand All @@ -127,43 +147,23 @@ trait ReadOnnxModel {
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false,
dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = {
dataFilePostfix: String = "_data"): Map[String, OnnxWrapper] = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

// 1. Create tmp directory
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
.toAbsolutePath
.toString
val tmpFolder = Some(createTmpDirectory(suffix))

val wrappers = (modelNames map { modelName: String =>
// 2. Copy to local dir
val localModelFile = modelName
fs.copyToLocalFile(new Path(path, localModelFile), new Path(tmpFolder))

val localPath = new Path(tmpFolder, localModelFile).toString

val fsPath = new Path(path, localModelFile).toString

// 3. Copy onnx_data file if exists
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}

// 4. Read ONNX state
val onnxWrapper =
OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle, modelName = modelName)

val onnxWrapper = readOnnxModel(
path,
spark,
suffix,
zipped,
useBundle,
Some(modelName),
tmpFolder,
Option(dataFilePostfix))
(modelName, onnxWrapper)
}).toMap

// 4. Remove tmp folder
FileHelper.delete(tmpFolder)

wrappers
}

Expand Down
78 changes: 30 additions & 48 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel}
import ai.onnxruntime.providers.OrtCUDAProviderOptions
import ai.onnxruntime.{OrtEnvironment, OrtSession}
import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil}
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}
import org.apache.hadoop.fs.{FileSystem, Path}

import java.io._
import java.nio.file.{Files, Paths}
import java.util.UUID
import scala.util.{Failure, Success, Try}

class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] = None)
class OnnxWrapper(var modelFileName: Option[String] = None, var dataFileDirectory: Option[String])
extends Serializable {

/** For Deserialization */
Expand All @@ -43,10 +44,15 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]

def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) =
this.synchronized {
// TODO: After testing it works remove the Map.empty
if (ortSession == null && ortEnv == null) {
val modelFilePath = if (modelFileName.isDefined) {
SparkFiles.get(modelFileName.get)
} else {
throw new UnsupportedOperationException("modelFileName not defined")
}

val (session, env) =
OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions, onnxModelPath)
OnnxWrapper.withSafeOnnxModelLoader(onnxSessionOptions, Some(modelFilePath))
ortEnv = env
ortSession = session
}
Expand All @@ -60,17 +66,11 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]
.toAbsolutePath
.toString

// 2. Save onnx model
val fileName = Paths.get(file).getFileName.toString
val onnxFile = Paths
.get(tmpFolder, fileName)
.toString

FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel)
// 4. Zip folder
if (zip) ZipArchiveUtil.zip(tmpFolder, file)
val tmpModelFilePath = SparkFiles.get(modelFileName.get)
// 2. Zip folder
if (zip) ZipArchiveUtil.zip(tmpModelFilePath, file)

// 5. Remove tmp directory
// 3. Remove tmp directory
FileHelper.delete(tmpFolder)
}

Expand All @@ -82,7 +82,6 @@ object OnnxWrapper {

// TODO: make sure this.synchronized is needed or it's not a bottleneck
private def withSafeOnnxModelLoader(
onnxModel: Array[Byte],
sessionOptions: Map[String, String],
onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) =
this.synchronized {
Expand All @@ -96,19 +95,18 @@ object OnnxWrapper {
val session = env.createSession(onnxModelPath.get, sessionOptionsObject)
(session, env)
} else {
val session = env.createSession(onnxModel, sessionOptionsObject)
(session, env)
throw new UnsupportedOperationException("onnxModelPath not defined")
}
}

// TODO: the parts related to onnx_data should be refactored once we support addFile()
def read(
sparkSession: SparkSession,
modelPath: String,
zipped: Boolean = true,
useBundle: Boolean = false,
modelName: String = "model",
dataFileSuffix: String = "_data"): OnnxWrapper = {

dataFileSuffix: Option[String] = Some("_data"),
onnxFileSuffix: Option[String] = None): OnnxWrapper = {
// 1. Create tmp folder
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx")
Expand All @@ -118,11 +116,10 @@ object OnnxWrapper {
// 2. Unpack archive
val folder =
if (zipped)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder))
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix)
else
modelPath

val sessionOptions = new OnnxSession().getSessionOptions
val onnxFile =
if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString
else Paths.get(folder, new File(folder).list().head).toString
Expand All @@ -134,38 +131,23 @@ object OnnxWrapper {
val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath

val onnxDataFileExist: Boolean = {
onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile
onnxDataFile.exists()
if (onnxFileSuffix.isDefined && dataFileSuffix.isDefined) {
val onnxDataFilePath = s"${onnxFileSuffix.get}_$modelName${dataFileSuffix.get}"
onnxDataFile = Paths.get(parentDir, onnxDataFilePath).toFile
onnxDataFile.exists()
} else false
}

if (onnxDataFileExist) {
val onnxDataFileTmp =
Paths.get(tmpFolder, modelName + dataFileSuffix).toFile
FileUtils.copyFile(onnxDataFile, onnxDataFileTmp)
sparkSession.sparkContext.addFile(onnxDataFile.toString)
}

val modelFile = new File(onnxFile)
val modelBytes = FileUtils.readFileToByteArray(modelFile)
var session: OrtSession = null
var env: OrtEnvironment = null
if (onnxDataFileExist) {
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile))
session = _session
env = _env
} else {
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None)
session = _session
env = _env
sparkSession.sparkContext.addFile(onnxFile)

}
// 4. Remove tmp folder
FileHelper.delete(tmpFolder)
val onnxFileName = Some(new File(onnxFile).getName)
val dataFileDirectory = if (onnxDataFileExist) Some(onnxDataFile.toString) else None
val onnxWrapper = new OnnxWrapper(onnxFileName, dataFileDirectory)

val onnxWrapper =
if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile))
else new OnnxWrapper(modelBytes)
onnxWrapper.ortSession = session
onnxWrapper.ortEnv = env
onnxWrapper
}

Expand Down
Loading

0 comments on commit 4419a70

Please sign in to comment.