diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala b/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala index 9477b72f452ce0..3394b038044e2b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala @@ -23,11 +23,11 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper import com.johnsnowlabs.nlp.Annotation - -import scala.collection.JavaConverters._ import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import org.tensorflow.{Session, Tensor} +import scala.collection.JavaConverters._ + private[johnsnowlabs] class M2M100( val onnxWrappers: EncoderDecoderWithoutPastWrappers, val spp: SentencePieceWrapper, diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala index 17e4b3f2ab6aa0..4454ef783633d0 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala @@ -297,7 +297,10 @@ private[johnsnowlabs] class Whisper( case TensorFlow.name => val session = tensorflowWrapper.get - .getTFSessionWithSignature(configProtoBytes, savedSignatures = signatures) + .getTFSessionWithSignature( + configProtoBytes, + savedSignatures = signatures, + initAllTables = false) val encodedBatchFeatures: Tensor = encode(featuresBatch, Some(session), None).asInstanceOf[Tensor] diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index 5c9156539d4cd0..e985f2b0bcac99 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -16,8 +16,6 @@ package com.johnsnowlabs.ml.onnx -import ai.onnxruntime.OrtSession.SessionOptions -import com.johnsnowlabs.util.FileHelper import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.SparkSession @@ -32,11 +30,10 @@ trait WriteOnnxModel { path: String, spark: SparkSession, onnxWrappersWithNames: Seq[(OnnxWrapper, String)], - suffix: String, - dataFileSuffix: String = "_data"): Unit = { + suffix: String): Unit = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) - val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) + val fileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) // 1. Create tmp folder val tmpFolder = Files @@ -51,15 +48,16 @@ trait WriteOnnxModel { onnxWrapper.saveToFile(onnxFile) // 3. Copy to dest folder - fs.copyFromLocalFile(new Path(onnxFile), new Path(path)) + fileSystem.copyFromLocalFile(new Path(onnxFile), new Path(path)) // 4. check if there is a onnx_data file - if (onnxWrapper.onnxModelPath.isDefined) { - val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix) - if (fs.exists(onnxDataFile)) { - fs.copyFromLocalFile(onnxDataFile, new Path(path)) + if (onnxWrapper.dataFileDirectory.isDefined) { + val onnxDataFile = new Path(onnxWrapper.dataFileDirectory.get) + if (fileSystem.exists(onnxDataFile)) { + fileSystem.copyFromLocalFile(onnxDataFile, new Path(path)) } } + } // 4. Remove tmp folder @@ -74,7 +72,6 @@ trait WriteOnnxModel { fileName: String): Unit = { writeOnnxModels(path, spark, Seq((onnxWrapper, fileName)), suffix) } - } trait ReadOnnxModel { @@ -86,38 +83,61 @@ trait ReadOnnxModel { suffix: String, zipped: Boolean = true, useBundle: Boolean = false, - sessionOptions: Option[SessionOptions] = None, - dataFileSuffix: String = "_data"): OnnxWrapper = { + modelName: Option[String] = None, + tmpFolder: Option[String] = None, + dataFilePostfix: Option[String] = None): OnnxWrapper = { + + // 1. Copy to local tmp dir + val localModelFile = if (modelName.isDefined) modelName.get else onnxFile + val srcPath = new Path(path, localModelFile) + val fileSystem = getFileSystem(path, spark) + val localTmpFolder = if (tmpFolder.isDefined) tmpFolder.get else createTmpDirectory(suffix) + fileSystem.copyToLocalFile(srcPath, new Path(localTmpFolder)) + + // 2. Copy onnx_data file if exists + val fsPath = new Path(path, localModelFile).toString + + val onnxDataFile: Option[String] = if (modelName.isDefined && dataFilePostfix.isDefined) { + Some(fsPath.replaceAll(modelName.get, s"${suffix}_${modelName.get}${dataFilePostfix.get}")) + } else None + + if (onnxDataFile.isDefined) { + val onnxDataFilePath = new Path(onnxDataFile.get) + if (fileSystem.exists(onnxDataFilePath)) { + fileSystem.copyToLocalFile(onnxDataFilePath, new Path(localTmpFolder)) + } + } + + // 3. Read ONNX state + val onnxFileTmpPath = new Path(localTmpFolder, localModelFile).toString + val onnxWrapper = + OnnxWrapper.read( + spark, + onnxFileTmpPath, + zipped = zipped, + useBundle = useBundle, + modelName = if (modelName.isDefined) modelName.get else onnxFile, + onnxFileSuffix = Some(suffix)) + + onnxWrapper + + } + private def getFileSystem(path: String, sparkSession: SparkSession): FileSystem = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) - val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) + val fileSystem = FileSystem.get(uri, sparkSession.sparkContext.hadoopConfiguration) + fileSystem + } + + private def createTmpDirectory(suffix: String): String = { // 1. Create tmp directory val tmpFolder = Files - .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix) + .createTempDirectory(s"${UUID.randomUUID().toString.takeRight(12)}_$suffix") .toAbsolutePath .toString - // 2. Copy to local dir - fs.copyToLocalFile(new Path(path, onnxFile), new Path(tmpFolder)) - - val localPath = new Path(tmpFolder, onnxFile).toString - - val fsPath = new Path(path, onnxFile) - - // 3. Copy onnx_data file if exists - val onnxDataFile = new Path(fsPath + dataFileSuffix) - - if (fs.exists(onnxDataFile)) { - fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder)) - } - // 4. Read ONNX state - val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle) - - // 5. Remove tmp folder - FileHelper.delete(tmpFolder) - - onnxWrapper + tmpFolder } def readOnnxModels( @@ -127,43 +147,23 @@ trait ReadOnnxModel { suffix: String, zipped: Boolean = true, useBundle: Boolean = false, - dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = { + dataFilePostfix: String = "_data"): Map[String, OnnxWrapper] = { - val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) - val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) - - // 1. Create tmp directory - val tmpFolder = Files - .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix) - .toAbsolutePath - .toString + val tmpFolder = Some(createTmpDirectory(suffix)) val wrappers = (modelNames map { modelName: String => - // 2. Copy to local dir - val localModelFile = modelName - fs.copyToLocalFile(new Path(path, localModelFile), new Path(tmpFolder)) - - val localPath = new Path(tmpFolder, localModelFile).toString - - val fsPath = new Path(path, localModelFile).toString - - // 3. Copy onnx_data file if exists - val onnxDataFile = new Path(fsPath + dataFileSuffix) - - if (fs.exists(onnxDataFile)) { - fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder)) - } - - // 4. Read ONNX state - val onnxWrapper = - OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle, modelName = modelName) - + val onnxWrapper = readOnnxModel( + path, + spark, + suffix, + zipped, + useBundle, + Some(modelName), + tmpFolder, + Option(dataFilePostfix)) (modelName, onnxWrapper) }).toMap - // 4. Remove tmp folder - FileHelper.delete(tmpFolder) - wrappers } diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 5478a52282990d..3b08931558a41a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -21,15 +21,16 @@ import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel} import ai.onnxruntime.providers.OrtCUDAProviderOptions import ai.onnxruntime.{OrtEnvironment, OrtSession} import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil} -import org.apache.commons.io.FileUtils +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession import org.slf4j.{Logger, LoggerFactory} -import org.apache.hadoop.fs.{FileSystem, Path} + import java.io._ import java.nio.file.{Files, Paths} import java.util.UUID import scala.util.{Failure, Success, Try} -class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] = None) +class OnnxWrapper(var modelFileName: Option[String] = None, var dataFileDirectory: Option[String]) extends Serializable { /** For Deserialization */ @@ -43,10 +44,15 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = this.synchronized { - // TODO: After testing it works remove the Map.empty if (ortSession == null && ortEnv == null) { + val modelFilePath = if (modelFileName.isDefined) { + SparkFiles.get(modelFileName.get) + } else { + throw new UnsupportedOperationException("modelFileName not defined") + } + val (session, env) = - OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions, onnxModelPath) + OnnxWrapper.withSafeOnnxModelLoader(onnxSessionOptions, Some(modelFilePath)) ortEnv = env ortSession = session } @@ -60,17 +66,11 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] .toAbsolutePath .toString - // 2. Save onnx model - val fileName = Paths.get(file).getFileName.toString - val onnxFile = Paths - .get(tmpFolder, fileName) - .toString - - FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel) - // 4. Zip folder - if (zip) ZipArchiveUtil.zip(tmpFolder, file) + val tmpModelFilePath = SparkFiles.get(modelFileName.get) + // 2. Zip folder + if (zip) ZipArchiveUtil.zip(tmpModelFilePath, file) - // 5. Remove tmp directory + // 3. Remove tmp directory FileHelper.delete(tmpFolder) } @@ -82,7 +82,6 @@ object OnnxWrapper { // TODO: make sure this.synchronized is needed or it's not a bottleneck private def withSafeOnnxModelLoader( - onnxModel: Array[Byte], sessionOptions: Map[String, String], onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) = this.synchronized { @@ -96,19 +95,18 @@ object OnnxWrapper { val session = env.createSession(onnxModelPath.get, sessionOptionsObject) (session, env) } else { - val session = env.createSession(onnxModel, sessionOptionsObject) - (session, env) + throw new UnsupportedOperationException("onnxModelPath not defined") } } - // TODO: the parts related to onnx_data should be refactored once we support addFile() def read( + sparkSession: SparkSession, modelPath: String, zipped: Boolean = true, useBundle: Boolean = false, modelName: String = "model", - dataFileSuffix: String = "_data"): OnnxWrapper = { - + dataFileSuffix: Option[String] = Some("_data"), + onnxFileSuffix: Option[String] = None): OnnxWrapper = { // 1. Create tmp folder val tmpFolder = Files .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx") @@ -118,11 +116,10 @@ object OnnxWrapper { // 2. Unpack archive val folder = if (zipped) - ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder)) + ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix) else modelPath - val sessionOptions = new OnnxSession().getSessionOptions val onnxFile = if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString else Paths.get(folder, new File(folder).list().head).toString @@ -134,38 +131,23 @@ object OnnxWrapper { val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath val onnxDataFileExist: Boolean = { - onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile - onnxDataFile.exists() + if (onnxFileSuffix.isDefined && dataFileSuffix.isDefined) { + val onnxDataFilePath = s"${onnxFileSuffix.get}_$modelName${dataFileSuffix.get}" + onnxDataFile = Paths.get(parentDir, onnxDataFilePath).toFile + onnxDataFile.exists() + } else false } if (onnxDataFileExist) { - val onnxDataFileTmp = - Paths.get(tmpFolder, modelName + dataFileSuffix).toFile - FileUtils.copyFile(onnxDataFile, onnxDataFileTmp) + sparkSession.sparkContext.addFile(onnxDataFile.toString) } - val modelFile = new File(onnxFile) - val modelBytes = FileUtils.readFileToByteArray(modelFile) - var session: OrtSession = null - var env: OrtEnvironment = null - if (onnxDataFileExist) { - val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile)) - session = _session - env = _env - } else { - val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None) - session = _session - env = _env + sparkSession.sparkContext.addFile(onnxFile) - } - // 4. Remove tmp folder - FileHelper.delete(tmpFolder) + val onnxFileName = Some(new File(onnxFile).getName) + val dataFileDirectory = if (onnxDataFileExist) Some(onnxDataFile.toString) else None + val onnxWrapper = new OnnxWrapper(onnxFileName, dataFileDirectory) - val onnxWrapper = - if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile)) - else new OnnxWrapper(modelBytes) - onnxWrapper.ortSession = session - onnxWrapper.ortEnv = env onnxWrapper } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala index 203cc50603a672..b6564ab629dc80 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala @@ -449,7 +449,7 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel { spark, Seq("encoder_model", "decoder_model", "decoder_with_past_model"), WhisperForCTC.suffix, - dataFileSuffix = ".onnx_data") + dataFilePostfix = ".onnx_data") val onnxWrappers = EncoderDecoderWrappers( wrappers("encoder_model"), @@ -580,24 +580,30 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => val onnxWrapperEncoder = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "encoder_model") + modelName = "encoder_model", + onnxFileSuffix = None) val onnxWrapperDecoder = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "decoder_model") + modelName = "decoder_model", + onnxFileSuffix = None) val onnxWrapperDecoderWithPast = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "decoder_with_past_model") + modelName = "decoder_with_past_model", + onnxFileSuffix = None) val onnxWrappers = EncoderDecoderWrappers( onnxWrapperEncoder, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 56fbf3dc80889e..902459309b571e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -328,13 +328,7 @@ trait ReadAlbertForQuestionAnsweringDLModel instance.setModelIfNotSet(spark, Some(tf), None, spp) case ONNX.name => val onnxWrapper = - readOnnxModel( - path, - spark, - "_albert_classification_onnx", - zipped = true, - useBundle = false, - None) + readOnnxModel(path, spark, "albert_qa_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -372,7 +366,12 @@ trait ReadAlbertForQuestionAnsweringDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + onnxFileSuffix = None) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index 16b9e6c196e37d..1b598ee20c987e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -382,13 +382,7 @@ trait ReadAlbertForSequenceDLModel instance.setModelIfNotSet(spark, Some(tf), None, spp) case ONNX.name => val onnxWrapper = - readOnnxModel( - path, - spark, - "_albert_classification_onnx", - zipped = true, - useBundle = false, - None) + readOnnxModel(path, spark, "albert_sequence_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -428,7 +422,12 @@ trait ReadAlbertForSequenceDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + onnxFileSuffix = None) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index 845af80e4fa753..a4fc9dc7c24e20 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -353,13 +353,7 @@ trait ReadAlbertForTokenDLModel instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) case ONNX.name => val onnxWrapper = - readOnnxModel( - path, - spark, - "_albert_classification_onnx", - zipped = true, - useBundle = false, - None) + readOnnxModel(path, spark, "albert_token_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -399,7 +393,12 @@ trait ReadAlbertForTokenDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + onnxFileSuffix = None) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala index 3c7fe2d857ec23..5f2331088f9b60 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala @@ -330,7 +330,7 @@ trait ReadBertForQuestionAnsweringDLModel extends ReadTensorflowModel with ReadO instance.setModelIfNotSet(spark, Some(tensorFlow), None) case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_bert_classification_onnx") + readOnnxModel(path, spark, "bert_qa_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) @@ -369,7 +369,8 @@ trait ReadBertForQuestionAnsweringDLModel extends ReadTensorflowModel with ReadO .setSignatures(_signatures) .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala index 1bc3df28beb65f..cb9248cce6965d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala @@ -384,7 +384,7 @@ trait ReadBertForSequenceDLModel extends ReadTensorflowModel with ReadOnnxModel instance.setModelIfNotSet(spark, Some(tensorFlow), None) case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_bert_classification_onnx") + readOnnxModel(path, spark, "bert_sequence_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) @@ -424,7 +424,8 @@ trait ReadBertForSequenceDLModel extends ReadTensorflowModel with ReadOnnxModel .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala index 3e54cd7c84e425..24a1858d7d1842 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala @@ -350,7 +350,7 @@ trait ReadBertForTokenDLModel extends ReadTensorflowModel with ReadOnnxModel { instance.setModelIfNotSet(spark, Some(tensorFlow), None) case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_bert_classification_onnx") + readOnnxModel(path, spark, "bert_token_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) @@ -389,7 +389,8 @@ trait ReadBertForTokenDLModel extends ReadTensorflowModel with ReadOnnxModel { .setSignatures(_signatures) .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala index 1a8a77ca84b582..33d3b61f0042f1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala @@ -400,7 +400,7 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel with ReadOnnxModel instance.setModelIfNotSet(spark, Some(tensorFlow), None) case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_bert_classification_onnx") + readOnnxModel(path, spark, "bert_zs_classification_onnx") instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) @@ -462,7 +462,8 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel with ReadOnnxModel .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala index c08be00d37318d..4ba692bfc2a906 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala @@ -336,7 +336,7 @@ trait ReadCamemBertForQADLModel readOnnxModel( path, spark, - "_camembert_classification_onnx", + "camembert_qa_classification_onnx", zipped = true, useBundle = false, None) @@ -377,7 +377,8 @@ trait ReadCamemBertForQADLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala index abb417366b87f0..d56b7528abefb5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala @@ -388,7 +388,7 @@ trait ReadCamemBertForSequenceDLModel readOnnxModel( path, spark, - "_camembert_classification_onnx", + "camembert_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -432,7 +432,8 @@ trait ReadCamemBertForSequenceDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala index c5a3637a96f13a..5669945561dd79 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala @@ -357,7 +357,7 @@ trait ReadCamemBertForTokenDLModel readOnnxModel( path, spark, - "_camembert_classification_onnx", + "camembert_token_classification_onnx", zipped = true, useBundle = false, None) @@ -399,7 +399,8 @@ trait ReadCamemBertForTokenDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index b27adcb4846651..8671f1ef441aac 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -333,7 +333,7 @@ trait ReadDeBertaForQuestionAnsweringDLModel readOnnxModel( path, spark, - "_deberta_classification_onnx", + "deberta_qa_classification_onnx", zipped = true, useBundle = false, None) @@ -373,7 +373,8 @@ trait ReadDeBertaForQuestionAnsweringDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index d77564f8e4c2a9..841676cecc83a6 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -385,7 +385,7 @@ trait ReadDeBertaForSequenceDLModel readOnnxModel( path, spark, - "_deberta_classification_onnx", + "deberta_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -427,7 +427,8 @@ trait ReadDeBertaForSequenceDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index 60e273ee56fed5..f2e3c1722aa6ab 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -356,7 +356,7 @@ trait ReadDeBertaForTokenDLModel readOnnxModel( path, spark, - "_deberta_classification_onnx", + "deberta_token_classification_onnx", zipped = true, useBundle = false, None) @@ -396,7 +396,8 @@ trait ReadDeBertaForTokenDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala index 9b5215bd7618a1..7f8f118370eb12 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala @@ -336,7 +336,7 @@ trait ReadDistilBertForQuestionAnsweringDLModel extends ReadTensorflowModel with readOnnxModel( path, spark, - "_distilbert_classification_onnx", + "distilbert_qa_classification_onnx", zipped = true, useBundle = false, None) @@ -378,7 +378,8 @@ trait ReadDistilBertForQuestionAnsweringDLModel extends ReadTensorflowModel with .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala index 8a60e65bcfeb6f..3defa1451cbb3d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala @@ -388,7 +388,7 @@ trait ReadDistilBertForSequenceDLModel extends ReadTensorflowModel with ReadOnnx readOnnxModel( path, spark, - "_albert_classification_onnx", + "distilbert_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -433,7 +433,8 @@ trait ReadDistilBertForSequenceDLModel extends ReadTensorflowModel with ReadOnnx .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala index 351ac574d4a148..1b13ee828787a1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala @@ -358,7 +358,7 @@ trait ReadDistilBertForTokenDLModel extends ReadTensorflowModel with ReadOnnxMod readOnnxModel( path, spark, - "_distilbert_classification_onnx", + "distilbert_token_classification_onnx", zipped = true, useBundle = false, None) @@ -401,7 +401,8 @@ trait ReadDistilBertForTokenDLModel extends ReadTensorflowModel with ReadOnnxMod .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForQuestionAnswering.scala index 469a7aa0bb1fc2..d0d7aa698b008a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForQuestionAnswering.scala @@ -296,13 +296,7 @@ trait ReadMPNetForQuestionAnsweringDLModel extends ReadOnnxModel { instance.getEngine match { case ONNX.name => val onnxWrapper = - readOnnxModel( - path, - spark, - "_mpnet_question_answering_onnx", - zipped = true, - useBundle = false, - None) + readOnnxModel(path, spark, "mpnet_qa_onnx", zipped = true, useBundle = false, None) instance.setModelIfNotSet(spark, Some(onnxWrapper)) case _ => throw new NotImplementedError("Tensorflow models are not supported.") @@ -328,7 +322,8 @@ trait ReadMPNetForQuestionAnsweringDLModel extends ReadOnnxModel { case TensorFlow.name => throw new NotImplementedError("Tensorflow models are not supported.") case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForSequenceClassification.scala index 882a871f44600b..f59bbb6808ad50 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/MPNetForSequenceClassification.scala @@ -358,7 +358,7 @@ trait ReadMPNetForSequenceDLModel extends ReadOnnxModel { readOnnxModel( path, spark, - "_mpnet_classification_onnx", + "mpnet_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -388,7 +388,8 @@ trait ReadMPNetForSequenceDLModel extends ReadOnnxModel { case TensorFlow.name => throw new NotImplementedError("Tensorflow Models are currently not supported.") case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, Some(onnxWrapper)) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala index a62e4aef0bcfcd..53db6fe18d4569 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala @@ -348,7 +348,7 @@ trait ReadRoBertaForQuestionAnsweringDLModel extends ReadTensorflowModel with Re readOnnxModel( path, spark, - "roberta_classification_onnx", + "roberta_qa_classification_onnx", zipped = true, useBundle = false, None) @@ -397,7 +397,8 @@ trait ReadRoBertaForQuestionAnsweringDLModel extends ReadTensorflowModel with Re .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala index 190835558ea6a4..93eae76247cfcf 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala @@ -400,7 +400,7 @@ trait ReadRoBertaForSequenceDLModel extends ReadTensorflowModel with ReadOnnxMod readOnnxModel( path, spark, - "roberta_classification_onnx", + "roberta_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -447,7 +447,8 @@ trait ReadRoBertaForSequenceDLModel extends ReadTensorflowModel with ReadOnnxMod .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala index 34ae2ae7bd203f..0dbfe4326ed5eb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala @@ -370,7 +370,7 @@ trait ReadRoBertaForTokenDLModel extends ReadTensorflowModel with ReadOnnxModel readOnnxModel( path, spark, - "roberta_classification_onnx", + "roberta_token_classification_onnx", zipped = true, useBundle = false, None) @@ -418,7 +418,8 @@ trait ReadRoBertaForTokenDLModel extends ReadTensorflowModel with ReadOnnxModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala index b885f640d94145..09ec74f2a42607 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala @@ -326,7 +326,7 @@ trait ReadXlmRoBertaForQuestionAnsweringDLModel readOnnxModel( path, spark, - "xlm_roberta_classification_onnx", + "xlm_roberta_qa_classification_onnx", zipped = true, useBundle = false, None) @@ -367,7 +367,8 @@ trait ReadXlmRoBertaForQuestionAnsweringDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala index 366b0ce0fa8ad6..3b8b30bfd90b46 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala @@ -377,7 +377,7 @@ trait ReadXlmRoBertaForSequenceDLModel readOnnxModel( path, spark, - "xlm_roberta_classification_onnx", + "xlm_roberta_sequence_classification_onnx", zipped = true, useBundle = false, None) @@ -421,7 +421,8 @@ trait ReadXlmRoBertaForSequenceDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala index f9f933c1f8d018..3dd353251c09b3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala @@ -349,7 +349,7 @@ trait ReadXlmRoBertaForTokenDLModel readOnnxModel( path, spark, - "xlm_roberta_classification_onnx", + "xlm_roberta_token_classification_onnx", zipped = true, useBundle = false, None) @@ -390,7 +390,8 @@ trait ReadXlmRoBertaForTokenDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/CLIPForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/CLIPForZeroShotClassification.scala index 15e766b81b6b34..dd630a96230b3c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/CLIPForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/CLIPForZeroShotClassification.scala @@ -420,7 +420,7 @@ trait ReadCLIPForZeroShotClassificationModel extends ReadTensorflowModel with Re throw new Exception("Tensorflow is currently not supported by this annotator.") case ONNX.name => val onnxWrapper = - OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), preprocessorConfig) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 0c2970e26683e8..dc4232a1a46344 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -311,7 +311,7 @@ trait ReadLLAMA2TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM this: ParamsAndFeaturesReadable[LLAMA2Transformer] => override val onnxFile: String = "llama2_onnx" - val suffix: String = "_llama2" + val suffix: String = "llama2" override val sppFile: String = "llama2_spp" def readModel(instance: LLAMA2Transformer, path: String, spark: SparkSession): Unit = { @@ -378,10 +378,13 @@ trait ReadLLAMA2TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM case ONNX.name => val onnxWrapperDecoder = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "decoder_model") + modelName = "decoder_model", + dataFileSuffix = Some(".onnx_data"), + onnxFileSuffix = Some(suffix)) val onnxWrappers = DecoderWrappers(onnxWrapperDecoder) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala index 356ade7cf96601..d17ec3bdafe696 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala @@ -38,6 +38,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.SparkSession import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import com.johnsnowlabs.util.FileHelper import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -553,16 +554,20 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM case ONNX.name => val onnxWrapperEncoder = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "encoder_model") + modelName = "encoder_model", + onnxFileSuffix = None) val onnxWrapperDecoder = OnnxWrapper.read( + spark, localModelPath, zipped = false, useBundle = true, - modelName = "decoder_model") + modelName = "decoder_model", + onnxFileSuffix = None) val onnxWrappers = EncoderDecoderWithoutPastWrappers( @@ -571,7 +576,6 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM annotatorModel .setModelIfNotSet(spark, onnxWrappers, spModel) - case _ => throw new Exception(notSupportedEngineError) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala index 41b022fe42c074..6302668156f79b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala @@ -696,12 +696,14 @@ trait ReadMarianMTDLModel OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR) val onnxEncoder = OnnxWrapper.read( + spark, localModelPath, modelName = "encoder_model", zipped = false, useBundle = true) val onnxDecoder = OnnxWrapper.read( + spark, localModelPath, modelName = "decoder_model_merged", zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala index 1e8a42a6d7416d..9b071f20498b7b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala @@ -666,12 +666,14 @@ trait ReadT5TransformerDLModel OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR) val onnxEncoder = OnnxWrapper.read( + spark, localModelPath, modelName = "encoder_model", zipped = false, useBundle = true) val onnxDecoder = OnnxWrapper.read( + spark, localModelPath, modelName = "decoder_model_merged", zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index ddb2f45e17b82d..0fe6e8b8b17bb3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -405,7 +405,7 @@ trait ReadAlbertDLModel case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) } @@ -445,7 +445,12 @@ trait ReadAlbertDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + onnxFileSuffix = None) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala index 139f9efd2dbf66..3f701c4307d8dd 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala @@ -461,7 +461,8 @@ trait ReadBGEDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala index b123d93e83a310..89cd6e52d40eb5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala @@ -473,7 +473,8 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index c2e36695688a38..a808c1068c59d1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -502,7 +502,8 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala index f59d0d46c0fa41..d1ab0358224c58 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala @@ -413,7 +413,8 @@ trait ReadCamemBertDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala index 56f57238e3a84e..de1beb85ad10db 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala @@ -425,7 +425,8 @@ trait ReadDeBertaDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala index d28ce903c48eb0..06a1809973b7f6 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala @@ -475,7 +475,8 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala index 38ead9b55ac086..7ec4c1daf4a739 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala @@ -459,7 +459,8 @@ trait ReadE5DLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala index 0f9b0288a14436..763bb1e5853a7a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala @@ -453,7 +453,8 @@ trait ReadMPNetDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala index 4dc491fe1e9d3f..253dc9376b2673 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala @@ -497,7 +497,8 @@ trait ReadRobertaDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala index f82fc3f2e994c1..3f869f745aeecf 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala @@ -515,7 +515,8 @@ trait ReadUAEDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala index 2d59b18fdb3292..6a4e64efe64f29 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala @@ -447,7 +447,8 @@ trait ReadXlmRobertaDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala index 0fe866c3d1303a..b71110f3d79f86 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala @@ -429,7 +429,8 @@ trait ReadXlmRobertaSentenceDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala index d37a6ce90e7a11..8c85f2915561f3 100644 --- a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala +++ b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala @@ -113,7 +113,10 @@ object ZipArchiveUtil { throw new IllegalArgumentException("only folder and file input are valid") } - def unzip(file: File, destDirPath: Option[String] = None): String = { + def unzip( + file: File, + destDirPath: Option[String] = None, + suffix: Option[String] = None): String = { val fileName = file.getName val basename = if (fileName.indexOf('.') >= 0) { @@ -132,10 +135,10 @@ object ZipArchiveUtil { val zip = new ZipFile(file) zip.entries.asScala foreach { entry => - val entryName = entry.getName + val entryName = if (suffix.isDefined) suffix.get + "_" + entry.getName else entry.getName val entryPath = { if (entryName.startsWith(basename)) - entryName.substring(basename.length) + entryName.substring(0, basename.length) else entryName } @@ -161,4 +164,5 @@ object ZipArchiveUtil { destDir.getPath } + } diff --git a/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala b/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala index e8fad8f4775f7b..fa8310452deccd 100644 --- a/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala @@ -16,12 +16,15 @@ package com.johnsnowlabs.ml.onnx +import com.johnsnowlabs.nlp.util.io.ResourceHelper import com.johnsnowlabs.tags.FastTest import org.scalatest.flatspec.AnyFlatSpec -import java.nio.file.{Files, Paths, Path} + +import java.nio.file.{Files, Path, Paths} import java.io.File import com.johnsnowlabs.util.FileHelper import org.scalatest.BeforeAndAfter + import java.util.UUID class OnnxWrapperTestSpec extends AnyFlatSpec with BeforeAndAfter { @@ -68,16 +71,19 @@ class OnnxWrapperTestSpec extends AnyFlatSpec with BeforeAndAfter { } "a dummy onnx wrapper" should "get session correctly" taggedAs FastTest in { - val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath)) - val dummyOnnxWrapper = new OnnxWrapper(modelBytes) + ResourceHelper.spark.sparkContext.addFile(modelPath) + val onnxFileName = Some(new File(modelPath).getName) + val dummyOnnxWrapper = new OnnxWrapper(onnxFileName, None) dummyOnnxWrapper.getSession(onnxSessionOptions) } "a dummy onnx wrapper" should "saveToFile correctly" taggedAs FastTest in { - val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath)) - val dummyOnnxWrapper = new OnnxWrapper(modelBytes) + ResourceHelper.spark.sparkContext.addFile(modelPath) + val onnxFileName = Some(new File(modelPath).getName) + val dummyOnnxWrapper = new OnnxWrapper(onnxFileName, None) dummyOnnxWrapper.saveToFile(Paths.get(tmpFolder, "modelFromTest.zip").toString) // verify file existence assert(new File(tmpFolder, "modelFromTest.zip").exists()) } + }