Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Exclude DDS ops from running on TRT #22874

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,14 @@ function(AddTest)

if (IOS)
# target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m)

set(_UT_IOS_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET})
# replace any characters that are not valid in a bundle identifier with '-'
string(REGEX REPLACE "[^a-zA-Z0-9\\.-]" "-" _UT_IOS_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER})

set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest"
MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}
MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION}
MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION}
MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION}
Expand All @@ -163,7 +168,7 @@ function(AddTest)

set_target_properties(${_UT_TARGET}_xc PROPERTIES FOLDER "ONNXRuntimeXCTest"
MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}_xc
MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}
MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION}
MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION}
MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION}
Expand Down
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Do not modify directly.*
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **TS** = tensor(float)|
|||[10, 20]|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|||[19, 20]|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
Expand Down
18 changes: 11 additions & 7 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3651,13 +3651,17 @@ struct OrtApi {
* - "73"
* - "75"
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
"enable_htp_fp16_precision": Used for float32 model for HTP backend.
Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
- "0": With fp32 precision.
- "1": Default. With fp16 precision.
"enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
- "0": Default. Disabled.
- "1": Enabled.
* "enable_htp_fp16_precision": Used for float32 model for HTP backend.
* Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
* - "0": With fp32 precision.
* - "1": Default. With fp16 precision.
* "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
* - "0": Default. Disabled.
* - "1": Enabled.
* "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderI
LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape";
return false;
}

if (mode == "DCR" && input_params.coreml_version < 7) {
int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
GetType(*input_defs[0], input_type, logger);

if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
// In CoreML version 6 (e.g., on an iOS 16 simulator) with DCR mode and float16 input, the output is all zeros
// in this unit test: TensorOpTest/1.DepthToSpaceTest_4.
// However, CoreML version 7 is fine.
// Don't support CoreML version < 7, DCR mode, and float16 input.
LOGS(logger, VERBOSE) << "DepthToSpace: DCR mode with float16 input requires at least CoreML version 7.";
return false;
}
}
} else {
if (mode != "DCR") {
LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported";
Expand Down
20 changes: 14 additions & 6 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t,
QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, uint8_t,
QLinearMatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, int8_t,
QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
Expand Down Expand Up @@ -1108,6 +1110,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear);
Expand Down Expand Up @@ -1691,10 +1695,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12,
int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20,
uint8_t, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20,
int8_t, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t,
MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t,
Expand Down Expand Up @@ -2769,6 +2773,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t,
QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t,
QLinearMatMul)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN,
DequantizeLinear)>,
Expand Down
82 changes: 48 additions & 34 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ void ComputeJob(
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
IAllocatorUniquePtr<float>& scale_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
bool simplified,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(scale_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(scale_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

const T* p_input = X_data + task_idx * norm_size;
Expand Down Expand Up @@ -82,14 +82,17 @@ void ComputeJob(
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
IAllocatorUniquePtr<float>& scale_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
bool simplified,
MLFloat16* Y_data,
U* mean_data,
U* inv_std_dev_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(scale_data); // only used in float/double overload
ORT_UNUSED_PARAMETER(bias_data); // only used in float/double overload

const MLFloat16* p_input = X_data + task_idx * norm_size;
MLFloat16* p_output = Y_data + task_idx * norm_size;

Expand Down Expand Up @@ -117,22 +120,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

if (!scale_float_uptr) {
scale_float_uptr = std::move(input_float_uptr); // overwrite input with scale values, since they have the same size
MlasConvertHalfToFloatBuffer(scale_data, scale_float_uptr.get(), num_elems);
}

if (bias_data && !bias_float_uptr) {
bias_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems);
}

const float* scale_float_ptr = scale_float_uptr.get();
const float* bias_float_ptr = bias_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
} else if (nullptr == bias_data) {
} else if (nullptr == bias_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
Expand Down Expand Up @@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
} // namespace

LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
: OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr), bias_fp32_(nullptr) {
: OpKernel(op_kernel_info),
simplified_{simplified},
contrib_op_{contrib_op},
prepacked_scale_fp32_data_(nullptr),
prepacked_scale_fp32_size_(0),
prepacked_bias_fp32_data_(nullptr),
prepacked_bias_fp32_size_(0) {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}
Expand All @@ -175,15 +172,15 @@ template <typename T, typename U>
Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const {
// Inputs
const Tensor* X = p_ctx->Input<Tensor>(0);
const Tensor* scale = p_ctx->Input<Tensor>(1);
const Tensor* bias = p_ctx->Input<Tensor>(2);
const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const T* X_data = X->Data<T>();
const T* scale_data = scale->Data<T>();
const T* scale_data = scale ? scale->Data<T>() : nullptr;
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias->Shape();
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();

Expand Down Expand Up @@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
}

Expand All @@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr

is_packed = false;
if (input_idx == 1) { // scale
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, scale_fp32_, is_packed);
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
} else if (input_idx == 2) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

return Status::OK();
Expand All @@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext(
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
const TensorShape& scale_shape,
size_t scale_size,
const T* bias_data,
const TensorShape& bias_shape,
size_t bias_size,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
Expand All @@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext(
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

const auto scale_size = scale_shape.Size();
const auto bias_size = (bias_data) ? bias_shape.Size() : 0;
if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(norm_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(norm_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_,
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
},
0);
Expand Down
Loading
Loading