Skip to content

Commit

Permalink
#0: support for multidevice tensors in concat with non-aligned last dim
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt committed Nov 5, 2024
1 parent e549ca7 commit 1de5cff
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,17 @@ MassagedConcat build_non_aligned_last_dim_concat(const std::vector<ttnn::Tensor>
// not all aligned
auto dim_aligned = [](const std::vector<ttnn::Tensor>& tensors, int dim) -> bool {
return std::all_of(tensors.begin(), tensors.end(), [&](const ttnn::Tensor& tensor) {
return tensor.get_padded_shape()[dim] * tensor.element_size() % tensor.buffer()->alignment() == 0;
auto storage_type = tensor.storage_type();
if (storage_type == tt::tt_metal::StorageType::DEVICE) {
return tensor.get_padded_shape()[dim] * tensor.element_size() % tensor.buffer()->alignment() == 0;
} else if (storage_type == tt::tt_metal::StorageType::MULTI_DEVICE) {
auto buffers = tensor.buffers();
return std::all_of(buffers.begin(), buffers.end(), [&](Buffer *buffer) {
return tensor.get_padded_shape()[dim] * tensor.element_size() % buffer->alignment() == 0;
});
} else {
TT_THROW("ttnn.concat: expected a tensor with device storage, but got a tensor with storage type {}", tensor.storage_type());
}
});
};

Expand Down

0 comments on commit 1de5cff

Please sign in to comment.