Skip to content

Commit

Permalink
add kernel tests for ops that changed in opset18 (#19767)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

- [x] Pad operator has introduced a new input called "axes" which
specifies which axis to pad. But it defaults to input_rank if axes is
not provided which was the behavior before the opset upgrade.
- [x] ReduceMean
- [x] ReduceL2
- [x] ReduceLogSumExp
- [x] ReduceSum
- Reduction ops all had the axes attribute switched to an input and a
new attribute called "noop_with_empty_axes" was added to define what to
do when axes is not specified.
- [x] Resize has had two new attributes introduced: antialias and
keep_aspect_ratio_policy. From Operators.md I've gathered:
"Antialiasing is achieved by stretching the resampling filter by a
factor max(1, 1 / scale), which means that when downsampling, more input
pixels contribute to an output pixel."
keep_aspect_ratio_policy "describes how to interpret the `sizes` input
with regard to keeping the original aspect ratio of the input." there
are a couple enum-type options that specify different policies and what
to do in each case.
- NOTE: Baiju already included opset18 tests in
#17772
- [x] ScatterElements/ScatterND has had a new attribute introduced
called "reduction." This specifies the type of reduction to apply: none
(default), add, mul, max, min.
- [x] Split introduced a new attribute called "num_outputs" which
specifies how many outputs to split the input tensor into. This is in
contrast to the previous, default behavior of specifying a "split" input
which defines the size of each resultant tensor of the output.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
prathikr authored Mar 19, 2024
1 parent 4c6a6a3 commit 26cd3c1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
37 changes: 28 additions & 9 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {

ArgDef grad = GO(0);
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
Expand All @@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
grad = IA("Unqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
}
}

Expand Down Expand Up @@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
}

ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));

result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")}));
}
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
} else {
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
Expand Down Expand Up @@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));

if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def}));
}
}

result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));
Expand Down
30 changes: 27 additions & 3 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) {

OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
RunReductionTests(op_def_opset13);

// axes is input from opset 18.
OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18};
RunReductionTests(op_def_opset18, true, true);
}

TEST(GradientCheckerTest, ReduceSumGrad) {
Expand All @@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
OpDef op_def_13{"ReduceSum", kOnnxDomain, 13};

RunReductionTests(op_def_13, true, true);

OpDef op_def_18{"ReduceSum", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);
}

TEST(GradientCheckerTest, ReduceL2Grad) {
Expand All @@ -641,13 +649,22 @@ TEST(GradientCheckerTest, ReduceL2Grad) {
{MakeAttribute("axes", axes)}));
EXPECT_IS_TINY(max_error);
}

// axes is input from opset 18
OpDef op_def_18{"ReduceL2", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);
}

TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};

RunReductionTests(op_def);

OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18};

RunReductionTests(op_def_opset18, true, true);
}

TEST(GradientCheckerTest, ReluGrad) {
Expand Down Expand Up @@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) {
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
{MakeAttribute("axis", int64_t(0))}));
EXPECT_IS_TINY(max_error);

// opset18 test
OpDef op_def_18{"Split", kOnnxDomain, 18};
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
{MakeAttribute("axis", int64_t(0)),
MakeAttribute("num_outputs", int64_t(3))}));
EXPECT_IS_TINY(max_error);
}

template <typename T>
Expand Down Expand Up @@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) {
TEST(GradientCheckerTest, PadGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Pad", kOnnxDomain, 11};
OpDef op_def{"Pad", kOnnxDomain, 18};

{
TensorInfo x_info({2, 4}, true);
Expand Down Expand Up @@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) {
TEST(GradientCheckerTest, ScatterNDGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"ScatterND", kOnnxDomain, 11};
OpDef op_def{"ScatterND", kOnnxDomain, 18};

{
TensorInfo data_info({8}, true);
Expand Down Expand Up @@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) {
TEST(GradientCheckerTest, ScatterElementsGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"ScatterElements", kOnnxDomain, 13};
OpDef op_def{"ScatterElements", kOnnxDomain, 18};

{ // without axis
TensorInfo data_info({3, 3}, true);
Expand Down

0 comments on commit 26cd3c1

Please sign in to comment.