Skip to content

Commit

Permalink
#5769: Add batch mul support in composite
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 15, 2024
1 parent 5b1f1d3 commit 56e7868
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,8 @@ Other Operations

.. autofunction:: tt_lib.tensor.argmin

.. autofunction:: tt_lib.tensor.batch_mul

Backward Operations
===================

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs


@pytest.mark.parametrize(
("input_shape_a", "input_shape_b"),
[
(torch.Size([1, 1, 32, 32]), torch.Size([32, 1, 32, 32])),
(torch.Size([4, 3, 320, 384]), torch.Size([1, 3, 320, 384])),
(torch.Size([16, 3, 320, 384]), torch.Size([1, 3, 320, 384])),
(torch.Size([32, 1, 32, 1024]), torch.Size([1, 1, 32, 1024])),
# (torch.Size([1, 1, 320, 384]), torch.Size([64, 1, 320, 384])), # error w.r.t arg is more than 1kb #issue 6361
],
)
class TestBatchMul:
def test_batch_mul(self, input_shape_a, input_shape_b, device):
torch.manual_seed(0)

input_data_a = torch.randn(input_shape_a).bfloat16()
input_tensor_a = (
tt_lib.tensor.Tensor(input_data_a, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
input_data_b = torch.randn(input_shape_b).bfloat16()
input_tensor_b = (
tt_lib.tensor.Tensor(input_data_b, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
tt_output_tensor_on_device = tt_lib.tensor.batch_mul(input_tensor_a, input_tensor_b)

tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
pt_out_tensor = torch.mul(input_data_a, input_data_b)
comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99)
comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=4, rtol=1e-1)
logger.info(comp_pass)
logger.info(comp_all)
logger.info(comp_out)
status = comp_pass | comp_all
assert status
27 changes: 27 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,33 @@ Tensor argmin(
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _argmin)(input_a, dim, all, output_mem_config);
}


Tensor _batch_mul(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) {

const Shape shape_a = input_a.get_legacy_shape();
const Shape shape_b = input_b.get_legacy_shape();
Tensor in_a = input_a;
Tensor in_b = input_b;
Shape shape({1, 1, 1, 1});
if (shape_a[0] > shape_b[0])
{
shape[0] = shape_a[0];
in_b = repeat(input_b, shape, output_mem_config);
}
else
{
shape[0] = shape_b[0];
in_a = repeat(input_a, shape, output_mem_config);
}
return mul(in_a, in_b, std::nullopt, output_mem_config);
}
Tensor batch_mul(
const Tensor& input_a,
const Tensor& input_b,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _batch_mul)(input_a, input_b, output_mem_config);
}
} // namespace tt_metal

} // namespace tt
5 changes: 5 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ Tensor argmin(
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor batch_mul(
const Tensor& input_a,
const Tensor& input_b,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ namespace tt::tt_metal::detail{
detail::bind_binary_op<false, true, false>(m_tensor, "xlogy", &xlogy, R"doc(Performs eltwise-binary xlogy (``{0} * log( {1} )``) on two tensors.)doc");
detail::bind_binary_op<false, true, false>(m_tensor, "atan2", &atan2, R"doc(Returns tensor with the atan2 activation on elements of the input tensors ``{0}`` and ``{1}``.)doc");
detail::bind_binary_op<false, true, false>(m_tensor, "nextafter", &nextafter, R"doc(Returns the next floating-point value after input_a towards input_b of the input tensors ``{0}`` and ``{1}``.)doc");
detail::bind_binary_op<false, true, false>(m_tensor, "batch_mul", &batch_mul, R"doc(Returns the result tensor of matmul the input tensors ``{0}`` and ``{1}`` where batch size is not equal. )doc");

// *** type-2 complex operations in new submodule 'type2_complex' ***
auto m_type2_cplx = m_tensor.def_submodule("complex", "Complex type2");
Expand Down

0 comments on commit 56e7868

Please sign in to comment.