Skip to content

Commit

Permalink
Change named attribute of reduce_avg to align with tt-mlir naming. (#…
Browse files Browse the repository at this point in the history
…1101)

Solves [#1574](tenstorrent/tt-mlir#1574)
- Avg pool2d decomposes to reduce_avg. Reduce avg op requires dim_arg
attribute in ttir. Therefore, changing named attribute of reduce_avg
from `dim` to `dim_arg`.
- Edit unit tests and change xfail reason for test_avg_pool2d_resnet.
  • Loading branch information
dgolubovicTT authored Jan 27, 2025
1 parent 8573421 commit f92e047
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 19 deletions.
2 changes: 1 addition & 1 deletion forge/csrc/passes/commute_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ void update_reduce_attr(graphlib::OpNode *reduce, int reduce_dim, bool keep_dim)
attr.push_back(keep_dim);

graphlib::OpType::Attrs named_attrs = reduce->named_attrs();
named_attrs["dim"] = reduce_dim;
named_attrs["dim_arg"] = reduce_dim;
named_attrs["keep_dim"] = keep_dim;

reduce->overwrite_op_named_attrs(attr, named_attrs);
Expand Down
12 changes: 6 additions & 6 deletions forge/csrc/test/passes/test_erase_inverse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ struct UpdateReduceSumAttrsTest : testing::Test
},
{input_node});
auto &named_attrs = reduce_node->named_attrs();
named_attrs["dim"] = reduce_dim;
named_attrs["dim_arg"] = reduce_dim;
named_attrs["keep_dim"] = keep_dim;
reduce_node->overwrite_named_attrs(named_attrs);
create_output(*graph, "out", reduce_node);
Expand All @@ -594,8 +594,8 @@ TEST_F(UpdateReduceSumAttrsTest, ReduceSumDim)

auto updated_attrs = reduce_node->named_attrs();

ASSERT_TRUE(updated_attrs.count("dim"));
EXPECT_EQ(std::get<int>(updated_attrs["dim"]), reduce_dim);
ASSERT_TRUE(updated_attrs.count("dim_arg"));
EXPECT_EQ(std::get<int>(updated_attrs["dim_arg"]), reduce_dim);

ASSERT_TRUE(updated_attrs.count("keep_dim"));
EXPECT_EQ(std::get<bool>(updated_attrs["keep_dim"]), keep_dim);
Expand Down Expand Up @@ -628,7 +628,7 @@ struct UpdateReduceMaxAttrsTest : testing::Test
},
{input_node});
auto &named_attrs = reduce_node->named_attrs();
named_attrs["dim"] = reduce_dim;
named_attrs["dim_arg"] = reduce_dim;
named_attrs["stride"] = stride;
named_attrs["keep_dim"] = keep_dim;
reduce_node->overwrite_named_attrs(named_attrs);
Expand All @@ -653,8 +653,8 @@ TEST_F(UpdateReduceMaxAttrsTest, ReduceMaxDim)

auto updated_attrs = reduce_node->named_attrs();

ASSERT_TRUE(updated_attrs.count("dim"));
EXPECT_EQ(std::get<int>(updated_attrs["dim"]), reduce_dim);
ASSERT_TRUE(updated_attrs.count("dim_arg"));
EXPECT_EQ(std::get<int>(updated_attrs["dim_arg"]), reduce_dim);

ASSERT_TRUE(updated_attrs.count("stride"));
EXPECT_EQ(std::get<int>(updated_attrs["stride"]), stride);
Expand Down
8 changes: 5 additions & 3 deletions forge/forge/op/eval/forge/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,14 @@ def decompose(type, attr, dc, inputs):
result = dc.op_with_named_attrs(
"reshape", [activations], {"shape": (w, 1, y * x, cin)}, (w, 1, y * x, cin)
)
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": [-2], "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, 1, 1, cin)}, (w, 1, 1, cin))
else:
result = dc.op_with_named_attrs(
"reshape", [activations], {"shape": (w, 1, cin, y * x)}, (w, 1, cin, y * x)
)
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": [-2], "keep_dim": True}, (-2, True))
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, cin, 1, 1)}, (w, cin, 1, 1))
dc.fuse(result)
Expand Down Expand Up @@ -718,7 +718,9 @@ def decompose(type, attr, dc, inputs):
d_start = i * sD

depth_slice = dc.op("index", [activations], (2, d_start, d_start + kD, activations.shape[2]))
depth_avg = dc.op_with_named_attrs("reduce_avg", [depth_slice], {"dim": 2, "keep_dim": True}, (2, True))
depth_avg = dc.op_with_named_attrs(
"reduce_avg", [depth_slice], {"dim_arg": [2], "keep_dim": True}, (2, True)
)

named_attrs = {
"kernel_height": kernel_size[1],
Expand Down
3 changes: 0 additions & 3 deletions forge/test/mlir/operators/nn/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,6 @@ def test_convtranspose2d(
verify(inputs, framework_model, compiled_model)


@pytest.mark.xfail(
reason="RuntimeError: TT_FATAL @ /tt-metal/src/tt-metal/ttnn/cpp/ttnn/tensor/tensor_utils.cpp:474: new_volume == old_volume. Invalid arguments to reshape. Tracking on: https://github.com/tenstorrent/tt-mlir/issues/1574"
)
@pytest.mark.push
def test_avg_pool2d():
class AvgPool2d(nn.Module):
Expand Down
3 changes: 0 additions & 3 deletions forge/test/mlir/resnet/test_resnet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@


@pytest.mark.push
@pytest.mark.xfail(
reason=" Metal issue: Can only tilize bfloat16 tensors. tracked on: https://github.com/tenstorrent/tt-metal/issues/14570"
)
def test_resnet_inference():
# Compiler configurations
compiler_cfg = forge.config._get_global_compiler_config()
Expand Down
3 changes: 0 additions & 3 deletions forge/test/mlir/resnet/test_resnet_unique_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def forward(self, x):
verify(inputs, framework_model, compiled_model)


@pytest.mark.xfail(
reason="RuntimeError: TT_FATAL @ /tt-metal/src/tt-metal/ttnn/cpp/ttnn/tensor/tensor_utils.cpp:474: new_volume == old_volume. Invalid arguments to reshape. Tracking on: https://github.com/tenstorrent/tt-mlir/issues/1574"
)
@pytest.mark.push
def test_avg_pool2d_resnet():
class AvgPool2d(nn.Module):
Expand Down

0 comments on commit f92e047

Please sign in to comment.