Skip to content

Commit

Permalink
refactor: extract shared util function ComputeBroadcastOutputShape (#…
Browse files Browse the repository at this point in the history
…21940)

### Description

This is used in multiple places.
  • Loading branch information
fs-eire authored Sep 4, 2024
1 parent 628c0a8 commit decb385
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

// ORT system.
#include "core/providers/cuda/tensor/expand.h"
#include "core/providers/common.h"

// std C++.
#include <iostream>
Expand Down Expand Up @@ -51,7 +52,7 @@ Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
onnxruntime::ComputeBroadcastOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape)
Expand Down
29 changes: 0 additions & 29 deletions onnxruntime/core/providers/cann/cann_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,34 +224,5 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
}

Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

} // namespace cann
} // namespace onnxruntime
2 changes: 0 additions & 2 deletions onnxruntime/core/providers/cann/cann_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ Status aclrtblasGemmEx(aclTransType transA,

bool FileExist(const std::string& file_name);
void GenerateHashValue(const std::string string, HashValue& hash_value);
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape);

std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cann/math/binary_elementwise_ops.h"
#include <vector>
#include <algorithm>
Expand All @@ -20,7 +22,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
const Tensor* B = ctx->Input<Tensor>(1);

TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape));
Tensor* C = ctx->Output(0, output_shape);

void* A_data = const_cast<void*>(A->DataRaw());
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,38 @@ T Product(const Container<T>& c) {
return accumulate(c.cbegin(), c.cend(), static_cast<T>(1), std::multiplies<T>());
}

/// <summary>
/// Compute the output shape for broadcasting the given input shapes of lhs and rhs.
/// </summary>
inline Status ComputeBroadcastOutputShape(const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

} // namespace onnxruntime
32 changes: 3 additions & 29 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
Expand All @@ -21,34 +23,6 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
return Status::OK();
}

Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);

std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}

Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Expand Down Expand Up @@ -77,7 +51,7 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
const auto& rhs_shape = rhs_tensor->Shape();

TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context->Output(0, output_shape);

ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ struct BinaryElementwisePreparation {
}
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/common.h"
#include "core/providers/cuda/math/variadic_elementwise_ops.h"

#include <cassert>
Expand Down Expand Up @@ -209,7 +210,7 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
TensorShape output_shape;
TensorShape previous_output_shape = first_input_tensor.Shape();
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(ComputeOutputShape(
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
previous_output_shape = output_shape;
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
TensorShape output_shape(output_dims);

ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
auto& output_tensor = *ctx->Output(0, output_shape);
if (0 == output_shape.Size()) {
return Status::OK();
Expand Down Expand Up @@ -202,7 +202,7 @@ std::unique_ptr<Tensor> FuncExpand(
TensorShape output_shape(output_dims);

ORT_ENFORCE(
ComputeOutputShape(
ComputeBroadcastOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape)
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ class Expand final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
Expand Down

0 comments on commit decb385

Please sign in to comment.