Skip to content

Commit

Permalink
#5769: Fix CI fail issue
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 22, 2024
1 parent 99e185c commit fca3190
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ struct make_eltwise_binary {
Shape shape_b = input_tensor_b.get_legacy_shape();
Tensor in_a = input_tensor_a;
Tensor in_b = input_tensor_b;
if (shape_a[0] > shape_b[0])
{
Shape shape ({shape_a[0],1,1,1});
in_b = repeat(input_tensor_b, shape, output_mem_config);
}
else
if (shape_a[0] != shape_b[0])
{
Shape shape ({shape_b[0],1,1,1});
in_a = repeat(input_tensor_a, shape, output_mem_config);
if (shape_a[0] > shape_b[0])
{
Shape shape ({shape_a[0],1,1,1});
in_b = repeat(input_tensor_b, shape, output_mem_config);
}
else
{
Shape shape ({shape_b[0],1,1,1});
in_a = repeat(input_tensor_a, shape, output_mem_config);
}
}
TT_FATAL(
in_a.get_legacy_shape() == in_b.get_legacy_shape(), "Input shapes must be the same!");
Expand Down Expand Up @@ -174,17 +177,34 @@ inline Tensor add(
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<const DataType> output_dtype = std::nullopt,
bool in_place = false) {
TT_FATAL(input_tensor_a.get_legacy_shape() == input_tensor_b.get_legacy_shape(), "Input shapes must be the same!");
Shape shape_a = input_tensor_a.get_legacy_shape();
Shape shape_b = input_tensor_b.get_legacy_shape();
Tensor in_a = input_tensor_a;
Tensor in_b = input_tensor_b;
if (shape_a[0] != shape_b[0])
{
if (shape_a[0] > shape_b[0])
{
Shape shape ({shape_a[0],1,1,1});
in_b = repeat(input_tensor_b, shape, output_mem_config);
}
else
{
Shape shape ({shape_b[0],1,1,1});
in_a = repeat(input_tensor_a, shape, output_mem_config);
}
}
TT_FATAL(in_a.get_legacy_shape() == in_b.get_legacy_shape(), "Input shapes must be the same!");
auto output = operation::run(
EltwiseBinary{
BinaryOpType::ADD,
fused_activations,
output_mem_config,
output_dtype.value_or(input_tensor_a.get_dtype()),
in_place},
{input_tensor_a, input_tensor_b});
{in_a, in_b});
if (in_place) {
return input_tensor_a;
return in_a;
} else {
return output.at(0);
}
Expand Down

0 comments on commit fca3190

Please sign in to comment.