From 1de5cffef0e5eaecf31e39a7f57d73834a503010 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Tue, 5 Nov 2024 02:52:58 +0000 Subject: [PATCH] #0: support for multidevice tensors in concat with non-aligned last dim --- .../ttnn/operations/data_movement/concat/concat.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp index a5821989862..3c99bab7215 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -189,7 +189,17 @@ MassagedConcat build_non_aligned_last_dim_concat(const std::vector // not all aligned auto dim_aligned = [](const std::vector& tensors, int dim) -> bool { return std::all_of(tensors.begin(), tensors.end(), [&](const ttnn::Tensor& tensor) { - return tensor.get_padded_shape()[dim] * tensor.element_size() % tensor.buffer()->alignment() == 0; + auto storage_type = tensor.storage_type(); + if (storage_type == tt::tt_metal::StorageType::DEVICE) { + return tensor.get_padded_shape()[dim] * tensor.element_size() % tensor.buffer()->alignment() == 0; + } else if (storage_type == tt::tt_metal::StorageType::MULTI_DEVICE) { + auto buffers = tensor.buffers(); + return std::all_of(buffers.begin(), buffers.end(), [&](Buffer *buffer) { + return tensor.get_padded_shape()[dim] * tensor.element_size() % buffer->alignment() == 0; + }); + } else { + TT_THROW("ttnn.concat: expected a tensor with device storage, but got a tensor with storage type {}", tensor.storage_type()); + } }); };