diff --git a/tests/sweep_framework/sweeps/eltwise/binary/fmod/fmod_unary.py b/tests/sweep_framework/sweeps/eltwise/binary/fmod/fmod_unary.py index 35426ab28a4f..721a52f6476e 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/fmod/fmod_unary.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/fmod/fmod_unary.py @@ -97,7 +97,7 @@ def run( scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder.py index a2d3645d6707..9700107c8bd9 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder.py @@ -88,7 +88,7 @@ def run( )(input_shape) golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_forge.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_forge.py index 496ffc0abe2b..b5252be9d470 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_forge.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_forge.py @@ -68,7 +68,7 @@ def run( )(input_shape) golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py index 59b5fa43506f..804f3efc2c8c 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py @@ -55,7 +55,7 @@ def run( )(input_shape) golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary.py index 647a9a8abf42..88e28e2da154 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary.py @@ -97,7 +97,7 @@ def run( scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary_sharded.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary_sharded.py index 27ba3471cb85..b0bc3fb4066c 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary_sharded.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_unary_sharded.py @@ -103,7 +103,7 @@ def run( scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device) sharded_config = ttnn.create_sharded_memory_config_( shape=input_shape, diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 97cff73907d3..c23d0cc5f28b 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -516,7 +516,7 @@ def test_binary_remainder_ttnn(input_shapes, device): in_data2, input_tensor2 = data_gen_with_range(input_shapes, -100, 100, device) output_tensor = ttnn.remainder(input_tensor1, input_tensor2) golden_function = ttnn.get_golden_function(ttnn.remainder) - golden_tensor = golden_function(in_data1, in_data2) + golden_tensor = golden_function(in_data1, in_data2, device=device) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass @@ -541,7 +541,7 @@ def test_shape_remainder(device, shapes): torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16) * (high - low) + low golden_function = ttnn.get_golden_function(ttnn.remainder) - torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG @@ -573,7 +573,7 @@ def test_remainder_ttnn(input_shapes, scalar, device): in_data1, input_tensor1 = data_gen_with_range(input_shapes, -150, 150, device) output_tensor = ttnn.remainder(input_tensor1, scalar) golden_function = ttnn.get_golden_function(ttnn.remainder) - golden_tensor = golden_function(in_data1, scalar) + golden_tensor = golden_function(in_data1, scalar, device=device) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py b/tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py index 55105be12af8..489315d15280 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py @@ -21,7 +21,7 @@ def test_remainder_fp32(device, ttnn_function): x_torch = torch.rand([2, 3, 64, 64], dtype=torch.float32) y_torch = torch.rand([2, 3, 64, 64], dtype=torch.float32) golden_fn = ttnn.get_golden_function(ttnn_function) - z_torch = golden_fn(x_torch, y_torch) + z_torch = golden_fn(x_torch, y_torch, device=device) x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) z_tt_div = ttnn.remainder(x_tt, y_tt) @@ -93,7 +93,7 @@ def test_remainder_forge(device, ttnn_function): input2 = torch.randn(2, 32, 32) golden_fn = ttnn.get_golden_function(ttnn_function) - torch_output = golden_fn(input1, input2) + torch_output = golden_fn(input1, input2, device=device) input1 = ttnn.from_torch(input1, dtype=ttnn.float32) input2 = ttnn.from_torch(input2, dtype=ttnn.float32) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_remainder.py b/tests/ttnn/unit_tests/operations/eltwise/test_remainder.py new file mode 100644 index 000000000000..aee615dd8280 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_remainder.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_grayskull + + +@skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0") +@pytest.mark.parametrize( + "input_shapes", + ((torch.Size([1, 1, 32, 32])),), +) +def test_broken_remainder(input_shapes, device): + torch_lhs = torch.ones(32, 32, dtype=torch.bfloat16) + torch_rhs = torch.zeros(32, 32, dtype=torch.bfloat16) + + golden_function = ttnn.get_golden_function(ttnn.remainder) + golden = golden_function(torch_lhs, torch_rhs, device=device) + + tt_lhs = ttnn.from_torch(torch_lhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + tt_rhs = ttnn.from_torch(torch_rhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + tt_result = ttnn.remainder(tt_lhs, tt_rhs) + result = ttnn.to_torch(tt_result) + assert torch.allclose(result, golden, atol=0.01, rtol=0) + + +@skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0") +@pytest.mark.parametrize( + "input_shapes", + ((torch.Size([1, 1, 32, 32])),), +) +def test_broken_remainder1(input_shapes, device): + torch_lhs = torch.ones(32, 32, dtype=torch.bfloat16) * 95 + torch_rhs = torch.ones(32, 32, dtype=torch.bfloat16) * (-94.5) + + golden_function = ttnn.get_golden_function(ttnn.remainder) # all -94.0 + golden = golden_function(torch_lhs, torch_rhs, device=device) + + tt_lhs = ttnn.from_torch(torch_lhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + tt_rhs = ttnn.from_torch(torch_rhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + + tt_result = ttnn.remainder(tt_lhs, tt_rhs) + result = ttnn.to_torch(tt_result) # all 0.5 + assert torch.allclose(result, golden, atol=0.01, rtol=0) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py index baea244c5820..f8dde956b7e1 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py @@ -321,6 +321,22 @@ def run_unary_test_with_float(device, h, w, scalar, ttnn_function, pcc=0.9999): assert_with_pcc(torch_output_tensor, output_tensor, pcc) +def run_unary_test_with_float_remainder(device, h, w, scalar, ttnn_function, pcc=0.9999): + torch.manual_seed(0) + + torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) + golden_function = ttnn.get_golden_function(ttnn.remainder) + torch_output_tensor = golden_function(torch_input_tensor, scalar, device=device) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn_function(input_tensor, scalar) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, pcc) + + @pytest.mark.parametrize("scalar", [1, 2]) @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) @@ -389,7 +405,7 @@ def test_relu_max(device, h, w, upper_limit): @pytest.mark.parametrize("w", [128]) @skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0") def test_remainder(device, h, w, scalar): - run_unary_test_with_float(device, h, w, scalar, ttnn.remainder) + run_unary_test_with_float_remainder(device, h, w, scalar, ttnn.remainder) @pytest.mark.parametrize("scalar", [1.5, 2.0]) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 606176cded1c..98f04ed8221f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -366,6 +366,22 @@ Tensor ExecutePrelu::invoke( Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); return result; } + +Tensor run_remainder( + const Tensor& input_a, const Tensor& input_b, float t_nan, const std::optional& output_mem_config) { + Tensor result = ttnn::subtract( + input_a, + ttnn::multiply( + input_b, ttnn::div(input_a, input_b, true, "floor", output_mem_config), std::nullopt, output_mem_config), + std::nullopt, + output_mem_config); + result = ttnn::where(ttnn::ge(result, input_b), ttnn::subtract(result, input_b), result); + result = ttnn::where(ttnn::ltz(input_b), ttnn::add(result, input_b), result); + result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result); + result = ttnn::where(ttnn::eqz(input_a), 0.0f, ttnn::where(ttnn::eqz(input_b), t_nan, result)); + result = ttnn::where(ttnn::logical_and(ttnn::eqz(input_a), ttnn::eqz(input_b)), t_nan, result); + return result; +} // Binary remainder will be overloaded by unary remainder in another PR Tensor ExecuteBinaryRemainder::invoke( const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { @@ -373,38 +389,18 @@ Tensor ExecuteBinaryRemainder::invoke( TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Op is supported on Wormhole or Blackhole"); DataType input_dtype = input_a.get_dtype(); - float t_nan = std::nanf(""); + float t_nan = tt::tt_metal::experimental::hal::get_nan(); // No typecast for FP32 input - if (input_dtype == DataType::FLOAT32 && input_b.get_dtype() == DataType::FLOAT32) { - Tensor result = ttnn::subtract( - input_a, - ttnn::multiply( - input_b, - ttnn::div(input_a, input_b, true, "floor", output_mem_config), - std::nullopt, - output_mem_config), - std::nullopt, - output_mem_config); - result = ttnn::where(ttnn::ge(result, input_b), ttnn::subtract(result, input_b), result); - result = ttnn::where(ttnn::ltz(input_b), ttnn::add(result, input_b), result); - result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result); - return result; - } - Tensor a = typecast(input_a, DataType::FLOAT32); - Tensor b = typecast(input_b, DataType::FLOAT32); - Tensor result = ttnn::subtract( - a, - ttnn::multiply( - b, ttnn::div(input_a, input_b, false, "floor", output_mem_config), std::nullopt, output_mem_config), - std::nullopt, - output_mem_config); - result = ttnn::where(ttnn::ge(result, b), ttnn::subtract(result, b), result); - result = ttnn::where(ttnn::ltz(b), ttnn::add(result, b), result); - result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result); - result = typecast(result, input_dtype); - result = ttnn::where(ttnn::eqz(input_a), 0.0f, ttnn::where(ttnn::eqz(input_b), t_nan, result)); - return result; + const auto do_typecast = input_dtype != DataType::FLOAT32 or input_b.get_dtype() != DataType::FLOAT32; + const auto& a = do_typecast ? typecast(input_a, DataType::FLOAT32) : input_a; + const auto& b = do_typecast ? typecast(input_b, DataType::FLOAT32) : input_b; + + // Perform the remainder operation + Tensor result = run_remainder(a, b, t_nan, output_mem_config); + + // Return the result, typecasted if necessary + return do_typecast ? typecast(result, input_dtype) : result; } Tensor ExecuteBinaryRemainder::invoke( diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index 3c1a54db3289..91ad2471ee56 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -330,10 +330,15 @@ def _golden_function_floor_div(input_tensor_a, input_tensor_b, *args, **kwargs): ttnn.attach_golden_function(ttnn.floor_div, golden_function=_golden_function_floor_div) -def _golden_function_remainder(input_tensor_a, input_tensor_b, *args, **kwargs): +def _golden_function_remainder(input_tensor_a, input_tensor_b, *args, device, **kwargs): import torch - return torch.remainder(input_tensor_a, input_tensor_b) + return torch.nan_to_num( + torch.remainder(input_tensor_a, input_tensor_b), + nan=device.sfpu_nan(), + posinf=device.sfpu_inf(), + neginf=-device.sfpu_inf(), + ) ttnn.attach_golden_function(ttnn.remainder, golden_function=_golden_function_remainder)