From 2d00351d7b4975a4d03f6a437772b6976726a252 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:57:07 -0700 Subject: [PATCH 1/3] ORT 1.20.0 Release: Cherry pick round 1 (#22526) ORT 1.20.0 release preparation: Cherry pick round 1 Approved cherry pick comments --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Hector Li Co-authored-by: Adrian Lizarraga Co-authored-by: Patrice Vignola Co-authored-by: Changming Sun Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- cmake/onnxruntime_unittests.cmake | 9 +- .../core/session/onnxruntime_c_api.h | 18 +- .../builders/impl/depthtospace_op_builder.cc | 14 + .../core/providers/cpu/nn/layer_norm_impl.cc | 82 +++-- .../core/providers/cpu/nn/layer_norm_impl.h | 10 +- .../Operators/DmlOperatorRotaryEmbedding.cpp | 291 ++++++++++++++---- .../dml/OperatorAuthorHelper/Attributes.h | 2 + .../opbuilder/layer_norm_op_builder.cc | 4 +- .../builder/opbuilder/simple_op_builder.cc | 10 + .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../core/providers/qnn/builder/qnn_model.h | 1 + .../providers/qnn/builder/qnn_model_wrapper.h | 13 +- .../providers/qnn/qnn_execution_provider.cc | 32 +- .../providers/qnn/qnn_execution_provider.h | 1 + .../models/stable_diffusion/requirements.txt | 1 + .../test/contrib_ops/layer_norm_op_test.cc | 29 ++ .../contrib_ops/rotary_embedding_op_test.cc | 17 +- .../test/logging_apis/test_logging_apis.cc | 8 +- onnxruntime/test/onnx/main.cc | 8 +- .../microbenchmark/layer_normalization.cc | 17 +- .../test/perftest/command_args_parser.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 6 +- .../test/providers/qnn/layer_norm_test.cc | 16 +- .../test/providers/qnn/matmul_test.cpp | 4 +- .../test/providers/qnn/qnn_basic_test.cc | 75 +++++ .../test/providers/qnn/qnn_test_utils.cc | 7 +- .../test/providers/qnn/qnn_test_utils.h | 11 +- .../test/qnn_ctx_gen/command_args_parser.cc | 8 +- onnxruntime/test/unittest_main/test_main.cc | 10 +- onnxruntime/test/xctest/xcgtest.mm | 43 ++- .../github/apple/get_simulator_device_info.py | 29 +- ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../azure-pipelines/bigmodels-ci-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 +- .../linux-gpu-tensorrt-ci-pipeline.yml | 4 +- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../azure-pipelines/mac-ios-ci-pipeline.yml | 8 + .../py-cuda-package-test-pipeline.yml | 2 +- .../py-package-test-pipeline.yml | 2 +- .../azure-pipelines/py-packaging-pipeline.yml | 2 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../stages/java-cuda-packaging-stage.yml | 4 +- .../jobs/py-linux-cuda-package-test-job.yml | 2 +- .../stages/py-cuda-packaging-stage.yml | 4 +- .../templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../templates/jobs/download_win_qnn_sdk.yml | 2 +- .../templates/py-linux-qnn.yml | 2 +- .../templates/py-packaging-stage.yml | 4 +- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-arm64ec-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../stages/mac-ios-packaging-build-stage.yml | 9 + .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- .../linux/docker/Dockerfile.manylinux2_28_cpu | 2 +- .../default/cpu/scripts/install_deps.sh | 2 +- .../inference/aarch64/python/cpu/Dockerfile | 2 +- .../default/cpu/scripts/install_deps.sh | 2 +- .../x86_64/default/cuda11/Dockerfile | 2 +- .../x86_64/default/cuda12/Dockerfile | 2 +- .../inference/x86_64/python/cpu/Dockerfile | 2 +- 63 files changed, 664 insertions(+), 205 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a2495de5dfd80..cbae6990cd0b6 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -134,9 +134,14 @@ function(AddTest) if (IOS) # target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m) + + set(_UT_IOS_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}) + # replace any characters that are not valid in a bundle identifier with '-' + string(REGEX REPLACE "[^a-zA-Z0-9\\.-]" "-" _UT_IOS_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}) + set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET} - MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET} + MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER} MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION} MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION} MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION} @@ -163,7 +168,7 @@ function(AddTest) set_target_properties(${_UT_TARGET}_xc PROPERTIES FOLDER "ONNXRuntimeXCTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}_xc - MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET} + MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER} MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION} MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION} MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION} diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9e71997c1e442..bde27df94ed1c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3651,13 +3651,17 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Used for float32 model for HTP backend. - Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": With fp32 precision. - - "1": Default. With fp16 precision. - "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - - "0": Default. Disabled. - - "1": Enabled. + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * - "0": With fp32 precision. + * - "1": Default. With fp16 precision. + * "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. + * - "0": Default. Disabled. + * - "1": Enabled. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index ddaa19c7fab18..fec14dfd093a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -145,6 +145,20 @@ bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderI LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape"; return false; } + + if (mode == "DCR" && input_params.coreml_version < 7) { + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + GetType(*input_defs[0], input_type, logger); + + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + // In CoreML version 6 (e.g., on an iOS 16 simulator) with DCR mode and float16 input, the output is all zeros + // in this unit test: TensorOpTest/1.DepthToSpaceTest_4. + // However, CoreML version 7 is fine. + // Don't support CoreML version < 7, DCR mode, and float16 input. + LOGS(logger, VERBOSE) << "DepthToSpace: DCR mode with float16 input requires at least CoreML version 7."; + return false; + } + } } else { if (mode != "DCR") { LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported"; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index f73efcddcedd4..24a5dcab225c4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -24,16 +24,16 @@ void ComputeJob( const T* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, - IAllocatorUniquePtr& scale_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const float* scale_float_ptr, + const float* bias_float_ptr, float epsilon, bool simplified, T* Y_data, U* mean_data, U* inv_std_dev_data, AllocatorPtr alloc) { - ORT_UNUSED_PARAMETER(scale_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(scale_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload ORT_UNUSED_PARAMETER(alloc); const T* p_input = X_data + task_idx * norm_size; @@ -82,14 +82,17 @@ void ComputeJob( const MLFloat16* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, - IAllocatorUniquePtr& scale_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const float* scale_float_ptr, + const float* bias_float_ptr, float epsilon, bool simplified, MLFloat16* Y_data, U* mean_data, U* inv_std_dev_data, AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(scale_data); // only used in float/double overload + ORT_UNUSED_PARAMETER(bias_data); // only used in float/double overload + const MLFloat16* p_input = X_data + task_idx * norm_size; MLFloat16* p_output = Y_data + task_idx * norm_size; @@ -117,22 +120,10 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - if (!scale_float_uptr) { - scale_float_uptr = std::move(input_float_uptr); // overwrite input with scale values, since they have the same size - MlasConvertHalfToFloatBuffer(scale_data, scale_float_uptr.get(), num_elems); - } - - if (bias_data && !bias_float_uptr) { - bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); - } - - const float* scale_float_ptr = scale_float_uptr.get(); - const float* bias_float_ptr = bias_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { if (simplified) { output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h]; - } else if (nullptr == bias_data) { + } else if (nullptr == bias_float_ptr) { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h]; } else { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h]; @@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I } // namespace LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op) - : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr), bias_fp32_(nullptr) { + : OpKernel(op_kernel_info), + simplified_{simplified}, + contrib_op_{contrib_op}, + prepacked_scale_fp32_data_(nullptr), + prepacked_scale_fp32_size_(0), + prepacked_bias_fp32_data_(nullptr), + prepacked_bias_fp32_size_(0) { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); } @@ -175,15 +172,15 @@ template Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const { // Inputs const Tensor* X = p_ctx->Input(0); - const Tensor* scale = p_ctx->Input(1); - const Tensor* bias = p_ctx->Input(2); + const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(2); const T* X_data = X->Data(); - const T* scale_data = scale->Data(); + const T* scale_data = scale ? scale->Data() : nullptr; const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); - const TensorShape& scale_shape = scale->Shape(); - const TensorShape& bias_shape = bias->Shape(); + size_t scale_size = scale ? static_cast(scale->Shape().Size()) : prepacked_scale_fp32_size_; + size_t bias_size = bias ? static_cast(bias->Shape().Size()) : prepacked_bias_fp32_size_; Tensor* Y = p_ctx->Output(0, x_shape); T* Y_data = Y->MutableData(); @@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - return ComputeWithoutContext(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data, + return ComputeWithoutContext(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data, inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc); } @@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr is_packed = false; if (input_idx == 1) { // scale - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, scale_fp32_, is_packed); + prepacked_scale_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed); } else if (input_idx == 2) { // bias - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + prepacked_bias_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } return Status::OK(); @@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext( const T* X_data, const TensorShape& x_shape, const T* scale_data, - const TensorShape& scale_shape, + size_t scale_size, const T* bias_data, - const TensorShape& bias_shape, + size_t bias_size, T* Y_data, U* mean_data, U* inv_std_dev_data, @@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext( int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); - const auto scale_size = scale_shape.Size(); - const auto bias_size = (bias_data) ? bias_shape.Size() : 0; - if (scale_size != norm_size || (bias_data && bias_size != norm_size)) { + if (static_cast(scale_size) != norm_size || (bias_data && static_cast(bias_size) != norm_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size of X.shape()[axis:] == ", norm_size, ". Size of scale and bias (if provided) must match this. Got scale size of ", scale_size, " and bias size of ", bias_size); } + IAllocatorUniquePtr scale_fp32; + IAllocatorUniquePtr bias_fp32; + if constexpr (std::is_same_v) { + if (prepacked_scale_fp32_data_ == nullptr) { + const size_t num_elems = static_cast(norm_size); + scale_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems); + } + if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + const size_t num_elems = static_cast(norm_size); + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + } + } + concurrency::ThreadPool::TryBatchParallelFor( thread_pool, static_cast(norm_count), [&](ptrdiff_t task_idx) { - ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_, + ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, + prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(), + prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); }, 0); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index f6325c31cc71a..f8b528b398cba 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel { const T* X_data, const TensorShape& x_shape, const T* scale_data, - const TensorShape& scale_shape, + size_t scale_size, const T* bias_data, - const TensorShape& bias_shape, + size_t bias_size, T* Y_data, U* mean_data, U* inv_std_dev, @@ -63,8 +63,10 @@ class LayerNormImpl : public OpKernel { float epsilon_; const bool simplified_; const bool contrib_op_; - mutable IAllocatorUniquePtr scale_fp32_; - mutable IAllocatorUniquePtr bias_fp32_; + IAllocatorUniquePtr prepacked_scale_fp32_data_; + size_t prepacked_scale_fp32_size_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; + size_t prepacked_bias_fp32_size_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 0f15ebf342b3a..95d9644b4ca30 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -25,6 +25,41 @@ // The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap // the sign of every adjacent element. +// Here's a representation of what the graph looks like in DML, before getting fused together: +/* + Input CosCache PositionIds SinCache + | | | | + | | +--------+-----------+ | + Split | | | | + | | Gather Gather + +-------+ | | | + | | | | + | Identity----------+ | | + | | | | | + | | | | | + | --Split-- | | | + | \ / | +-----------------+ | + | \ / | | | + | \ / Mul | + | \ / | | + | X | | + | / \ | | + | / \ | | + | Join | | + | | | | + | | +---------------------------------------------------------+ + | | | | + | Mul | + | | | + | +-----+ +------+ + | | | + | Add + | | + +-------------+ | + | | + Join +*/ + namespace Dml { class DmlOperatorRotaryEmbedding : public DmlOperator @@ -56,25 +91,45 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); - const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; - // The last dimension of the data is the hidden size, so it must be divisible by the head size - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + uint32_t numHeads = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::NumHeads, 0)); + uint32_t rotaryEmbeddingDim = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::RotaryEmbeddingDim, 0)); - // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t hiddenSize = inputIs4D ? inputDataSizes[1] * inputDataSizes[3] : inputDataSizes.back(); + + const uint32_t headSize = numHeads == 0 + ? m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2 + : hiddenSize / numHeads; + + if (rotaryEmbeddingDim > 0) + { + ORT_ENFORCE(numHeads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } + else + { + rotaryEmbeddingDim = headSize; + } + + if (numHeads == 0) + { + numHeads = hiddenSize / headSize; + } + else if (inputIs4D) + { + ORT_ENFORCE(numHeads == inputDataSizes[1], "When the input has 4 dimensions, num_heads must be 0 or have the same value as the second dimension of the input"); + } + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; - if (sequenceLength > maxSequenceLength) + const bool isPackedBatching = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::IsPackedBatching, 0)) == 1; + if (!isPackedBatching && sequenceLength > maxSequenceLength) { ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); } @@ -84,64 +139,103 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const std::array inputOutputShape = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize}) + : std::array({batchSize, sequenceLength, numHeads, headSize}); + + const std::array splitInputOutputShape1 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}); + + const std::array splitInputOutputShape2 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize - rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, headSize - rotaryEmbeddingDim}); + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); - TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc splitInputOutputTensorDesc1 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape1); + TensorDesc splitInputOutputTensorDesc2 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape2); - if (inputIs4D) + // Split the input to perform the rotary embedding only on a subregion of the tensor if needed. The split inputs + // will be joined back together at the end. + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + std::array splitTensorDescs = { + splitInputOutputTensorDesc1.GetDmlDesc(), + splitInputOutputTensorDesc2.GetDmlDesc(), + }; + + DML_SPLIT_OPERATOR_DESC splitInputOperatorDesc{}; + DML_OPERATOR_DESC splitInputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) { - const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; - stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + splitInputOperatorDesc.InputTensor = &inputOutputDmlTensorDesc; + splitInputOperatorDesc.OutputCount = gsl::narrow_cast(splitTensorDescs.size()); + splitInputOperatorDesc.OutputTensors = splitTensorDescs.data(); + splitInputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + splitInputDmlOperatorDesc.Type = DML_OPERATOR_SPLIT; + splitInputDmlOperatorDesc.Desc = &splitInputOperatorDesc; } - const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); - - // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + // Copy the partial input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + const std::array partialInputOutputShape = {batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}; + TensorDesc partialStridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + TensorDesc partialInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + + if (inputIs4D) + { + const std::array partialInputOutputStrides = {rotaryEmbeddingDim * numHeads * sequenceLength, rotaryEmbeddingDim, sequenceLength * rotaryEmbeddingDim, 1}; + partialStridedInputOutputTensorDesc.SetStrides(partialInputOutputStrides); + } + + const DML_TENSOR_DESC partialStridedInputOutputDmlTensorDesc = partialStridedInputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC partialInputOutputDmlTensorDesc = partialInputOutputTensorDesc.GetDmlDesc(); + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; - copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &partialStridedInputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &partialInputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + const uint32_t halfRoraryEmbeddingDim = rotaryEmbeddingDim / 2; + // Split the input data into 2 equal parts - const std::vector inputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) - : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + const std::vector partialInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, rotaryEmbeddingDim / 2}); const std::vector splitInputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, rotaryEmbeddingDim / 2}); - TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc partialInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); + const DML_TENSOR_DESC partialInputDataDmlTensorDesc = partialInputDataTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - - TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; - DML_SPLIT_OPERATOR_DESC splitInputDesc{}; - splitInputDesc.InputTensor = &inputDataDmlTensorDesc; - splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); - splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - splitInputDesc.Axis = interleaved + DML_SPLIT_OPERATOR_DESC splitPartialInputDesc{}; + splitPartialInputDesc.InputTensor = &partialInputDataDmlTensorDesc; + splitPartialInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitPartialInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitPartialInputDesc.Axis = interleaved ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; - const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + const DML_OPERATOR_DESC splitPartialInputDmlDesc = {DML_OPERATOR_SPLIT, &splitPartialInputDesc}; // Swap the 2 halves and join them together - DML_JOIN_OPERATOR_DESC joinInputDesc{}; - joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; - joinInputDesc.Axis = splitInputDesc.Axis; - joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + DML_JOIN_OPERATOR_DESC joinPartialInputDesc{}; + joinPartialInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinPartialInputDesc.OutputTensor = &joinedDataDmlTensorDesc; + joinPartialInputDesc.Axis = splitPartialInputDesc.Axis; + joinPartialInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinPartialInputDmlDesc = {DML_OPERATOR_JOIN, &joinPartialInputDesc}; // We generate a sequence from 0 to sequenceLength and add the offset to it const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; @@ -177,7 +271,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; // Gather the cos/sin values based on the position ids - const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, rotaryEmbeddingDim / 2}; TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); @@ -191,9 +285,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data const std::vector reshapedCosSinShape = interleaved - ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); - TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + ? std::vector({batchSize, sequenceLength, 1, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, rotaryEmbeddingDim / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedCosSinShape); const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); // Create a vector that contains the sign values {-1, 1} @@ -224,7 +318,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const std::vector reshapedSignShape = interleaved ? std::vector({1, 1, 1, 1, 2}) : std::vector({1, 1, 1, 2, 1}); - TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedSignShape); const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; @@ -242,11 +336,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; - addDesc.ATensor = &inputOutputDmlTensorDesc; - addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; + addDesc.ATensor = &partialInputOutputDmlTensorDesc; + addDesc.BTensor = &partialInputOutputDmlTensorDesc; + addDesc.OutputTensor = &partialStridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + DML_JOIN_OPERATOR_DESC joinOutputOperatorDesc{}; + DML_OPERATOR_DESC joinOutputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) + { + joinOutputOperatorDesc.InputCount = gsl::narrow_cast(splitTensorDescs.size()); + joinOutputOperatorDesc.InputTensors = splitTensorDescs.data(); + joinOutputOperatorDesc.OutputTensor = &inputOutputDmlTensorDesc; + joinOutputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + joinOutputDmlOperatorDesc.Type = DML_OPERATOR_JOIN; + joinOutputDmlOperatorDesc.Desc = &joinOutputOperatorDesc; + } + // Construct the graph std::vector inputEdges; std::vector intermediateEdges; @@ -254,12 +360,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector opDescs = { ©InputDmlDesc, // Copy the input data to preseve the real input shape - &splitInputDmlDesc, // Split the input data + &splitPartialInputDmlDesc, // Split the input data &gatherCosSinDmlDesc, // Gather cos &gatherCosSinDmlDesc, // Gather sin &signRangeDmlDesc, // Generate the signs - &joinInputDmlDesc, // Join the split data + &joinPartialInputDmlDesc, // Join the split data &mulCosSinDmlDesc, // Multiply cos with the non-rotated data &mulCosSinDmlDesc, // Multiply sin with the rotated data &mulSignDmlDesc, // Multiply the sign with the rotated data @@ -269,12 +375,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator enum NodeIndex : uint32_t { copyInputOpIndex, - splitInputOpIndex, + splitPartialInputOpIndex, gatherCosOpIndex, gatherSinOpIndex, signRangeOpIndex, - joinInputOpIndex, + joinPartialInputOpIndex, mulCosOpIndex, mulSinOpIndex, mulSignOpIndex, @@ -285,6 +391,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator positionIdsAddOffsetOpIndex, }; + uint32_t splitInputOpIndex = positionIdsIsOffset ? positionIdsAddOffsetOpIndex + 1 : addOpIndex + 1; + uint32_t joinOutputOpIndex = splitInputOpIndex + 1; + if (positionIdsIsOffset) { opDescs.push_back(&positionIdsRangeDmlDesc); @@ -332,11 +441,32 @@ class DmlOperatorRotaryEmbedding : public DmlOperator inputEdges.push_back(positionIdsToGatherSinEdge); } - DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; - inputToCopyInputEdge.GraphInputIndex = inputDataIndex; - inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; - inputToCopyInputEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToCopyInputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + opDescs.push_back(&splitInputDmlOperatorDesc); + opDescs.push_back(&joinOutputDmlOperatorDesc); + + DML_INPUT_GRAPH_EDGE_DESC inputToSplitInputEdge = {}; + inputToSplitInputEdge.GraphInputIndex = inputDataIndex; + inputToSplitInputEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSplitInputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC partialInputToCopyInputEdge = {}; + partialInputToCopyInputEdge.FromNodeIndex = splitInputOpIndex; + partialInputToCopyInputEdge.FromNodeOutputIndex = 0; + partialInputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + partialInputToCopyInputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(partialInputToCopyInputEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + } DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; cosToGatherEdge.GraphInputIndex = cosCacheIndex; @@ -353,7 +483,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; inputToSplitEdge.FromNodeIndex = copyInputOpIndex; inputToSplitEdge.FromNodeOutputIndex = 0; - inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeIndex = splitPartialInputOpIndex; inputToSplitEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(inputToSplitEdge); @@ -365,16 +495,16 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(nonRotatedDataToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; - secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; - secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; secondHalfDataToJoinEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(secondHalfDataToJoinEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; - firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; - firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; firstHalfDataToJoinEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(firstHalfDataToJoinEdge); @@ -386,7 +516,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(cosToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; - rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeIndex = joinPartialInputOpIndex; rotatedDataToMulEdge.FromNodeOutputIndex = 0; rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; rotatedDataToMulEdge.ToNodeInputIndex = 0; @@ -427,11 +557,36 @@ class DmlOperatorRotaryEmbedding : public DmlOperator rotatedSinToAddEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(rotatedSinToAddEdge); - DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; - addToOutputEdge.FromNodeIndex = addOpIndex; - addToOutputEdge.FromNodeOutputIndex = 0; - addToOutputEdge.GraphOutputIndex = 0; - outputEdges.push_back(addToOutputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC addToJoinOutputEdge = {}; + addToJoinOutputEdge.FromNodeIndex = addOpIndex; + addToJoinOutputEdge.FromNodeOutputIndex = 0; + addToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + addToJoinOutputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(addToJoinOutputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC remainingInputToJoinOutputEdge = {}; + remainingInputToJoinOutputEdge.FromNodeIndex = splitInputOpIndex; + remainingInputToJoinOutputEdge.FromNodeOutputIndex = 1; + remainingInputToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + remainingInputToJoinOutputEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(remainingInputToJoinOutputEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC joinOutputToOutputEdge = {}; + joinOutputToOutputEdge.FromNodeIndex = joinOutputOpIndex; + joinOutputToOutputEdge.FromNodeOutputIndex = 0; + joinOutputToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(joinOutputToOutputEdge); + } + else + { + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + } MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 0c5739554b800..3d23fb6206479 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -130,6 +130,8 @@ namespace AttrName static constexpr const char* UppercaseN = "N"; static constexpr const char* UppercaseK = "K"; static constexpr const char* MatMulNBitsBlockSize = "block_size"; + static constexpr const char* RotaryEmbeddingDim = "rotary_embedding_dim"; + static constexpr const char* IsPackedBatching = "is_packed_batching"; } // namespace AttrName diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index 5c4608dff9bb1..d089235ceaa02 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,9 +87,9 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18 || QNN_API_VERSION_MINOR == 19) +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 17) if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24/2.25/2.26 (QNN API version 2.17/2.18/2.19) has a validation bug for implicit bias inputs, + // Bias is implicit. QNN SDK 2.24+ (QNN API version 2.17+) has a validation bug for implicit bias inputs, // so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 0358fae3c2115..a6c4203ad92e4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -164,6 +164,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Inputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone DQ op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(node_unit.Outputs()[0].node_arg.Name()), + "QNN EP is configured to not take DQ nodes that generate a graph output."); + } } if (op_type == "QuantizeLinear") { @@ -171,6 +176,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Outputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone Q op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphInput(node_unit.Inputs()[0].node_arg.Name()), + "QNN EP is configured to not take Q nodes that consume a graph input."); + } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index f322456e0c8f0..b09ff51b666c7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -95,6 +95,7 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node, Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs) { LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); @@ -115,7 +116,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, model_input_index_map_, model_output_index_map_, initializer_inputs_, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings); bool rt = true; rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name, graph_configs); if (!rt) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 83cf8f9f08fb0..d9682cc3b3222 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -35,6 +35,7 @@ class QnnModel { Status ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs = nullptr); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ab122b7f8e28..f3e52050e79e0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -29,6 +29,10 @@ struct TensorInfo { const ONNX_NAMESPACE::TensorProto* initializer_tensor; }; +struct ModelSettings { + bool offload_graph_io_quantization = false; +}; + class QnnModelWrapper { public: QnnModelWrapper(const GraphViewer& graph_viewer, @@ -38,7 +42,8 @@ class QnnModelWrapper { const std::unordered_map& input_index_map, const std::unordered_map& output_index_map, const std::unordered_set& initializer_lookup, - QnnBackendType qnn_backend_type) + QnnBackendType qnn_backend_type, + const ModelSettings& model_settings) : graph_viewer_(graph_viewer), logger_(logger), qnn_interface_(qnn_interface), @@ -46,12 +51,15 @@ class QnnModelWrapper { input_index_map_(input_index_map), output_index_map_(output_index_map), initializer_lookup_(initializer_lookup), - qnn_backend_type_(qnn_backend_type) { + qnn_backend_type_(qnn_backend_type), + model_settings_(model_settings) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModelWrapper); ~QnnModelWrapper() = default; + const ModelSettings& GetModelSettings() const { return model_settings_; } + bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); @@ -279,6 +287,7 @@ class QnnModelWrapper { const std::unordered_map& output_index_map_; const std::unordered_set& initializer_lookup_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + ModelSettings model_settings_ = {}; }; // QnnModelWrapper } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 24132b98e3757..4cd5d403e95b8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -161,6 +161,23 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic } } +static bool ParseBoolOption(const std::string& key, bool default_value, + const std::unordered_map& options) { + bool result = default_value; + auto it = options.find(key); + if (it != options.end()) { + if ("1" == it->second) { + result = true; + } else if ("0" == it->second) { + result = false; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid value for " << key << " (" << it->second << "). Only 0 or 1 allowed."; + } + LOGS_DEFAULT(VERBOSE) << "Using " << key << ": " << result; + } + return result; +} + qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned char level) { if (level == 5) { LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level); @@ -403,6 +420,15 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; } + model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, + provider_options_map); + + if (disable_cpu_ep_fallback_ && model_settings_.offload_graph_io_quantization) { + LOGS_DEFAULT(WARNING) << "Fallback to CPU EP is disabled, but user configured QNN EP to offload graph I/O " + << "quantization/dequantization to another EP. Session creation will fail if the CPU EP " + << "handles the graph I/O quantization/dequantization."; + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level_etw, @@ -499,7 +525,8 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, model_input_index_map, model_output_index_map, initializer_input_lookup, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings_); std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_size); @@ -845,7 +872,8 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vectorComposeGraph(graph_viewer, fused_node, logger, graph_configs_builder.GetQnnConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, model_settings_, logger, + graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e0eaf31c94a36..246ab1d5a6608 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -153,6 +153,7 @@ class QNNExecutionProvider : public IExecutionProvider { #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; #endif + qnn::ModelSettings model_settings_ = {}; class PerThreadContext final { public: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 5080737516c53..1857b366194ec 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,3 +1,4 @@ +huggingface_hub==0.25.2 diffusers==0.28.0 transformers==4.41.2 numpy>=1.24.1 diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 655c4951f262d..9ecaa16a2ab24 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -151,6 +151,20 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -211,6 +225,21 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), true); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. TEST(LayerNormTest, LayerNorm17_float) { OpTester test("LayerNormalization", 17); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 89552da58b938..8675a997d29a1 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -135,8 +135,7 @@ static void RunTests(const std::vector& input_data, int max_sequence_length = 0, int64_t interleaved = 0, int64_t is_packed_batching = 0, - bool use_float16 = true, - bool disable_dml = false) { + bool use_float16 = true) { // FP32 test for CPU RunTest(input_data, position_ids, @@ -173,7 +172,7 @@ static void RunTests(const std::vector& input_data, TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ - disable_dml || false /* disable_dml */); + false /* disable_dml */); // FP16 test for CUDA and DML if (use_float16) { @@ -193,7 +192,7 @@ static void RunTests(const std::vector& input_data, TensorType::kFloat16, true, /* disable_cpu */ false, /* disable_cuda*/ - disable_dml || false /* disable_dml */); + false /* disable_dml */); // RunTest(input_data, // position_ids, @@ -743,9 +742,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { num_heads, max_sequence_length, interleaved, - 0, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 0, // is_packed_batching + true /*use_fp16*/); } TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { @@ -785,9 +783,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_B num_heads, max_sequence_length, interleaved, - 1, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 1, // is_packed_batching + true /*use_fp16*/); } } // namespace test diff --git a/onnxruntime/test/logging_apis/test_logging_apis.cc b/onnxruntime/test/logging_apis/test_logging_apis.cc index d72c47493d800..b98e5c34b4e1d 100644 --- a/onnxruntime/test/logging_apis/test_logging_apis.cc +++ b/onnxruntime/test/logging_apis/test_logging_apis.cc @@ -359,12 +359,16 @@ TEST_F(MockCAPITestsFixture, CppLogMacroBypassCApiCall) { #undef TEST_MAIN #define TEST_MAIN main_no_link_ // there is a UI test app for iOS. -// IOS tests require this function to be defined. +// iOS tests require ortenv_setup() and ortenv_teardown() to be defined. // See onnxruntime/test/xctest/xcgtest.mm -void ortenv_setup() { +extern "C" void ortenv_setup() { // Do nothing. These logging tests do not require an env to be setup initially. } +extern "C" void ortenv_teardown() { + // Do nothing. +} + #endif // TARGET_OS_SIMULATOR || TARGET_OS_IOS #endif // defined(__APPLE__) diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6d86e4c35af85..93a1bf9f30651 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -77,6 +77,8 @@ void usage() { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -587,20 +589,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', -'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); +'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc index 75ce7b77acd4e..f6158d8cbc12b 100644 --- a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc +++ b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc @@ -111,9 +111,20 @@ static void BM_LayerNormalization(benchmark::State& state) { OrtMemoryInfo memory_info(onnxruntime::CPU, OrtAllocatorType::OrtArenaAllocator); AllocatorPtr alloc = std::make_shared(memory_info); for (auto _ : state) { - auto status = layer_norm_impl.ComputeWithoutContext(x_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, - Y_data, mean_data, inv_std_dev_data, thread_pool.get(), axis, - epsilon, simplified, alloc); + auto status = layer_norm_impl.ComputeWithoutContext(x_data, + x_shape, + scale_data, + static_cast(scale_shape.Size()), + bias_data, + static_cast(bias_shape.Size()), + Y_data, + mean_data, + inv_std_dev_data, + thread_pool.get(), + axis, + epsilon, + simplified, + alloc); if (!status.IsOK()) { std::cout << "ComputeWithoutContext status not OK: " << status.ErrorMessage() << std::endl; break; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 94945c0393d08..e40544d950ed7 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -98,6 +98,8 @@ namespace perftest { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fcdef48eda56c..e69c87b2540e5 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -302,20 +302,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for " + key + ". select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); +'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 2af49a5e500d2..2773568dde717 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -188,7 +188,13 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_StaticBias_AU8_WU8_B ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { +// QNN 2.27 accuracy issue +// Inaccuracy detected for output 'output_0', element 0 +// output_range=1.2245157957077026, tolerance=0.40000000596046448%. +// Expected val (f32@CPU_EP): -0 +// qdq@QNN_EP val: 0.19133351743221283 (err: 0.19133351743221283, err/output_range: 15.625238418579102%) +// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) +TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { // QNN 2.24 LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide an // explicit bias of all zeros to get around this bug. for (size_t i = 0; i < 15; i++) { // Run it multiple times since this is an intermittent bug. @@ -202,7 +208,13 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { } // Test accuracy of 16-bit QDQ LayerNorm with a static scale input. -TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { +// QNN 2.27 accuracy issue +// Inaccuracy detected for output 'output_0', element 0 +// output_range=1.224743127822876, tolerance=0.40000000596046448%. +// Expected val (f32@CPU_EP): -0 +// qdq@QNN_EP val: 0.19136904180049896 (err: 0.19136904180049896, err/output_range: 15.625238418579102%) +// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) +TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static TestInputDef(), diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 708aac03ceb2e..800457d906940 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -273,7 +273,9 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { } // Test QDQ per-channel MatMul with int8 act, int4 weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { +// QNN 2.27 regression +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_PerChannel_AS8_WeightInt4) { std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 236b66a2d8a78..e8282dbad9f72 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1023,6 +1023,81 @@ TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { &ep_graph_checker); } +// Test option for offloading quantization of graph inputs and dequantization of graph outputs to the CPU EP. +TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { + // Returns a function that checks that the Q/DQ ops at the graph IO boundary are offloaded to CPU + // if the corresponding provider option is enabled. + auto graph_checker_builder = [](bool offload_graph_io_quantization) -> std::function { + return [offload_graph_io_quantization](const Graph& graph) { + size_t num_q = 0; + size_t num_dq = 0; + size_t num_qnn_fused_node = 0; + + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + + if (offload_graph_io_quantization && op_type == "QuantizeLinear") { + const bool consumes_graph_input = graph.IsInputsIncludingInitializers(node.InputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(consumes_graph_input); + num_q += 1; + } else if (offload_graph_io_quantization && op_type == "DequantizeLinear") { + const bool produces_graph_output = graph.IsOutput(node.OutputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(produces_graph_output); + num_dq += 1; + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + num_qnn_fused_node += 1; + } + } + + EXPECT_EQ(num_q, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_dq, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_qnn_fused_node, 1); + }; + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::vector op_types = { + "Sigmoid", + "Transpose", + "Softmax", + "Sqrt", + "Elu", + }; + + // Test various QDQ ops with offloading of I/O quantization enabled and disabled. + for (auto op_type : op_types) { + for (int offload_io_quant = 0; offload_io_quant <= 1; offload_io_quant++) { + provider_options["offload_graph_io_quantization"] = offload_io_quant ? "1" : "0"; + auto graph_checker = graph_checker_builder(offload_io_quant); + auto expected_ep_assignment = offload_io_quant ? ExpectedEPNodeAssignment::Some : ExpectedEPNodeAssignment::All; + + float min_val = (op_type == "Sqrt") ? 0.0f : -10.0f; + TestInputDef input_def({1, 2, 2, 2}, false, GetFloatDataInRange(min_val, 10.0f, 8)); + auto f32_model_build_fn = BuildOpTestCase(op_type, {input_def}, {}, {}); + auto qdq_model_build_fn = BuildQDQOpTestCase(op_type, {input_def}, {}, {}); + TestQDQModelAccuracy(f32_model_build_fn, + qdq_model_build_fn, + provider_options, + /*opset*/ 21, + expected_ep_assignment, + /*abs_err*/ QDQTolerance(), + logging::Severity::kERROR, + /*qnn_ctx_model_path*/ "", + /*session_option_pairs*/ {}, + &graph_checker); + } + } +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 8a4f7f2a1f6b5..79e7d39e85518 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -134,7 +134,8 @@ void InferenceModel(const std::string& model_data, const char* log_id, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep, - const std::unordered_map& session_option_pairs) { + const std::unordered_map& session_option_pairs, + std::function* graph_checker) { SessionOptions so; so.session_logid = log_id; for (auto key_value : session_option_pairs) { @@ -166,6 +167,10 @@ void InferenceModel(const std::string& model_data, const char* log_id, ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type; } + if (graph_checker) { + (*graph_checker)(graph); + } + const auto& outputs = graph.GetOutputs(); std::vector output_names; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 7f55a44c748b6..a8670252ff9e0 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -457,13 +457,15 @@ DEF_QUANTIZE_VALUES_INT4_FUNC(UInt4x2, ParQuantizeLinearStdU4) * \param output_vals Initialized to the inference results. * \param is_qnn_ep Ture: QNN EP is used. False: CPU EP is used (default). * \param session_option_pairs extra session options. + * \param graph_checker Function called on the Graph. */ void InferenceModel(const std::string& model_data, const char* log_id, const ProviderOptions& provider_options, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep = false, - const std::unordered_map& session_option_pairs = {}); + const std::unordered_map& session_option_pairs = {}, + std::function* graph_checker = nullptr); /** * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies @@ -515,6 +517,8 @@ struct QDQTolerance { * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the QDQ model * on CPU EP. This tolerance is a percentage of the output range. * \param log_severity The logger's severity setting. + * \param ep_graph_checker Function called on the Graph generated for the QNN EP's session. Used to check node + * EP assignment. */ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, @@ -523,7 +527,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe QDQTolerance tolerance = QDQTolerance(), logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", - const std::unordered_map& session_option_pairs = {}) { + const std::unordered_map& session_option_pairs = {}, + std::function* qnn_ep_graph_checker = nullptr) { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; @@ -607,7 +612,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP InferenceModel(qdq_model_data, "qdq_model_logger", qnn_options, expected_ep_assignment, - qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); + qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs, qnn_ep_graph_checker); } if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 102846e08ac5f..5b3720992c542 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -48,6 +48,8 @@ namespace qnnctxgen { "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -143,7 +145,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing") { + } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" || + key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -154,7 +157,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode', - 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing'])"); + 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing', + 'offload_graph_io_quantization'])"); } test_config.run_config.qnn_options[key] = value; diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 1d89272680e47..b558a7f00f7bc 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -27,8 +27,8 @@ std::unique_ptr ort_env; -// ortenv_setup is used by /onnxruntime/test/xctest/xcgtest.mm so can't be file local -void ortenv_setup() { +// ortenv_setup() and ortenv_teardown() are used by onnxruntime/test/xctest/xcgtest.mm so can't be file local +extern "C" void ortenv_setup() { OrtThreadingOptions tpo; // allow verbose logging to be enabled by setting this environment variable to a numeric log level @@ -46,6 +46,10 @@ void ortenv_setup() { ort_env.reset(new Ort::Env(&tpo, log_level, "Default")); } +extern "C" void ortenv_teardown() { + ort_env.reset(); +} + #ifdef USE_TENSORRT #if defined(_MSC_VER) @@ -101,7 +105,7 @@ int TEST_MAIN(int argc, char** argv) { } // TODO: Fix the C API issue - ort_env.reset(); // If we don't do this, it will crash + ortenv_teardown(); // If we don't do this, it will crash #ifndef USE_ONNXRUNTIME_DLL // make memory leak checker happy diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index c02f18d906cbe..785c9cd937022 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -34,7 +34,8 @@ using testing::TestPartResult; using testing::UnitTest; -void ortenv_setup(); +extern "C" void ortenv_setup(); +extern "C" void ortenv_teardown(); static NSString* const GoogleTestDisabledPrefix = @"DISABLED_"; @@ -63,24 +64,51 @@ public: XCTestListener(XCTestCase* testCase) : _testCase(testCase) {} - void OnTestPartResult(const TestPartResult& test_part_result) { + void OnTestPartResult(const TestPartResult& test_part_result) override { if (test_part_result.passed() || test_part_result.skipped()) return; int lineNumber = test_part_result.line_number(); const char* fileName = test_part_result.file_name(); NSString* path = fileName ? [@(fileName) stringByStandardizingPath] : nil; + NSString* summary = @(test_part_result.summary()); NSString* description = @(test_part_result.message()); - [_testCase recordFailureWithDescription:description - inFile:path - atLine:(lineNumber >= 0 ? (NSUInteger)lineNumber : 0) - expected:YES]; + + XCTSourceCodeLocation* sourceCodeLocation = + [[XCTSourceCodeLocation alloc] initWithFilePath:path + lineNumber:lineNumber]; + + XCTSourceCodeContext* sourceCodeContext = + [[XCTSourceCodeContext alloc] initWithLocation:sourceCodeLocation]; + + XCTIssue* issue = [[XCTIssue alloc] initWithType:XCTIssueTypeAssertionFailure + compactDescription:summary + detailedDescription:description + sourceCodeContext:sourceCodeContext + associatedError:nil + attachments:@[]]; + + [_testCase recordIssue:issue]; } private: XCTestCase* _testCase; }; +/** + * A Google Test listener that manages the ORT env setup and teardown. + */ +class OrtEnvManagementListener : public testing::EmptyTestEventListener { + public: + void OnTestProgramStart(const UnitTest& unit_test) override { + ortenv_setup(); + } + + void OnTestProgramEnd(const UnitTest& unit_test) override { + ortenv_teardown(); + } +}; + /** * Registers an XCTestCase subclass for each Google Test case. * @@ -179,7 +207,6 @@ + (void)load { object:bundle queue:nil usingBlock:^(NSNotification* notification) { - ortenv_setup(); [self registerTestClasses]; }]; } @@ -201,6 +228,8 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); + listeners.Append(new OrtEnvManagementListener()); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; diff --git a/tools/ci_build/github/apple/get_simulator_device_info.py b/tools/ci_build/github/apple/get_simulator_device_info.py index 7de9aa13912e0..aa693038b4394 100755 --- a/tools/ci_build/github/apple/get_simulator_device_info.py +++ b/tools/ci_build/github/apple/get_simulator_device_info.py @@ -8,6 +8,7 @@ import functools import itertools import json +import os import subprocess @@ -37,7 +38,7 @@ def __lt__(self, other: Version) -> bool: def get_simulator_device_info( requested_runtime_platform: str = "iOS", requested_device_type_product_family: str = "iPhone", - max_runtime_version_str: str | None = None, + requested_runtime_version_str: str | None = None, ) -> dict[str, str]: """ Retrieves simulator device information from Xcode. @@ -45,11 +46,13 @@ def get_simulator_device_info( :param requested_runtime_platform: The runtime platform to select. :param requested_device_type_product_family: The device type product family to select. - :param max_runtime_version_str: The maximum runtime version to allow. + :param requested_runtime_version_str: The runtime version to select. If unspecified, selects the latest one. :return: A dictionary containing information about the selected simulator device. """ - max_runtime_version = Version(max_runtime_version_str) if max_runtime_version_str is not None else None + requested_runtime_version = ( + Version(requested_runtime_version_str) if requested_runtime_version_str is not None else None + ) simctl_proc = subprocess.run( ["xcrun", "simctl", "list", "--json", "--no-escape-slashes"], @@ -73,7 +76,7 @@ def runtime_filter(runtime) -> bool: if runtime["platform"] != requested_runtime_platform: return False - if max_runtime_version is not None and Version(runtime["version"]) > max_runtime_version: + if requested_runtime_version is not None and Version(runtime["version"]) != requested_runtime_version: return False return True @@ -108,6 +111,9 @@ def device_filter(device) -> bool: ): runtime_id_and_device_pairs.extend((runtime_id, device) for device in filter(device_filter, device_list)) + if len(runtime_id_and_device_pairs) == 0: + raise ValueError("Failed to find requested simulator device info.") + # sort key - tuple of (runtime version, device type min runtime version) # the secondary device type min runtime version value is to treat more recent device types as greater def runtime_id_and_device_pair_key(runtime_id_and_device_pair): @@ -137,13 +143,20 @@ def runtime_id_and_device_pair_key(runtime_id_and_device_pair): def main(): + requested_runtime_version_environment_variable_name = "ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION" + parser = argparse.ArgumentParser(description="Gets simulator info from Xcode and prints it in JSON format.") - parser.add_argument("--max-runtime-version", help="The maximum runtime version to allow.") + parser.add_argument( + "--requested-runtime-version", + default=os.environ.get(requested_runtime_version_environment_variable_name, None), + help="The requested runtime version. " + f"This may also be specified with the {requested_runtime_version_environment_variable_name} " + "environment variable. The command line option takes precedence. " + "An unspecified value means the latest available runtime version.", + ) args = parser.parse_args() - info = get_simulator_device_info( - max_runtime_version_str=args.max_runtime_version, - ) + info = get_simulator_device_info(requested_runtime_version_str=args.requested_runtime_version) print(json.dumps(info, indent=2)) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index abdcb1b7610c9..9362a8b0ee18c 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index a7ea5061e604e..ad763277c732e 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -42,7 +42,7 @@ parameters: variables: - template: templates/common-variables.yml - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 - name: linux_trt_version value: 10.3.0.26-1.cuda11.8 - name: Repository diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index e2d977bd60986..b12360d2710d0 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1f9b506ac451f..b0f40429c1a1e 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -49,9 +49,9 @@ parameters: variables: - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 - name: Repository ${{ if eq(parameters.CudaVersion, '11.8') }}: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index e43cbd3413f2d..87d5c7bd824d2 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -39,9 +39,9 @@ parameters: variables: - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: 10.4.0.26-1.cuda11.8 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index feb27e90085b8..41f6b6a8d6d80 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index c61beb63b8b40..9576aac182bbe 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -36,9 +36,16 @@ jobs: PROTO_CACHE_DIR: $(Pipeline.Workspace)/proto_ccache ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + # Note: Keep the Xcode version and iOS simulator version compatible. + # Check the table here to see what iOS simulator versions are supported by a particular Xcode version: + # https://developer.apple.com/support/xcode/ + XCODE_VERSION: 14.3.1 + IOS_SIMULATOR_RUNTIME_VERSION: 16.4 timeoutInMinutes: 150 steps: - template: templates/use-xcode-version.yml + parameters: + xcodeVersion: $(XCODE_VERSION) - template: templates/mac-build-step-with-cache.yml parameters: @@ -71,3 +78,4 @@ jobs: CCACHE_DEPEND: 1 CCACHE_SLOPPINESS: modules CCACHE_DIR: $(ORT_CACHE_DIR) + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: $(IOS_SIMULATOR_RUNTIME_VERSION) diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 7fb4563a477fc..e946fedd07a27 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 trt_version: '10.4.0.26-1.cuda12.6' cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 2641ec6d56ffb..c458f0cf4bfe2 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -54,7 +54,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 trt_version: '10.4.0.26-1.cuda11.8' cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 7263239c6c7f0..de17db216da9c 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -69,7 +69,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.26.0.240828 + default: 2.27.0.240926 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 98b5e47c0e2d7..fd3f31da4ab7e 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 87fe920d8ecdd..a38486995478d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -148,9 +148,9 @@ stages: value: false - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 timeoutInMinutes: 60 steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 8c492c0153964..9289935b4ef9c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -46,7 +46,7 @@ jobs: ${{ if eq(parameters.CudaVersion, '11.8') }}: value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: 10.4.0.26-1.cuda11.8 diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml index 466dbb2f21ec8..ae18687cb9e54 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -77,8 +77,8 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 trt_version: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.cuda_version, '12.2') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241015.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 trt_version: 10.4.0.26-1.cuda12.6 diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 4aedd2f8564c1..f749f32456b25 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.26.0.240828' + default: '2.27.0.240926' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index eff49302eb33d..c56d81aefbec1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.26.0.240828' + default: '2.27.0.240926' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 6220a9a46c312..e663afb49dd99 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 jobs: - job: Linux_py_qnn_Wheels_x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index a5a440eb877e9..10d7ce04747d9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -73,7 +73,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.26.0.240828 + default: 2.27.0.240926 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: @@ -470,7 +470,7 @@ stages: parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241015.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} trt_version: '10.4.0.26-1.cuda11.8' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 6e573d79e4a72..f47108a2a48cd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 2c9218a059e0c..5839ee273c1fe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 9cb82d65bcdce..9e01f4116b602 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 6fed0192d866d..30280c6e22c7e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.26.0.240828' + QnnSdk: '2.27.0.240926' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index e27de27036130..0d2330489279d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -18,7 +18,12 @@ stages: vmImage: "macOS-13" variables: + # Note: Keep the Xcode version and iOS simulator version compatible. + # Check the table here to see what iOS simulator versions are supported by a particular Xcode version: + # https://developer.apple.com/support/xcode/ xcodeVersion: "14.3.1" + iosSimulatorRuntimeVersion: "16.4" + ortPodVersion: $[stageDependencies.IosPackaging_SetCommonVariables.j.outputs['SetCommonVariables.ORT_POD_VERSION']] ${{ if eq(parameters.packageVariant, 'Full') }}: @@ -62,6 +67,8 @@ stages: architecture: "x64" - template: ../use-xcode-version.yml + parameters: + xcodeVersion: $(xcodeVersion) - template: ../install-appcenter.yml @@ -80,6 +87,8 @@ stages: --build-settings-file "${{ variables.buildSettingsFile }}" \ ${{ variables.optionalIncludeOpsByConfigOption }} displayName: "Build macOS/iOS framework and assemble pod package files" + env: + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: $(iosSimulatorRuntimeVersion) - script: | python tools/ci_build/github/apple/test_apple_packages.py \ diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 4c0003f31fea1..8f971612dbc6d 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 442f99a7f50e3..fdb6998f53d15 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.27.0.240926 jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 0b39bea26c7de..3ff213b16f3d1 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241015.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241020.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index e6f38b5cbb76e..bf08a853fe7f4 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -40,7 +40,7 @@ cd /tmp/src CPU_ARCH=$(uname -m) echo "Installing cmake" -GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc1/cmake-3.31.0-rc1-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" +GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc2/cmake-3.31.0-rc2-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 933b56e4fd413..3f42b28497c7a 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241015.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241020.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh index 53a49a996ad2d..0cc48a720b8f4 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh @@ -39,7 +39,7 @@ mkdir -p /tmp/src cd /tmp/src CPU_ARCH=$(uname -m) echo "Installing cmake" -GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc1/cmake-3.31.0-rc1-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" +GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc2/cmake-3.31.0-rc2-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile index 238f0c9a0d922..6702474d75801 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241015.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241020.1 ARG TRT_VERSION RUN rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 24a4503c03f4c..4059de23b2480 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241015.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241020.1 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index deea9db9aae91..76b31e71a7dea 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241015.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241020.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts From c4fb724e810bb496165b9015c77f402727392933 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:00:00 -0700 Subject: [PATCH 2/3] ORT 1.20.0 release preparation: Cherry pick round 2 (#22643) ORT 1.20.0 release preparation: Cherry pick round 2 Approved commits --------- Co-authored-by: Hector Li Co-authored-by: ivberg --- docs/OperatorKernels.md | 3 +- .../providers/cpu/cpu_execution_provider.cc | 20 +++-- .../quantization/quantize_linear_matmul.cc | 55 ++++++++---- .../qnn/builder/qnn_backend_manager.cc | 16 +++- .../qnn/builder/qnn_backend_manager.h | 3 + .../providers/qnn/qnn_execution_provider.cc | 90 ++++++++++--------- .../providers/qnn/qnn_execution_provider.h | 2 +- onnxruntime/test/onnx/TestCase.cc | 8 +- .../cpu/math/quantize_linear_matmul_test.cc | 20 +++-- 9 files changed, 138 insertions(+), 79 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d8de7756bae22..ddf37cfded77d 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -258,7 +258,8 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**TS** = tensor(float)| +|||[10, 20]|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 424bee63511ad..a8284e4d88693 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -379,8 +379,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, uint8_t, + QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, int8_t, + QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); @@ -1108,6 +1110,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); @@ -1691,10 +1695,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, QuantizeLinear)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc index cb162ade44559..be448455194f6 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc @@ -14,10 +14,11 @@ namespace onnxruntime { // uint8_t kernel supports weight being either uint8_t or int8_t -ONNX_OPERATOR_TYPED_KERNEL_EX( +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( QLinearMatMul, kOnnxDomain, 10, + 20, uint8_t, kCpuExecutionProvider, KernelDefBuilder() @@ -26,21 +27,45 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), QLinearMatMul); +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + // int8_t kernel only supports weight being int8_t -#define REGISTER_QLINEARMATMUL_INT8_KERNEL() \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - QLinearMatMul, \ - kOnnxDomain, \ - 10, \ - int8_t, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", DataTypeImpl::GetTensorType()), \ - QLinearMatMul); - -REGISTER_QLINEARMATMUL_INT8_KERNEL(); +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 10, + 20, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); Status QLinearMatMul::Compute(OpKernelContext* ctx) const { const auto* a = ctx->Input(IN_A); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index eaffe1e2ac224..34dcbd1d77fca 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -302,13 +302,21 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } Status QnnBackendManager::ResetQnnLogLevel() { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + std::lock_guard lock(logger_mutex_); + + if (backend_setup_completed_ && logger_ != nullptr) { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); + } + return Status::OK(); } Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); + ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -686,6 +694,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { + std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; return Status::OK(); @@ -972,6 +981,7 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } + std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b80f1374fcdc7..43007d4a5c244 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -12,9 +12,11 @@ #endif #include +#include #include #include #include + #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" #include "QnnTypes.h" @@ -233,6 +235,7 @@ class QnnBackendManager { private: const std::string backend_path_; + std::mutex logger_mutex_; const logging::Logger* logger_ = nullptr; QNN_INTERFACE_VER_TYPE qnn_interface_ = QNN_INTERFACE_VER_TYPE_INIT; QNN_SYSTEM_INTERFACE_VER_TYPE qnn_sys_interface_ = QNN_SYSTEM_INTERFACE_VER_TYPE_INIT; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 4cd5d403e95b8..ed193904fe7a8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -258,49 +258,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); - } - } - } - - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); -#endif - // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { @@ -440,6 +397,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_arch, soc_model, enable_htp_weight_sharing_); + +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); +#endif } QNNExecutionProvider::~QNNExecutionProvider() { @@ -453,7 +453,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #ifdef _WIN32 - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 246ab1d5a6608..9422e54bd0035 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -151,7 +151,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 - onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 45aaca1ceae56..6b9b20faf8697 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1026,7 +1026,13 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, - {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}}); + {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"qlinearmatmul_2D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_2D_int8_float32", "result diff", {}}, + {"qlinearmatmul_2D_uint8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float32", "result diff", {}}, + {"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}}); // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. diff --git a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc index 8cdb837712e83..096263792727a 100644 --- a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc @@ -126,8 +126,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul3D_S8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {208, 236, 0, 238, 3, 214, 255, 29}); @@ -155,10 +155,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { @@ -197,8 +199,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {80, -2, -128, 110, -125, 86, 127, -99}); @@ -225,10 +227,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { test.Run(); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } static void QLinearMatMul2DTest(bool only_t1_not_initializer) { From 269f9c63808452da035f2a96d4166f690d6002ee Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 18 Nov 2024 18:18:51 +0000 Subject: [PATCH 3/3] update --- .../tensorrt/tensorrt_execution_provider.cc | 74 ++++++++++++++----- .../tensorrt/tensorrt_execution_provider.h | 4 + 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 97d88786e4bcd..48834792d2364 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1725,6 +1725,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } + trt_version_ = getInferLibVersion(); + + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " << "device_id: " << device_id_ << ", trt_max_partition_iterations: " << max_partition_iterations_ @@ -2462,10 +2466,30 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - std::vector filtered_nodes_vector; + std::set exclude_ops_set; + + /* + * There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. + * TRT EP automatically excludes DDS ops from running on TRT. + */ + if (trt_version_ >= 100000 && trt_version_ < 110000) { + exclude_ops_set.insert("NonMaxSuppression"); + exclude_ops_set.insert("NonZero"); + exclude_ops_set.insert("RoiAlign"); + LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable" + } + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + bool new_subgraph = true; + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. It's a DDS op. + */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); + bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: * @@ -2477,29 +2501,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. */ if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - auto sub_graphs = node->GetSubgraphs(); - if (sub_graphs.size() != 0) { - bool all_subgraphs_are_supported = true; - for (auto sub_graph : sub_graphs) { - // TRT EP should consider the empty subgraph is fully supported by TRT. - if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { - continue; - } - if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { - all_subgraphs_are_supported = false; - break; + auto supported_control_flow_op = [&](const Node* node) { + auto sub_graphs = node->GetSubgraphs(); + if (sub_graphs.size() != 0) { + for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } + if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } } } - if (!all_subgraphs_are_supported) { - // if not all its subgraphs are supported, we need to exclude this control flow op - continue; - } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_set.find(node->OpType()) != exclude_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; } - filtered_nodes_vector.push_back(index); } - SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; bool early_termination = false; supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 97c9367b0bb61..0e9c11f7de968 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -329,6 +329,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool cuda_graph_enable_ = false; std::string cache_prefix_; bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; + + // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH + int32_t trt_version_; // The OrtAllocator object will be get during ep compute time // and should be kept for the lifetime of TRT EP object.