Skip to content

Commit

Permalink
Add padding cases for all_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jan 9, 2025
1 parent 21394d0 commit 86ef41c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def run_all_reduce_test(
([1, 1, 32, 8192]),
([1, 1, 32, 1024]),
([1, 1, 32, 2048]),
([1, 1, 1, 32]),
([1, 1, 32, 1]),
([1, 1, 3, 37]),
([1, 1, 83, 22]),
([1, 1, 4096, 32]),
([1, 1, 8192, 32]),
([1, 1, 1024, 32]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/cpp/ttnn/operations/copy.hpp"
#include <cstdint>

namespace ttnn {
Expand Down Expand Up @@ -172,6 +174,7 @@ static Tensor reduce_scatter_all_gather(
const ttnn::ccl::Topology& topology) {
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();
auto ccl_input_tensor = input_tensor;

uint32_t all_reduce_dim = -1;
for (uint32_t i = 0; i < rank; ++i) {
Expand All @@ -180,9 +183,29 @@ static Tensor reduce_scatter_all_gather(
}
}

ttnn::SmallVector<uint32_t> unpad_elements = {
ccl_input_tensor.get_logical_shape()[-4],
ccl_input_tensor.get_logical_shape()[-3],
ccl_input_tensor.get_logical_shape()[-2],
ccl_input_tensor.get_logical_shape()[-1]};
bool needs_padding = ccl_input_tensor.get_layout() == Layout::TILE &&
(ccl_input_tensor.get_logical_shape()[-2] % tt::constants::TILE_HEIGHT != 0 ||
ccl_input_tensor.get_logical_shape()[-1] % tt::constants::TILE_WIDTH != 0);
if (needs_padding) {
ttnn::SmallVector<std::pair<uint32_t, uint32_t>> padding = {{0, 0}, {0, 0}, {0, 0}, {0, 0}};
DataType original_dtype = ccl_input_tensor.get_dtype();
if (ccl_input_tensor.get_dtype() != DataType::BFLOAT16 && ccl_input_tensor.get_dtype() != DataType::FLOAT32) {
ccl_input_tensor = ttnn::typecast(ccl_input_tensor, DataType::BFLOAT16);
}
ccl_input_tensor = ttnn::pad(0, ccl_input_tensor, padding, 0, false, std::nullopt);
if (original_dtype != ccl_input_tensor.get_dtype()) {
ccl_input_tensor = ttnn::typecast(ccl_input_tensor, original_dtype);
}
}

const auto& reduced_tensor = operation::run(
ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct(
input_tensor,
ccl_input_tensor,
binary_op_type,
all_reduce_dim,
num_links,
Expand All @@ -191,7 +214,7 @@ static Tensor reduce_scatter_all_gather(
user_defined_num_buffers_per_channel,
devices,
topology),
{input_tensor});
{ccl_input_tensor});

const auto& gathered_tensor = operation::run(
ttnn::ccl::all_gather_detail::create_all_gather_struct(
Expand All @@ -205,7 +228,11 @@ static Tensor reduce_scatter_all_gather(
topology),
{reduced_tensor.at(0)});

return gathered_tensor.at(0);
if (needs_padding) {
return ttnn::ccl::unpad_output_tensor(gathered_tensor, num_devices, unpad_elements, all_reduce_dim).at(0);
} else {
return gathered_tensor.at(0);
}
}

Tensor run_all_reduce(
Expand Down

0 comments on commit 86ef41c

Please sign in to comment.