diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index b035e58d29a3..78d333dfa810 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -9,9 +9,11 @@ from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, data_gen_with_range_int, + data_gen_with_val, compare_pcc, compare_equal, ) +from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import is_grayskull, skip_for_grayskull @@ -998,32 +1000,43 @@ def test_binary_lcm_ttnn(input_shapes, device): @pytest.mark.parametrize( "input_shapes", ( - (torch.Size([1, 3, 32, 32])), - (torch.Size([1, 6, 32, 32])), - (torch.Size([1, 7, 320, 384])), - (torch.Size([1, 4, 320, 384])), + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), ), ) def test_binary_prelu_ttnn(input_shapes, device): - in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 channels = input_shapes[1] in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100 + + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.prelu(input_tensor1, input_tensor2) + output_tensor = ttnn.to_torch(output_tensor) golden_function = ttnn.get_golden_function(ttnn.prelu) golden_tensor = golden_function(in_data1, in_data2) - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + assert_with_pcc(golden_tensor, output_tensor, 0.999) @pytest.mark.parametrize( "input_shapes", ( - (torch.Size([1, 3, 32, 32])), - (torch.Size([1, 6, 32, 32])), - (torch.Size([1, 7, 320, 384])), - (torch.Size([1, 4, 320, 384])), + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), ), ) @pytest.mark.parametrize( @@ -1032,10 +1045,12 @@ def test_binary_prelu_ttnn(input_shapes, device): ) @skip_for_grayskull() def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): - in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.prelu(input_tensor1, scalar) + output_tensor = ttnn.to_torch(output_tensor) golden_function = ttnn.get_golden_function(ttnn.prelu) golden_tensor = golden_function(in_data1, scalar) - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index e6e149ffa0e0..d81303773b26 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -418,6 +418,7 @@ void bind_binary_composite_overload( const binary_operation_t& operation, const std::string& description, const std::string& supported_dtype = "BFLOAT16", + const std::string& supported_rank = "2, 3, 4", const std::string& note="") { auto doc = fmt::format( R"doc( @@ -447,9 +448,9 @@ void bind_binary_composite_overload( - Ranks * - {3} - TILE - - 2, 3, 4 + - {4} - {4} + {5} Example: >>> tensor1 = ttnn.from_torch(torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) @@ -460,6 +461,7 @@ void bind_binary_composite_overload( operation.python_fully_qualified_name(), description, supported_dtype, + supported_rank, note); bind_registered_operation( @@ -1168,7 +1170,9 @@ void py_module(py::module& module) { detail::bind_binary_composite_overload( module, ttnn::prelu, - R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc"); + R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc", + R"doc(2, 3, 4, 5)doc"); detail::bind_binary_composite( module, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 5933cc63db9b..e46b86fa0720 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -269,11 +269,16 @@ Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::option } Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { - const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape(); - auto volume = input_b.get_logical_volume(); - // If volume = 1 Support for a single-value tensor yet to be handled. TODO(#14933) - TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size"); - Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array{1, s_a[1], 1, 1}}); + const auto s_a = input_a.get_shape(); + const auto volume = input_b.get_logical_volume(); + + TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.", volume, s_a[1]); + Tensor b = input_b; + if(s_a.rank()>2){ + SmallVector reshape(s_a.rank(), 1); + reshape[1] = s_a[1]; + b = ttnn::reshape(input_b, ttnn::Shape(reshape)); + } Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); return result; }