diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs index df55942f25..bbb66b0e35 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs @@ -34,7 +34,7 @@ public static void Example() var mlContext = new MLContext(); var data = GetTensorData(); var idv = mlContext.Data.LoadFromEnumerable(data); - var pipeline = mlContext.Transforms.ApplyOnnxModel(modelPath, new[] { outputInfo.Key }, new[] { inputInfo.Key }); + var pipeline = mlContext.Transforms.ApplyOnnxModel(new[] { outputInfo.Key }, new[] { inputInfo.Key }, modelPath); // Run the pipeline and get the transformed values var transformedValues = pipeline.Fit(idv).Transform(idv); diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs b/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs index db5f6e1114..f33a45c32b 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs @@ -33,15 +33,15 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog /// Applies a pre-trained Onnx model. /// /// The transform's catalog. - /// The path of the file containing the ONNX model. /// The output column resulting from the transformation. /// The input column. + /// The path of the file containing the ONNX model. /// Optional GPU device ID to run execution on, to run on CPU. /// If GPU error, raise exception or fallback to CPU. public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog, - string modelFile, string outputColumnName, string inputColumnName, + string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) => new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName }, modelFile, gpuDeviceId, fallbackToCpu); @@ -50,15 +50,15 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog /// Applies a pre-trained Onnx model. /// /// The transform's catalog. - /// The path of the file containing the ONNX model. /// The output columns resulting from the transformation. /// The input columns. + /// The path of the file containing the ONNX model. /// Optional GPU device ID to run execution on, to run on CPU. /// If GPU error, raise exception or fallback to CPU. public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog, - string modelFile, string[] outputColumnNames, string[] inputColumnNames, + string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) => new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu); diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs index 84ec16bb8c..7cae32ce02 100644 --- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs @@ -101,7 +101,7 @@ void TestSimpleCase() var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; var sizeData = new List { new TestDataSize() { data_0 = new float[2] } }; - var pipe = ML.Transforms.ApplyOnnxModel(modelFile, new[] { "softmaxout_1" }, new[] { "data_0" }); + var pipe = ML.Transforms.ApplyOnnxModel(new[] { "softmaxout_1" }, new[] { "data_0" }, modelFile); var invalidDataWrongNames = ML.Data.LoadFromEnumerable(xyData); var invalidDataWrongTypes = ML.Data.LoadFromEnumerable(stringData); @@ -137,7 +137,7 @@ void TestOldSavingAndLoading(int? gpuDeviceId, bool fallbackToCpu) var inputNames = new[] { "data_0" }; var outputNames = new[] { "softmaxout_1" }; - var est = ML.Transforms.ApplyOnnxModel(modelFile, outputNames, inputNames, gpuDeviceId, fallbackToCpu); + var est = ML.Transforms.ApplyOnnxModel(outputNames, inputNames, modelFile, gpuDeviceId, fallbackToCpu); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -241,7 +241,7 @@ public void OnnxModelScenario() } }); - var onnx = ML.Transforms.ApplyOnnxModel(modelFile, "softmaxout_1", "data_0").Fit(dataView).Transform(dataView); + var onnx = ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile).Fit(dataView).Transform(dataView); var scoreCol = onnx.Schema["softmaxout_1"]; using (var curs = onnx.GetRowCursor(scoreCol)) @@ -271,7 +271,7 @@ public void OnnxModelMultiInput() inb = new float[] {1,2,3,4,5} } }); - var onnx = ML.Transforms.ApplyOnnxModel(modelFile, new[] { "outa", "outb" }, new[] { "ina", "inb" }).Fit(dataView).Transform(dataView); + var onnx = ML.Transforms.ApplyOnnxModel(new[] { "outa", "outb" }, new[] { "ina", "inb" }, modelFile).Fit(dataView).Transform(dataView); var outaCol = onnx.Schema["outa"]; var outbCol = onnx.Schema["outb"]; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index eb22493efc..723ab59055 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -78,7 +78,7 @@ public void SimpleEndToEndOnnxConversionTest() // Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); - var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath, outputNames, inputNames); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); @@ -162,7 +162,7 @@ public void KmeansOnnxConversionTest() // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); - var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath, outputNames, inputNames); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3);