Skip to content

Commit

Permalink
Feedback from review
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 16, 2024
1 parent 19274c7 commit 981a34f
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 57 deletions.
15 changes: 7 additions & 8 deletions tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include "ttnn/distributed/api.hpp"
#include "ttnn/operations/functions.hpp"
#include "ttnn/tensor/xtensor/conversion_utils.hpp"
#include "ttnn_test_fixtures.hpp"
#include <ttnn/distributed/types.hpp>
#include <ttnn/distributed/distributed_tensor.hpp>
Expand Down Expand Up @@ -97,18 +96,18 @@ TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) {
ASSERT_EQ(num_cols, 4);

EXPECT_ANY_THROW(
shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));

EXPECT_ANY_THROW(
shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DInvalidShardConfig) {
EXPECT_ANY_THROW(shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{}));
EXPECT_ANY_THROW(shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{}));
}

TEST_F(TensorDistributionTest, Concat2DInvalidConfig) {
EXPECT_ANY_THROW(concat_mesh_2d_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2}));
EXPECT_ANY_THROW(concat_2d_mesh_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
Expand All @@ -122,7 +121,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32));
input_tensor.print();

auto mapper = shard_tensor_2d_to_mesh_mapper(
auto mapper = shard_tensor_to_2d_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
Expand Down Expand Up @@ -156,7 +155,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
Tensor input_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32));

