Skip to content

Commit

Permalink
#12835: Fix binary fmod (#13568)
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw authored Nov 4, 2024
1 parent 5ec25b0 commit d46ba83
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
27 changes: 25 additions & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ def test_remainder_ttnn(input_shapes, scalar, device):
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
# (torch.Size([1, 1, 320, 384])),
# (torch.Size([1, 3, 320, 384])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for fmod")
Expand All @@ -532,6 +532,29 @@ def test_binary_fmod_ttnn(input_shapes, device):
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for fmod")
# Input with more than two decimal places experience precision loss.
def test_binary_fmod_decimal_ttnn(input_shapes, device):
in_data1 = torch.randn(input_shapes, dtype=torch.bfloat16) * 9
input_tensor1 = ttnn.Tensor(in_data1, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device)
in_data2 = torch.rand(input_shapes, dtype=torch.bfloat16) - 2
input_tensor2 = ttnn.Tensor(in_data2, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device)
output_tensor = ttnn.fmod(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.fmod)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([output_tensor], [golden_tensor], 0.97)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ Tensor ExecuteBinaryFmod::invoke(const Tensor& input_a, const Tensor& input_b, c
DataType input_dtype = input_a.get_dtype();
Tensor a = typecast(input_a, DataType::FLOAT32);
Tensor b = typecast(input_b, DataType::FLOAT32);
Tensor result = ttnn::subtract(a, ttnn::multiply(ttnn::div(input_a, input_b, true, "trunc", output_mem_config), b, std::nullopt, output_mem_config), std::nullopt, output_mem_config);
Tensor div_res = typecast(ttnn::div(input_a, input_b, true, "trunc", output_mem_config), DataType::FLOAT32);
Tensor result = ttnn::subtract(a, ttnn::multiply(div_res, b, std::nullopt, output_mem_config), std::nullopt, output_mem_config);
result = ttnn::where(ttnn::eq(a, b, std::nullopt, output_mem_config), ttnn::full_like(input_a, 0.0f), result);
return typecast(result, input_dtype);
}
Expand Down

0 comments on commit d46ba83

Please sign in to comment.