Skip to content

Commit

Permalink
#0: Make device an optional parameter in the tensor distribution API (
Browse files Browse the repository at this point in the history
#16746)

### Ticket
N/A

### Problem description
Per request from tt-mlir, "device" should be an optional parameter in
the tensor distribution API.

### What's changed
Made `device` an optional parameter.

Also made mapper / composer interfaces `const` qualified, because we
can.

### Checklist
- [X] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12778917625/job/35659006737)
- [X] [T3K
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12778914441)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Jan 15, 2025
1 parent 64c9b5e commit abc55d2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
26 changes: 19 additions & 7 deletions tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "ttnn/distributed/api.hpp"
#include "ttnn/operations/functions.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn_test_fixtures.hpp"
#include <ttnn/distributed/types.hpp>
#include <ttnn/distributed/distributed_tensor.hpp>
Expand All @@ -21,12 +22,23 @@ TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) {
return TensorSpec(shape, TensorLayout(dtype, Layout::ROW_MAJOR, MemoryConfig{}));
}

TEST_F(TensorDistributionTest, DistributeToDevice) {
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_);

// If no device is provided, the tensor is kept on host.
EXPECT_TRUE(distribute_tensor(input_tensor, *mapper).storage_type() == StorageType::MULTI_DEVICE_HOST);
EXPECT_TRUE(distribute_tensor(input_tensor, *mapper, *mesh_device_).storage_type() == StorageType::MULTI_DEVICE);
}

TEST_F(TensorDistributionTest, Replication) {
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_);
Tensor replicated_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor replicated_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);

std::vector<Tensor> device_tensors = get_device_tensors(replicated_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
Expand All @@ -43,12 +55,12 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) {

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, -1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);
});

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 4);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);
});
}

Expand All @@ -60,7 +72,7 @@ TEST_F(TensorDistributionTest, Shard1DTooFewShards) {

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 3);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);
});
}

Expand All @@ -74,7 +86,7 @@ TEST_F(TensorDistributionTest, Shard1D) {
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32));

auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
Expand Down Expand Up @@ -127,7 +139,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
Shard2dConfig{
.row_dim = 1,
});
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);
sharded_tensor.print();

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
Expand Down Expand Up @@ -162,7 +174,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
.row_dim = 1,
.col_dim = 2,
});
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mapper, *mesh_device_);

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
Expand Down
20 changes: 12 additions & 8 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ReplicateTensorToMesh : public TensorToMesh {
public:
ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {}

std::vector<Tensor> map(const Tensor& tensor) override {
std::vector<Tensor> map(const Tensor& tensor) const override {
std::vector<Tensor> tensors;
tensors.reserve(num_devices_);
std::fill_n(std::back_inserter(tensors), num_devices_, tensor);
Expand All @@ -37,7 +37,7 @@ class ShardTensorToMesh : public TensorToMesh {
public:
ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {}

std::vector<Tensor> map(const Tensor& tensor) override {
std::vector<Tensor> map(const Tensor& tensor) const override {
return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_);
}

Expand All @@ -55,7 +55,7 @@ class ShardTensorTo2dMesh : public TensorToMesh {
ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) :
mesh_shape_(mesh_shape), config_(config) {}

std::vector<Tensor> map(const Tensor& tensor) override {
std::vector<Tensor> map(const Tensor& tensor) const override {
const auto [rows, cols] = mesh_shape_;
const auto [row_dim, col_dim] = config_;

Expand Down Expand Up @@ -111,7 +111,7 @@ class ConcatMeshToTensor : public MeshToTensor {
public:
ConcatMeshToTensor(int dim) : concat_dim_(dim) {}

Tensor compose(const std::vector<Tensor>& tensors) override {
Tensor compose(const std::vector<Tensor>& tensors) const override {
return experimental::xtensor::concat(tensors, concat_dim_);
}

Expand All @@ -124,7 +124,7 @@ class Concat2dMeshToTensor : public MeshToTensor {
Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) :
mesh_shape_(mesh_device.shape()), config_(config) {}

Tensor compose(const std::vector<Tensor>& tensors) override {
Tensor compose(const std::vector<Tensor>& tensors) const override {
const auto [rows, cols] = mesh_shape_;
const auto [row_dim, col_dim] = config_;

Expand Down Expand Up @@ -180,18 +180,22 @@ std::unique_ptr<MeshToTensor> concat_2d_mesh_to_tensor_composer(MeshDevice& mesh
return std::make_unique<Concat2dMeshToTensor>(mesh_device, config);
}

Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper) {
Tensor distribute_tensor(
const Tensor& tensor, const TensorToMesh& mapper, std::optional<std::reference_wrapper<MeshDevice>> mesh_device) {
TT_FATAL(
tensor.storage_type() != tt::tt_metal::StorageType::MULTI_DEVICE &&
tensor.storage_type() != tt::tt_metal::StorageType::MULTI_DEVICE_HOST,
"TensorToMesh does not support multi-device or multi-device host tensors; got storage type: {}",
tensor.storage_type());
std::vector<Tensor> tensors = mapper.map(tensor);
Tensor output = aggregate_as_tensor(tensors, mapper.config());
return output.to(&mesh_device);
if (mesh_device.has_value()) {
return output.to(&(mesh_device->get()));
}
return output;
}

Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer) {
Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) {
return is_multi_device_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor))
: composer.compose({tensor});
}
Expand Down
11 changes: 7 additions & 4 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ namespace ttnn::distributed {
class TensorToMesh {
public:
virtual ~TensorToMesh() = default;
virtual std::vector<Tensor> map(const Tensor& tensor) = 0;
virtual std::vector<Tensor> map(const Tensor& tensor) const = 0;
virtual tt::tt_metal::DistributedTensorConfig config() const = 0;
};

// Composer interface that aggregates a multi-device tensor into a host tensor.
class MeshToTensor {
public:
virtual ~MeshToTensor() = default;
virtual Tensor compose(const std::vector<Tensor>& tensors) = 0;
virtual Tensor compose(const std::vector<Tensor>& tensors) const = 0;
};

// Creates a mapper that replicates a tensor across all devices.
Expand Down Expand Up @@ -50,9 +50,12 @@ struct Concat2dConfig {
std::unique_ptr<MeshToTensor> concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config);

// Distributes a host tensor onto multi-device configuration according to the `mapper`.
Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper);
Tensor distribute_tensor(
const Tensor& tensor,
const TensorToMesh& mapper,
std::optional<std::reference_wrapper<MeshDevice>> mesh_device = std::nullopt);

// Aggregates a multi-device tensor into a host tensor according to the `composer`.
Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer);
Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer);

} // namespace ttnn::distributed

0 comments on commit abc55d2

Please sign in to comment.