Skip to content

Commit

Permalink
Add integer support for eltwise ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 12, 2024
1 parent fbc8a9d commit 10356b8
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
101 changes: 101 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_comp_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range
from models.utility_functions import is_grayskull, skip_for_blackhole


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
@pytest.mark.parametrize(
"mem_configs",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
)
@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16))
@pytest.mark.parametrize(
"ttnn_function",
(ttnn.lt, ttnn.gt, ttnn.eq, ttnn.le, ttnn.ge, ttnn.ne, ttnn.logical_and, ttnn.logical_or, ttnn.logical_xor),
)
def test_binary_comp_ops(input_shapes, out_dtype, mem_configs, ttnn_function, device):
if is_grayskull():
pytest.skip("GS does not support fp32/uint32/uint16 data types")

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)

cq_id = 0
mem_cfg = mem_configs

tt_output_tensor_on_device = ttnn_function(
input_tensor, other_tensor, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id
)

golden_fn = ttnn.get_golden_function(ttnn_function)
golden_tensor = golden_fn(in_data, other_data)
golden_tensor = golden_tensor.int()

output_tensor = ttnn.to_torch(tt_output_tensor_on_device)

are_equal = torch.equal(output_tensor, golden_tensor)
assert are_equal


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
@pytest.mark.parametrize(
"mem_configs",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
)
@pytest.mark.parametrize(
"scalar",
{2.3, 15.6, 55.4, 72.5, 120.6},
)
@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16))
@pytest.mark.parametrize(
"ttnn_function",
(
ttnn.lt,
ttnn.gt,
ttnn.eq,
ttnn.le,
ttnn.ge,
ttnn.ne,
),
)
def test_binary_comp_ops_scalar(input_shapes, scalar, out_dtype, mem_configs, ttnn_function, device):
if is_grayskull():
pytest.skip("GS does not support fp32/uint32/uint16 data types")

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)

cq_id = 0
mem_cfg = mem_configs

tt_output_tensor_on_device = ttnn_function(
input_tensor, scalar, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id
)

golden_fn = ttnn.get_golden_function(ttnn_function)
golden_tensor = golden_fn(in_data, scalar)
golden_tensor = golden_tensor.int()

output_tensor = ttnn.to_torch(tt_output_tensor_on_device)

are_equal = torch.equal(output_tensor, golden_tensor)
assert are_equal
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement/repeat/repeat.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/copy.hpp"

namespace ttnn::operations::binary {

Expand All @@ -26,6 +27,7 @@ inline Tensor binary_impl(
BinaryOpType binary_op_type,
const ttnn::Tensor &input_tensor,
const float scalar,
const std::optional<const DataType> &dtype = std::nullopt,
const std::optional<ttnn::MemoryConfig> &memory_config = std::nullopt,
const std::optional<Tensor> &optional_output_tensor = std::nullopt) {
auto output_memory_config = optional_output_tensor.has_value()
Expand Down Expand Up @@ -59,6 +61,8 @@ inline Tensor binary_impl(
} else {
TT_THROW("Unsupported operation");
}
if(dtype.has_value())
output_tensor = ttnn::typecast(queue_id, output_tensor, dtype.value(), std::nullopt, optional_output_tensor);
return output_tensor;
}

Expand Down Expand Up @@ -295,7 +299,7 @@ Tensor RelationalBinary<binary_op_type>::invoke(
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return detail::binary_impl(
DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor);
DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor);
}

template <BinaryOpType binary_op_type>
Expand All @@ -309,7 +313,7 @@ Tensor RelationalBinary<binary_op_type>::invoke(
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return detail::binary_impl(
DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor);
DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor);
}
// scalar - tensor combination not available on Pytorch for this op
template <BinaryOpType binary_op_type>
Expand Down

0 comments on commit 10356b8

Please sign in to comment.