diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp index a0906934751..50b9f2fa55a 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp @@ -7,6 +7,7 @@ #include "ttnn/distributed/api.hpp" #include "ttnn/operations/functions.hpp" +#include "ttnn/tensor/tensor_utils.hpp" #include "ttnn_test_fixtures.hpp" #include #include @@ -21,12 +22,23 @@ TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) { return TensorSpec(shape, TensorLayout(dtype, Layout::ROW_MAJOR, MemoryConfig{})); } +TEST_F(TensorDistributionTest, DistributeToDevice) { + Tensor input_tensor = Tensor::from_vector( + std::vector{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); + + auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_); + + // If no device is provided, the tensor is kept on host. + EXPECT_TRUE(distribute_tensor(input_tensor, *mapper).storage_type() == StorageType::MULTI_DEVICE_HOST); + EXPECT_TRUE(distribute_tensor(input_tensor, *mapper, *mesh_device_).storage_type() == StorageType::MULTI_DEVICE); +} + TEST_F(TensorDistributionTest, Replication) { Tensor input_tensor = Tensor::from_vector( std::vector{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_); - Tensor replicated_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor replicated_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); std::vector device_tensors = get_device_tensors(replicated_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); @@ -43,12 +55,12 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) { EXPECT_ANY_THROW({ auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, -1); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); }); EXPECT_ANY_THROW({ auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 4); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); }); } @@ -60,7 +72,7 @@ TEST_F(TensorDistributionTest, Shard1DTooFewShards) { EXPECT_ANY_THROW({ auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 3); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); }); } @@ -74,7 +86,7 @@ TEST_F(TensorDistributionTest, Shard1D) { Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32)); auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); std::vector device_tensors = get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); @@ -127,7 +139,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { Shard2dConfig{ .row_dim = 1, }); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); sharded_tensor.print(); std::vector device_tensors = get_device_tensors(sharded_tensor); @@ -162,7 +174,7 @@ TEST_F(TensorDistributionTest, Shard2D) { .row_dim = 1, .col_dim = 2, }); - Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); + Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_); std::vector device_tensors = get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index e8716199a63..a4821a17c38 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -18,7 +18,7 @@ class ReplicateTensorToMesh : public TensorToMesh { public: ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} - std::vector map(const Tensor& tensor) override { + std::vector map(const Tensor& tensor) const override { std::vector tensors; tensors.reserve(num_devices_); std::fill_n(std::back_inserter(tensors), num_devices_, tensor); @@ -37,7 +37,7 @@ class ShardTensorToMesh : public TensorToMesh { public: ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} - std::vector map(const Tensor& tensor) override { + std::vector map(const Tensor& tensor) const override { return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); } @@ -55,7 +55,7 @@ class ShardTensorTo2dMesh : public TensorToMesh { ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) {} - std::vector map(const Tensor& tensor) override { + std::vector map(const Tensor& tensor) const override { const auto [rows, cols] = mesh_shape_; const auto [row_dim, col_dim] = config_; @@ -111,7 +111,7 @@ class ConcatMeshToTensor : public MeshToTensor { public: ConcatMeshToTensor(int dim) : concat_dim_(dim) {} - Tensor compose(const std::vector& tensors) override { + Tensor compose(const std::vector& tensors) const override { return experimental::xtensor::concat(tensors, concat_dim_); } @@ -124,7 +124,7 @@ class Concat2dMeshToTensor : public MeshToTensor { Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} - Tensor compose(const std::vector& tensors) override { + Tensor compose(const std::vector& tensors) const override { const auto [rows, cols] = mesh_shape_; const auto [row_dim, col_dim] = config_; @@ -180,7 +180,8 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh return std::make_unique(mesh_device, config); } -Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper) { +Tensor distribute_tensor( + const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) { TT_FATAL( tensor.storage_type() != tt::tt_metal::StorageType::MULTI_DEVICE && tensor.storage_type() != tt::tt_metal::StorageType::MULTI_DEVICE_HOST, @@ -188,10 +189,13 @@ Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorTo tensor.storage_type()); std::vector tensors = mapper.map(tensor); Tensor output = aggregate_as_tensor(tensors, mapper.config()); - return output.to(&mesh_device); + if (mesh_device.has_value()) { + return output.to(&(mesh_device->get())); + } + return output; } -Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer) { +Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { return is_multi_device_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor)) : composer.compose({tensor}); } diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index d8c8b060cf6..7d49ca932f4 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -13,7 +13,7 @@ namespace ttnn::distributed { class TensorToMesh { public: virtual ~TensorToMesh() = default; - virtual std::vector map(const Tensor& tensor) = 0; + virtual std::vector map(const Tensor& tensor) const = 0; virtual tt::tt_metal::DistributedTensorConfig config() const = 0; }; @@ -21,7 +21,7 @@ class TensorToMesh { class MeshToTensor { public: virtual ~MeshToTensor() = default; - virtual Tensor compose(const std::vector& tensors) = 0; + virtual Tensor compose(const std::vector& tensors) const = 0; }; // Creates a mapper that replicates a tensor across all devices. @@ -50,9 +50,12 @@ struct Concat2dConfig { std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); // Distributes a host tensor onto multi-device configuration according to the `mapper`. -Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper); +Tensor distribute_tensor( + const Tensor& tensor, + const TensorToMesh& mapper, + std::optional> mesh_device = std::nullopt); // Aggregates a multi-device tensor into a host tensor according to the `composer`. -Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer); +Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer); } // namespace ttnn::distributed