diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs
index df55942f25..bbb66b0e35 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs
@@ -34,7 +34,7 @@ public static void Example()
var mlContext = new MLContext();
var data = GetTensorData();
var idv = mlContext.Data.LoadFromEnumerable(data);
- var pipeline = mlContext.Transforms.ApplyOnnxModel(modelPath, new[] { outputInfo.Key }, new[] { inputInfo.Key });
+ var pipeline = mlContext.Transforms.ApplyOnnxModel(new[] { outputInfo.Key }, new[] { inputInfo.Key }, modelPath);
// Run the pipeline and get the transformed values
var transformedValues = pipeline.Fit(idv).Transform(idv);
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs b/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs
index db5f6e1114..f33a45c32b 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs
@@ -33,15 +33,15 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
/// Applies a pre-trained Onnx model.
///
/// The transform's catalog.
- /// The path of the file containing the ONNX model.
/// The output column resulting from the transformation.
/// The input column.
+ /// The path of the file containing the ONNX model.
/// Optional GPU device ID to run execution on, to run on CPU.
/// If GPU error, raise exception or fallback to CPU.
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
- string modelFile,
string outputColumnName,
string inputColumnName,
+ string modelFile,
int? gpuDeviceId = null,
bool fallbackToCpu = false)
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName }, modelFile, gpuDeviceId, fallbackToCpu);
@@ -50,15 +50,15 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
/// Applies a pre-trained Onnx model.
///
/// The transform's catalog.
- /// The path of the file containing the ONNX model.
/// The output columns resulting from the transformation.
/// The input columns.
+ /// The path of the file containing the ONNX model.
/// Optional GPU device ID to run execution on, to run on CPU.
/// If GPU error, raise exception or fallback to CPU.
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
- string modelFile,
string[] outputColumnNames,
string[] inputColumnNames,
+ string modelFile,
int? gpuDeviceId = null,
bool fallbackToCpu = false)
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu);
diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
index 84ec16bb8c..7cae32ce02 100644
--- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
+++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
@@ -101,7 +101,7 @@ void TestSimpleCase()
var xyData = new List { new TestDataXY() { A = new float[inputSize] } };
var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };
var sizeData = new List { new TestDataSize() { data_0 = new float[2] } };
- var pipe = ML.Transforms.ApplyOnnxModel(modelFile, new[] { "softmaxout_1" }, new[] { "data_0" });
+ var pipe = ML.Transforms.ApplyOnnxModel(new[] { "softmaxout_1" }, new[] { "data_0" }, modelFile);
var invalidDataWrongNames = ML.Data.LoadFromEnumerable(xyData);
var invalidDataWrongTypes = ML.Data.LoadFromEnumerable(stringData);
@@ -137,7 +137,7 @@ void TestOldSavingAndLoading(int? gpuDeviceId, bool fallbackToCpu)
var inputNames = new[] { "data_0" };
var outputNames = new[] { "softmaxout_1" };
- var est = ML.Transforms.ApplyOnnxModel(modelFile, outputNames, inputNames, gpuDeviceId, fallbackToCpu);
+ var est = ML.Transforms.ApplyOnnxModel(outputNames, inputNames, modelFile, gpuDeviceId, fallbackToCpu);
var transformer = est.Fit(dataView);
var result = transformer.Transform(dataView);
var resultRoles = new RoleMappedData(result);
@@ -241,7 +241,7 @@ public void OnnxModelScenario()
}
});
- var onnx = ML.Transforms.ApplyOnnxModel(modelFile, "softmaxout_1", "data_0").Fit(dataView).Transform(dataView);
+ var onnx = ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile).Fit(dataView).Transform(dataView);
var scoreCol = onnx.Schema["softmaxout_1"];
using (var curs = onnx.GetRowCursor(scoreCol))
@@ -271,7 +271,7 @@ public void OnnxModelMultiInput()
inb = new float[] {1,2,3,4,5}
}
});
- var onnx = ML.Transforms.ApplyOnnxModel(modelFile, new[] { "outa", "outb" }, new[] { "ina", "inb" }).Fit(dataView).Transform(dataView);
+ var onnx = ML.Transforms.ApplyOnnxModel(new[] { "outa", "outb" }, new[] { "ina", "inb" }, modelFile).Fit(dataView).Transform(dataView);
var outaCol = onnx.Schema["outa"];
var outbCol = onnx.Schema["outb"];
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index eb22493efc..723ab59055 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -78,7 +78,7 @@ public void SimpleEndToEndOnnxConversionTest()
// Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
- var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath, outputNames, inputNames);
+ var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
@@ -162,7 +162,7 @@ public void KmeansOnnxConversionTest()
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
- var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath, outputNames, inputNames);
+ var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3);