auto mapper = shard_tensor_2d_to_mesh_mapper(
auto mapper = shard_tensor_to_2d_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
Expand All @@ -171,7 +170,7 @@ TEST_F(TensorDistributionTest, Shard2D) {
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = concat_mesh_2d_to_tensor_composer(
auto composer = concat_2d_mesh_to_tensor_composer(
*mesh_device_,
Concat2dConfig{
.row_dim = 0,
Expand Down
16 changes: 8 additions & 8 deletions tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace {
using ::testing::SizeIs;
using ::tt::tt_metal::Tensor;
using ::ttnn::experimental::xtensor::chunk;
using ::ttnn::experimental::xtensor::concatenate;
using ::ttnn::experimental::xtensor::concat;

TEST(PartitionTest, ChunkBasicNonDivisible3) {
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Expand Down Expand Up @@ -51,7 +51,7 @@ TEST(PartitionTest, DefaultAxis) {
xt::xarray<double> b = {{5.0, 6.0}, {7.0, 8.0}};
std::vector<xt::xarray<double>> input = {a, b};

xt::xarray<double> result = concatenate(input); // axis=0 by default
xt::xarray<double> result = concat(input); // axis=0 by default
xt::xarray<double> expected = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}};

xt::allclose(result, expected);
Expand All @@ -62,7 +62,7 @@ TEST(PartitionTest, AxisOne) {
xt::xarray<int> y = {{7, 8}, {9, 10}};
std::vector<xt::xarray<int>> input = {x, y};

xt::xarray<int> result = concatenate(input, 1);
xt::xarray<int> result = concat(input, 1);
xt::xarray<int> expected = {{1, 2, 3, 7, 8}, {4, 5, 6, 9, 10}};

xt::allclose(result, expected);
Expand All @@ -74,7 +74,7 @@ TEST(PartitionTest, MultipleArraysAxis0) {
xt::xarray<float> c = {5.0f, 6.0f};
std::vector<xt::xarray<float>> input = {a, b, c};

xt::xarray<float> result = concatenate(input, 0);
xt::xarray<float> result = concat(input, 0);
xt::xarray<float> expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

xt::allclose(result, expected);
Expand All @@ -85,7 +85,7 @@ TEST(PartitionTest, EmptyArray) {
xt::xarray<int> b; // Empty
std::vector<xt::xarray<int>> input = {a, b};

EXPECT_ANY_THROW({ xt::xarray<int> result = concatenate(input, 0); });
EXPECT_ANY_THROW({ xt::xarray<int> result = concat(input, 0); });
}

TEST(PartitionTest, HigherDimensions) {
Expand All @@ -95,10 +95,10 @@ TEST(PartitionTest, HigherDimensions) {
arr2.reshape({2, 2, 2});

std::vector<xt::xarray<int>> input = {arr1, arr2};
xt::xarray<int> result = concatenate(input, 0);
xt::xarray<int> result = concat(input, 0);

// Expected: shape (4,2,2) with arr1 stacked over arr2 along axis 0
xt::xarray<int> expected = concatenate(xt::xtuple(arr1, arr2), 0);
xt::xarray<int> expected = xt::concatenate(xt::xtuple(arr1, arr2), 0);

xt::allclose(result, expected);
}
Expand All @@ -109,7 +109,7 @@ TEST(PartitionTest, HigherAxis) {
// Both have shape (2,2,2)

std::vector<xt::xarray<int>> input = {arr1, arr2};
xt::xarray<int> result = concatenate(input, 2);
xt::xarray<int> result = concat(input, 2);
// Expected shape: (2,2,4)
xt::xarray<int> expected = {{{1, 2, 9, 10}, {3, 4, 11, 12}}, {{5, 6, 13, 14}, {7, 8, 15, 16}}};

Expand Down
6 changes: 3 additions & 3 deletions tt-train/sources/ttml/core/distributed_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ class ConcatMesh2dToTensor : public MeshToXTensor<ConcatMesh2dToTensor<T>, T> {
auto row_end = row_start + cols;
std::vector<xt::xarray<T>> row_tensors(row_start, row_end);

auto concatenated_row = core::concatenate(row_tensors, col_dim);
auto concatenated_row = core::concat(row_tensors, col_dim);
row_concatenated.push_back(std::move(concatenated_row));
}

auto result = core::concatenate(row_concatenated, row_dim);
auto result = core::concat(row_concatenated, row_dim);
return {result};
}

Expand Down Expand Up @@ -216,7 +216,7 @@ class ConcatMeshToXTensor : public MeshToXTensor<ConcatMeshToXTensor<T>, T> {
}

std::vector<xt::xarray<T>> compose_impl(const std::vector<xt::xarray<T>>& tensors) const {
return {core::concatenate(tensors, m_concat_dim)};
return {core::concat(tensors, m_concat_dim)};
}

private:
Expand Down
4 changes: 2 additions & 2 deletions tt-train/sources/ttml/core/xtensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ auto xtensor_to_span(const xt::xarray<T>& xtensor) {
}

template <typename T>
xt::xarray<T> concatenate(const std::vector<xt::xarray<T>>& v, size_t axis = 0) {
return ttnn::experimental::xtensor::concatenate(v, axis);
xt::xarray<T> concat(const std::vector<xt::xarray<T>>& v, size_t axis = 0) {
return ttnn::experimental::xtensor::concat(v, axis);
}

} // namespace ttml::core
2 changes: 1 addition & 1 deletion tt-train/tests/core/distributed_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ TYPED_TEST(MeshOpsTest, ConcatenateSameParametersAsCompose) {

std::vector<xt::xarray<TypeParam>> shards = {s1, s2, s3};
ttml::core::ConcatMeshToXTensor<TypeParam> composer(mesh_shape, 0);
auto composed = ttml::core::concatenate(shards);
auto composed = ttml::core::concat(shards);

xt::xarray<TypeParam> expected = {
TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)};
Expand Down
28 changes: 14 additions & 14 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class ShardTensorToMesh : public TensorToMesh {
int shard_dim_ = -1;
};

class Shard2dTensorToMesh : public TensorToMesh {
class ShardTensorTo2dMesh : public TensorToMesh {
public:
Shard2dTensorToMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) :
ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) :
mesh_shape_(mesh_shape), config_(config) {}

std::vector<Tensor> map(const Tensor& tensor) override {
Expand Down Expand Up @@ -85,7 +85,7 @@ class Shard2dTensorToMesh : public TensorToMesh {

TT_FATAL(
static_cast<int>(tensor_shards.size()) == rows * cols,
"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh "
"ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh "
"dimensions. Size: {}, rows: {}, cols: {}",
tensor_shards.size(),
rows,
Expand All @@ -106,16 +106,16 @@ class ConcatMeshToTensor : public MeshToTensor {
ConcatMeshToTensor(int dim) : concat_dim_(dim) {}

Tensor compose(const std::vector<Tensor>& tensors) override {
return experimental::xtensor::concatenate(tensors, concat_dim_);
return experimental::xtensor::concat(tensors, concat_dim_);
}

private:
int concat_dim_ = -1;
};

class ConcatMesh2dToTensor : public MeshToTensor {
class Concat2dMeshToTensor : public MeshToTensor {
public:
ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) :
Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) :
mesh_shape_(mesh_device.shape()), config_(config) {}

Tensor compose(const std::vector<Tensor>& tensors) override {
Expand All @@ -128,10 +128,10 @@ class ConcatMesh2dToTensor : public MeshToTensor {
auto row_start = tensors.begin() + i * cols;
auto row_end = row_start + cols;
std::vector<Tensor> row_tensors(row_start, row_end);
row_concatenated.push_back(experimental::xtensor::concatenate(row_tensors, col_dim));
row_concatenated.push_back(experimental::xtensor::concat(row_tensors, col_dim));
}

return experimental::xtensor::concatenate(row_concatenated, row_dim);
return experimental::xtensor::concat(row_concatenated, row_dim);
}

private:
Expand All @@ -149,29 +149,29 @@ std::unique_ptr<TensorToMesh> shard_tensor_to_mesh_mapper(MeshDevice& mesh_devic
return std::make_unique<ShardTensorToMesh>(mesh_device.num_devices(), dim);
}

std::unique_ptr<TensorToMesh> shard_tensor_2d_to_mesh_mapper(
std::unique_ptr<TensorToMesh> shard_tensor_to_2d_mesh_mapper(
MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) {
TT_FATAL(
config.row_dim.has_value() || config.col_dim.has_value(),
"ShardTensor2dMesh requires at least one dimension to shard");
"Sharding a tensor to 2D mesh requires at least one dimension to shard");
TT_FATAL(
mesh_shape.num_rows <= mesh_device.shape().num_rows && //
mesh_shape.num_cols <= mesh_device.shape().num_cols,
"ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.");
return std::make_unique<Shard2dTensorToMesh>(mesh_shape, config);
"Device mesh shape does not match the provided mesh shape.");
return std::make_unique<ShardTensorTo2dMesh>(mesh_shape, config);
}

std::unique_ptr<MeshToTensor> concat_mesh_to_tensor_composer(int dim) {
return std::make_unique<ConcatMeshToTensor>(dim);
}

std::unique_ptr<MeshToTensor> concat_mesh_2d_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config) {
std::unique_ptr<MeshToTensor> concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config) {
TT_FATAL(
config.row_dim != config.col_dim,
"Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}",
config.row_dim,
config.col_dim);
return std::make_unique<ConcatMesh2dToTensor>(mesh_device, config);
return std::make_unique<Concat2dMeshToTensor>(mesh_device, config);
}

Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper) {
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct Shard2dConfig {
std::optional<int> row_dim;
std::optional<int> col_dim;
};
std::unique_ptr<TensorToMesh> shard_tensor_2d_to_mesh_mapper(
std::unique_ptr<TensorToMesh> shard_tensor_to_2d_mesh_mapper(
MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config);

// Creates a composer that concatenates a tensor across a single dimension.
Expand All @@ -47,7 +47,7 @@ struct Concat2dConfig {
int row_dim = -1;
int col_dim = -1;
};
std::unique_ptr<MeshToTensor> concat_mesh_2d_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config);
std::unique_ptr<MeshToTensor> concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config);

// Distributes a host tensor onto multi-device configuration according to the `mapper`.
Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper);
Expand Down
29 changes: 24 additions & 5 deletions ttnn/cpp/ttnn/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,20 @@ struct Tensor {
std::vector<Device*> get_workers(bool blocking = false) const;

// Converts a buffer of elements of type `T` to a `Tensor`.
// Elements are assumed to be stored in row-major order. The size of the span and the type have to match `spec`.
// Elements in the buffer are assumed to be stored in row-major order. The size of the buffer and the type of the
// elements have to match `spec`.
//
// TODO: tilized layouts and reduced precision types are currently not supported.
// The data in the buffer is copied into a tensor with an owned storage.
//
// IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t),
// and only row-major layout.
//
// TODO:
// 1. add support for returning a tensor with a borrowed storage based off the buffer.
// 2. add support for sharding.
// 3. add support for block float formats.
// 4. add support for tilized layouts.
// 5. add support for on-device tensor creation.
template <typename T>
static Tensor from_span(tt::stl::Span<const T> buffer, const TensorSpec& spec);

Expand All @@ -152,9 +163,17 @@ struct Tensor {
return from_span(tt::stl::Span<const T>(buffer.data(), buffer.size()), spec);
}

// Converts a `Tensor` to a buffer of elements of type `T`.
// Elements in the buffer will be stored in row-major order. The type of the elements has to match that of the
// `Tensor`.
// Converts a `Tensor` to a `std::vector<T>`.
// Elements in the vector will be stored in row-major order. The type of the requested vector has to match that of
// the `Tensor`.
//
// If the tensor resides on a device, it will be brough back to host.
//
// IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t).
//
// TODO:
// 1. add support for sharding.
// 2. add support for block float formats.
template <typename T>
std::vector<T> to_vector() const;

Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ auto xtensor_to_span(const xt::xarray<T>& xtensor) {
}

// Converts an xtensor to a Tensor.
// IMPORTANT: this copies the data into the returned Tensor, which can be an expensive operation.
template <typename T>
tt::tt_metal::Tensor from_xtensor(const xt::xarray<T>& buffer, const TensorSpec& spec) {
auto shape = get_shape_from_xarray(buffer);
Expand All @@ -54,6 +55,7 @@ tt::tt_metal::Tensor from_xtensor(const xt::xarray<T>& buffer, const TensorSpec&
}

// Converts a Tensor to an xtensor.
// IMPORTANT: this copies the data into the returned Tensor, which can be an expensive operation.
template <typename T>
xt::xarray<T> to_xtensor(const tt::tt_metal::Tensor& tensor) {
auto vec = tensor.to_vector<T>();
Expand Down
24 changes: 12 additions & 12 deletions ttnn/cpp/ttnn/tensor/xtensor/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ std::vector<xt::xarray<T>> chunk(const xt::xarray<T>& xtensor, int num_chunks, i
}

template <typename T>
xt::xarray<T> concatenate(const std::vector<xt::xarray<T>>& v, int dim) {
xt::xarray<T> concat(const std::vector<xt::xarray<T>>& v, int dim) {
constexpr size_t MAX_TUPLE_SIZE = 64;

if (v.empty()) {
Expand All @@ -112,22 +112,22 @@ xt::xarray<T> concatenate(const std::vector<xt::xarray<T>>& v, int dim) {
}
}

template xt::xarray<double> concatenate(const std::vector<xt::xarray<double>>& v, int dim);
template xt::xarray<float> concatenate(const std::vector<xt::xarray<float>>& v, int dim);
template xt::xarray<uint32_t> concatenate(const std::vector<xt::xarray<uint32_t>>& v, int dim);
template xt::xarray<int32_t> concatenate(const std::vector<xt::xarray<int32_t>>& v, int dim);
template xt::xarray<double> concat(const std::vector<xt::xarray<double>>& v, int dim);
template xt::xarray<float> concat(const std::vector<xt::xarray<float>>& v, int dim);
template xt::xarray<uint32_t> concat(const std::vector<xt::xarray<uint32_t>>& v, int dim);
template xt::xarray<int32_t> concat(const std::vector<xt::xarray<int32_t>>& v, int dim);

// Adaptor APIs from xtensor to ttnn::Tensor.
namespace adaptor {
namespace {

template <typename T>
Tensor concatenate_impl(const std::vector<Tensor>& tensors, const TensorLayout& layout, int dim) {
Tensor concat_impl(const std::vector<Tensor>& tensors, const TensorLayout& layout, int dim) {
std::vector<xt::xarray<T>> xtensors;
for (const auto& tensor : tensors) {
xtensors.push_back(to_xtensor<T>(tensor));
}
xt::xarray<T> result = concatenate(xtensors, dim);
xt::xarray<T> result = concat(xtensors, dim);
return from_xtensor<T>(result, TensorSpec(get_shape_from_xarray(result), layout));
}

Expand Down Expand Up @@ -159,14 +159,14 @@ std::vector<Tensor> chunk(const Tensor& tensor, int num_chunks, int dim) {
}
}

Tensor concatenate(const std::vector<Tensor>& tensors, int dim) {
Tensor concat(const std::vector<Tensor>& tensors, int dim) {
TT_FATAL(tensors.size() > 0, "Cannot concatenate an empty list of tensors");
const auto& reference_layout = tensors.front().tensor_spec().tensor_layout();
switch (reference_layout.get_data_type()) {
case DataType::BFLOAT16: return adaptor::concatenate_impl<float>(tensors, reference_layout, dim);
case DataType::FLOAT32: return adaptor::concatenate_impl<float>(tensors, reference_layout, dim);
case DataType::INT32: return adaptor::concatenate_impl<int32_t>(tensors, reference_layout, dim);
case DataType::UINT32: return adaptor::concatenate_impl<uint32_t>(tensors, reference_layout, dim);
case DataType::BFLOAT16: return adaptor::concat_impl<float>(tensors, reference_layout, dim);
case DataType::FLOAT32: return adaptor::concat_impl<float>(tensors, reference_layout, dim);
case DataType::INT32: return adaptor::concat_impl<int32_t>(tensors, reference_layout, dim);
case DataType::UINT32: return adaptor::concat_impl<uint32_t>(tensors, reference_layout, dim);
default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type());
}
}
Expand Down
Loading

0 comments on commit 981a34f

Please sign in to comment.