Skip to content

Commit

Permalink
#16406: Fix negative divisor issue in remainder
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Jan 22, 2025
1 parent a964ce3 commit 1d8664f
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run(
scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item()

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run(
)(input_shape)

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run(
)(input_shape)

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run(
)(input_shape)

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run(
scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item()

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def run(
scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item()

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar, device=device)

sharded_config = ttnn.create_sharded_memory_config_(
shape=input_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def test_binary_remainder_ttnn(input_shapes, device):
in_data2, input_tensor2 = data_gen_with_range(input_shapes, -100, 100, device)
output_tensor = ttnn.remainder(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.remainder)
golden_tensor = golden_function(in_data1, in_data2)
golden_tensor = golden_function(in_data1, in_data2, device=device)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand All @@ -541,7 +541,7 @@ def test_shape_remainder(device, shapes):
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16) * (high - low) + low

golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b, device=device)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
Expand Down Expand Up @@ -573,7 +573,7 @@ def test_remainder_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -150, 150, device)
output_tensor = ttnn.remainder(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.remainder)
golden_tensor = golden_function(in_data1, scalar)
golden_tensor = golden_function(in_data1, scalar, device=device)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_remainder_fp32(device, ttnn_function):
x_torch = torch.rand([2, 3, 64, 64], dtype=torch.float32)
y_torch = torch.rand([2, 3, 64, 64], dtype=torch.float32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
z_torch = golden_fn(x_torch, y_torch, device=device)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_div = ttnn.remainder(x_tt, y_tt)
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_remainder_forge(device, ttnn_function):
input2 = torch.randn(2, 32, 32)

golden_fn = ttnn.get_golden_function(ttnn_function)
torch_output = golden_fn(input1, input2)
torch_output = golden_fn(input1, input2, device=device)

input1 = ttnn.from_torch(input1, dtype=ttnn.float32)
input2 = ttnn.from_torch(input2, dtype=ttnn.float32)
Expand Down
49 changes: 49 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_remainder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull


@skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0")
@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
def test_broken_remainder(input_shapes, device):
torch_lhs = torch.ones(32, 32, dtype=torch.bfloat16)
torch_rhs = torch.zeros(32, 32, dtype=torch.bfloat16)

golden_function = ttnn.get_golden_function(ttnn.remainder)
golden = golden_function(torch_lhs, torch_rhs, device=device)

tt_lhs = ttnn.from_torch(torch_lhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
tt_rhs = ttnn.from_torch(torch_rhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
tt_result = ttnn.remainder(tt_lhs, tt_rhs)
result = ttnn.to_torch(tt_result)
assert torch.allclose(result, golden, atol=0.01, rtol=0)


@skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0")
@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
def test_broken_remainder1(input_shapes, device):
torch_lhs = torch.ones(32, 32, dtype=torch.bfloat16) * 95
torch_rhs = torch.ones(32, 32, dtype=torch.bfloat16) * (-94.5)

golden_function = ttnn.get_golden_function(ttnn.remainder) # all -94.0
golden = golden_function(torch_lhs, torch_rhs, device=device)

tt_lhs = ttnn.from_torch(torch_lhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
tt_rhs = ttnn.from_torch(torch_rhs, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)

tt_result = ttnn.remainder(tt_lhs, tt_rhs)
result = ttnn.to_torch(tt_result) # all 0.5
assert torch.allclose(result, golden, atol=0.01, rtol=0)
18 changes: 17 additions & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ def run_unary_test_with_float(device, h, w, scalar, ttnn_function, pcc=0.9999):
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


def run_unary_test_with_float_remainder(device, h, w, scalar, ttnn_function, pcc=0.9999):
torch.manual_seed(0)

torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.remainder)
torch_output_tensor = golden_function(torch_input_tensor, scalar, device=device)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor, scalar)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@pytest.mark.parametrize("scalar", [1, 2])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
Expand Down Expand Up @@ -389,7 +405,7 @@ def test_relu_max(device, h, w, upper_limit):
@pytest.mark.parametrize("w", [128])
@skip_for_grayskull("Op not supported for Grayskull, supported for wormhole_b0")
def test_remainder(device, h, w, scalar):
run_unary_test_with_float(device, h, w, scalar, ttnn.remainder)
run_unary_test_with_float_remainder(device, h, w, scalar, ttnn.remainder)


@pytest.mark.parametrize("scalar", [1.5, 2.0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,45 +366,41 @@ Tensor ExecutePrelu::invoke(
Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
return result;
}

Tensor run_remainder(
const Tensor& input_a, const Tensor& input_b, float t_nan, const std::optional<MemoryConfig>& output_mem_config) {
Tensor result = ttnn::subtract(
input_a,
ttnn::multiply(
input_b, ttnn::div(input_a, input_b, true, "floor", output_mem_config), std::nullopt, output_mem_config),
std::nullopt,
output_mem_config);
result = ttnn::where(ttnn::ge(result, input_b), ttnn::subtract(result, input_b), result);
result = ttnn::where(ttnn::ltz(input_b), ttnn::add(result, input_b), result);
result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result);
result = ttnn::where(ttnn::eqz(input_a), 0.0f, ttnn::where(ttnn::eqz(input_b), t_nan, result));
result = ttnn::where(ttnn::logical_and(ttnn::eqz(input_a), ttnn::eqz(input_b)), t_nan, result);
return result;
}
// Binary remainder will be overloaded by unary remainder in another PR
Tensor ExecuteBinaryRemainder::invoke(
const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
auto arch = input_a.device()->arch();
TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Op is supported on Wormhole or Blackhole");
DataType input_dtype = input_a.get_dtype();

float t_nan = std::nanf("");
float t_nan = tt::tt_metal::experimental::hal::get_nan();

// No typecast for FP32 input
if (input_dtype == DataType::FLOAT32 && input_b.get_dtype() == DataType::FLOAT32) {
Tensor result = ttnn::subtract(
input_a,
ttnn::multiply(
input_b,
ttnn::div(input_a, input_b, true, "floor", output_mem_config),
std::nullopt,
output_mem_config),
std::nullopt,
output_mem_config);
result = ttnn::where(ttnn::ge(result, input_b), ttnn::subtract(result, input_b), result);
result = ttnn::where(ttnn::ltz(input_b), ttnn::add(result, input_b), result);
result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result);
return result;
}
Tensor a = typecast(input_a, DataType::FLOAT32);
Tensor b = typecast(input_b, DataType::FLOAT32);
Tensor result = ttnn::subtract(
a,
ttnn::multiply(
b, ttnn::div(input_a, input_b, false, "floor", output_mem_config), std::nullopt, output_mem_config),
std::nullopt,
output_mem_config);
result = ttnn::where(ttnn::ge(result, b), ttnn::subtract(result, b), result);
result = ttnn::where(ttnn::ltz(b), ttnn::add(result, b), result);
result = ttnn::where(ttnn::eq(input_a, input_b, std::nullopt, output_mem_config), 0.0f, result);
result = typecast(result, input_dtype);
result = ttnn::where(ttnn::eqz(input_a), 0.0f, ttnn::where(ttnn::eqz(input_b), t_nan, result));
return result;
const auto do_typecast = input_dtype != DataType::FLOAT32 or input_b.get_dtype() != DataType::FLOAT32;
const auto& a = do_typecast ? typecast(input_a, DataType::FLOAT32) : input_a;
const auto& b = do_typecast ? typecast(input_b, DataType::FLOAT32) : input_b;

// Perform the remainder operation
Tensor result = run_remainder(a, b, t_nan, output_mem_config);

// Return the result, typecasted if necessary
return do_typecast ? typecast(result, input_dtype) : result;
}

Tensor ExecuteBinaryRemainder::invoke(
Expand Down
9 changes: 7 additions & 2 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,15 @@ def _golden_function_floor_div(input_tensor_a, input_tensor_b, *args, **kwargs):
ttnn.attach_golden_function(ttnn.floor_div, golden_function=_golden_function_floor_div)


def _golden_function_remainder(input_tensor_a, input_tensor_b, *args, **kwargs):
def _golden_function_remainder(input_tensor_a, input_tensor_b, *args, device, **kwargs):
import torch

return torch.remainder(input_tensor_a, input_tensor_b)
return torch.nan_to_num(
torch.remainder(input_tensor_a, input_tensor_b),
nan=device.sfpu_nan(),
posinf=device.sfpu_inf(),
neginf=-device.sfpu_inf(),
)


ttnn.attach_golden_function(ttnn.remainder, golden_function=_golden_function_remainder)
Expand Down

0 comments on commit 1d8664f

Please sign in to comment.