Skip to content

Commit

Permalink
#13127: Add support for new logical sharding + alignment in TensorLay…
Browse files Browse the repository at this point in the history
…out and tensor creation

- Add ShardMode enum to specify shard shape in shard spec as either physical or logical
- Update PageConfig create_default_alignment to use physical shard shape in shard spec if provided for logical sharding
- Add optional physical shard shape to shard spec for logical sharding
  * Add ShardSpec constructor that takes in physical shard shape (which will automatically set shard mode to ShardMode::LOGICAL)
  * ShardMode::PHYSICAL: This is current behaviour that we will deprecate!
    ** It is less expressive than using shard shape as logical (ie. must be tile aligned for TILE layout etc...)
    ** It fundamentally operates on padded shape and is confusing and incompatible with logical shape
  * ShardMode::LOGICAL: Shard shape cuts 2D logical shape and each shard is aligned after
    ** Without alignment restrictions, you can cut 2D logical shape more arbitrarily
    ** Existing sharding can be switched over to this entirely (just need codeowners to help out and flip...)
    ** If physical shard shape is provided, align to it instead of default alignment
  * Default everywhere will be ShardMode::PHYSICAL (TODO: Add a warning message?)
- Switch tests/ttnn/unit_tests/operations/test_paged_update_cache.py to use logical shard shape as an example
  * Introduce tensor.logical_volume() (as opposed to tensor.volume() which returns physical volume based on padded shape)
  * TODO: Rename volume() -> physical_volume() and logical_volume() -> volume()
- Add new c++ tests to test tensor creation with logical shard shape + alignment
  * IMPORTANT: Need to update host data manipulation to be aware of new logical sharding for use from python!

To support these changes, some changes to TensorLayout:
- Update private TensorLayout constructor with alignment with these changes:
  * legacyShapeToAlignment will try to return 2D alignment if possible (ie. only padding on height/width)
    ** Goal is to transition alignment to be 2D only if we remove poor use cases of padding on non-height/width dims
  * legacyShapeToAlignment uses default alignment for sharded tensors
    ** Before interleaved or sharded will just use padded shape for alignment
    ** One exception is for row major sharded tensors where we use shard width if shape is padded;
       otherwise, we only take shard width for BLOCK/WIDTH sharded cases and original physical shape for HEIGHT sharded
  * legacyShapeToAlignment (and alignment in general) will work iff there is only padding on height and/or width
    ** IMPORTANT: This means we are expecting tensors with arbitrary padding along non-height/width to be interleaved only!
- If ShardMode::LOGICAL:
  * In TensorLayout::compute_shard_spec_buffer, calculate physical shard shape based on shard shape + alignment if not provided
  * In TensorLayout::compute_physical_shape, calculate physical shape based on number of logical shards
- Clean up handling of sharded tensors and error messages in ttnn/cpp/ttnn/tensor/layout/page_config.cpp
- Add Size constructor for std::array<uint32_t, 2>
  • Loading branch information
TT-BrianLiu committed Nov 18, 2024
1 parent 02f5747 commit 36d9c9d
Show file tree
Hide file tree
Showing 16 changed files with 511 additions and 97 deletions.
310 changes: 275 additions & 35 deletions tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp

Large diffs are not rendered by default.

28 changes: 17 additions & 11 deletions tests/ttnn/unit_tests/operations/test_paged_update_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def run_test_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -151,11 +152,12 @@ def test_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down Expand Up @@ -234,11 +236,12 @@ def test_update_cache_decode_program_cache(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down Expand Up @@ -276,11 +279,12 @@ def run_test_tensor_index_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -414,11 +418,12 @@ def run_test_paged_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -543,11 +548,12 @@ def test_paged_update_cache_decode_program_caching(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ tt_metal::ShardSpec from_json_t<tt_metal::ShardSpec>::operator()(const nlohmann:
from_json<CoreRangeSet>(json_object.at("grid")),
from_json<std::array<uint32_t, 2>>(json_object.at("shape")),
from_json<tt_metal::ShardOrientation>(json_object.at("orientation")),
from_json<bool>(json_object.at("halo"))};
from_json<bool>(json_object.at("halo")),
from_json<tt_metal::ShardMode>(json_object.at("mode"))};
}
}
21 changes: 18 additions & 3 deletions tt_metal/impl/buffers/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,27 @@ struct ShardSpec {
ShardOrientation orientation = ShardOrientation::ROW_MAJOR;
bool halo = false;

// In ShardMode::PHYSICAL, physical_shard_shape will always be std::nullopt
ShardMode mode = ShardMode::PHYSICAL;
std::optional<std::array<uint32_t, 2>> physical_shard_shape = std::nullopt;

ShardSpec(
const CoreRangeSet &core_sets_,
const std::array<uint32_t, 2> &shard_shape_,
const ShardOrientation &shard_orientation_ = ShardOrientation::ROW_MAJOR,
const bool &halo_ = false,
const ShardMode &shard_mode_ = ShardMode::PHYSICAL) :
grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_), mode(shard_mode_), physical_shard_shape(std::nullopt) {
}

ShardSpec(
const CoreRangeSet &core_sets_,
const std::array<uint32_t, 2> &shard_shape_,
const std::array<uint32_t, 2> &physical_shard_shape_,
const ShardOrientation &shard_orientation_ = ShardOrientation::ROW_MAJOR,
const bool &halo_ = false) :
grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_) {
grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_), mode(ShardMode::LOGICAL), physical_shard_shape(physical_shard_shape_) {
TT_FATAL(physical_shard_shape_[0] >= shard_shape_[0] and physical_shard_shape_[1] >= shard_shape_[1], "Physical shard shape ({}, {}) must be greater or equal to logical shard shape ({}, {})!", physical_shard_shape_[0], physical_shard_shape_[1], shard_shape_[0], shard_shape_[1]);
}

const uint32_t num_cores() const { return this->grid.num_cores(); }
Expand All @@ -63,9 +78,9 @@ struct ShardSpec {
bool operator==(const ShardSpec& other) const;
bool operator!=(const ShardSpec& other) const;

static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo");
static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo", "mode");
constexpr auto attribute_values() const {
return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo);
return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo, this->mode);
}
};

Expand Down
5 changes: 5 additions & 0 deletions tt_metal/impl/buffers/buffer_constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ enum class ShardOrientation {
COL_MAJOR,
};

enum class ShardMode {
PHYSICAL, // TODO: Deprecate this option to treat shard shape as physical
LOGICAL,
};

enum class BufferType {
DRAM,
L1,
Expand Down
11 changes: 11 additions & 0 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1505,13 +1505,24 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
// TODO: Rename to physical_volume
"volume", [](const Tensor &self) { return self.volume(); }, R"doc(
Get the volume of the tensor.
.. code-block:: python
volume = tt_tensor.volume()
)doc")
.def(
// TODO: Rename to volume
"logical_volume", [](const Tensor &self) { return self.get_logical_volume(); }, R"doc(
Get the logical volume of the tensor.
.. code-block:: python
volume = tt_tensor.logical_volume()
)doc")
.def(
"storage_type", [](const Tensor &self) { return self.storage_type(); }, R"doc(
Expand Down
6 changes: 5 additions & 1 deletion ttnn/cpp/pybind11/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void tensor_mem_config_module_types(py::module& m_tensor) {
export_enum<MathFidelity>(m_tensor);
export_enum<TensorMemoryLayout>(m_tensor);
export_enum<ShardOrientation>(m_tensor);
export_enum<ShardMode>(m_tensor);

py::enum_<tt::tt_metal::BufferType>(m_tensor, "BufferType")
.value("DRAM", BufferType::DRAM)
Expand Down Expand Up @@ -266,10 +267,13 @@ void tensor_mem_config_module(py::module& m_tensor) {
.def(py::init<>([](const CoreRangeSet& core_sets,
const std::array<uint32_t, 2>& shard_shape,
const ShardOrientation& shard_orientation,
const bool& halo) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo); }))
const bool& halo,
const ShardMode& shard_mode) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, shard_mode); }),
py::arg("grid"), py::arg("shard_shape"), py::arg("shard_orientation"), py::arg("halo"), py::arg("shard_mode") = ShardMode::PHYSICAL)
.def_readwrite("shape", &ShardSpec::shape, "Shape of shard.")
.def_readwrite("grid", &ShardSpec::grid, "Grid to layout shards.")
.def_readwrite("orientation", &ShardSpec::orientation, "Orientation of cores to read shards")
.def_readwrite("mode", &ShardSpec::mode, "Treat shard shape as physical (default) or logical")
.def("num_cores", &ShardSpec::num_cores, "Number of cores")
.def(py::self == py::self)
.def(py::self != py::self);
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/layout/alignment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ class Alignment final : protected ShapeBase {

std::ostream &operator<<(std::ostream &os, const tt::tt_metal::Alignment &shape);

} // namespace ttnn
} // namespace tt::tt_metal
66 changes: 37 additions & 29 deletions ttnn/cpp/ttnn/tensor/layout/page_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void PageConfig::validate_alignment(const Alignment& alignment, DataType dtype,
std::visit([&](const auto& config) constexpr { config.validate_alignment(alignment, dtype, memory_config); }, config_);
}

Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config); }, config_);
Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const {
return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config, physical_shard_size); }, config_);
}

size_t PageConfig::get_page_size_bytes(const Size& page_shape, DataType dtype) const {
Expand All @@ -78,7 +78,13 @@ TilePageConfig::TilePageConfig(const Tile& tile)
: tile_(tile) {
}

Alignment TilePageConfig::create_default_alignment(DataType dtype, const MemoryConfig&) const {
Alignment TilePageConfig::create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const {
if (memory_config.shard_spec.has_value()) {
const auto& shard_spec = memory_config.shard_spec.value();
if (shard_spec.physical_shard_shape.has_value()) {
return Alignment(shard_spec.physical_shard_shape.value());
}
}
return Alignment({tile_.get_height(), tile_.get_width()});
}

Expand All @@ -92,7 +98,7 @@ void TilePageConfig::validate_alignment(const Alignment& alignment, DataType dty
"Wrong custom Tensor Layout alignment {}. For Tile layout second innermost dimension should be multiple of tile height {}.", alignment, tile_.get_height());
}

Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>&) const {
if(memory_config.memory_layout == TensorMemoryLayout::SINGLE_BANK && physical_size.width() != 0 && physical_size.height() != 0) {
return physical_size;
}
Expand All @@ -116,20 +122,23 @@ Alignment RowMajorPageConfig::create_default_alignment(DataType dtype, const Mem
const auto element_size = CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype);
auto width_alignment = sizeof(uint32_t) / element_size;

if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
if (memory_config.shard_spec.has_value()) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
const auto shard_width = shard_shape[1];
TT_FATAL(
(shard_width % width_alignment) == 0,
"Invalid sharding configuration: For Row Major layout with element size of {} bytes, the innermost dimension must align to {} bytes. "
"Buffer data is packed as uint32_t (4 bytes), so the provided shard shape {} does not meet alignment requirements.",
element_size, width_alignment, shard_shape
);

width_alignment = shard_width;
if (shard_spec.physical_shard_shape.has_value()) {
return Alignment(shard_spec.physical_shard_shape.value());
}
if (shard_spec.mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& physical_shard_shape = shard_spec.shape;
const auto physical_shard_width = physical_shard_shape[1];
TT_FATAL(
(physical_shard_width % width_alignment) == 0,
"For Row Major layout and shard mode {}, the width of shard shape {} is treated as physical shard width and must be aligned to {} since we pack buffer data as uint32_t.",
shard_spec.mode, physical_shard_shape, width_alignment
);

width_alignment = physical_shard_width;
}
}

return Alignment({width_alignment});}
}

Expand All @@ -140,21 +149,21 @@ void RowMajorPageConfig::validate_alignment(const Alignment& alignment, DataType
const uint32_t page_alignment = sizeof(uint32_t) / element_size;

TT_FATAL((width_alignment % page_alignment) == 0,
"Incorrect alignment configuration for Row Major layout: alignment {} requires innermost dimension alignment of {} bytes due to uint32_t (4-byte) packing, but the current alignment size is {}.",
alignment, element_size, page_alignment);
"Incorrect alignment configuration for Row Major layout: innermost dimension alignment must be aligned to {} bytes since we pack buffer data as uint32_t. With element size of {} byte(s), alignment {} must be a multiple of alignment {}.",
sizeof(uint32_t), element_size, alignment, page_alignment);

if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
const auto shard_width = shard_shape[1];
// TODO: Do we need to validate sharded width here if wee are guaranteed that physical_shard_width is set as width_alignment
if (memory_config.shard_spec.has_value() && memory_config.shard_spec.value().mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& physical_shard_shape = memory_config.shard_spec.value().shape;
const auto physical_shard_width = physical_shard_shape[1];
TT_FATAL(
width_alignment % shard_width == 0,
"Alignment mismatch for sharded tensor: Expected alignment width of {} to match shard shape {} for Row Major layout.",
width_alignment, shard_shape);
physical_shard_width % width_alignment == 0,
"Alignment mismatch for sharded tensor: Expected physical shard shape {} to be aligned to {} along the width for Row Major layout.",
physical_shard_width, width_alignment);
}
}

Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const {
if (physical_size.height() == 0 || physical_size.width() == 0) {
return Size(1, sizeof(uint32_t) / CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype));
}
Expand All @@ -164,10 +173,9 @@ Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtyp
}

if (memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
TT_FATAL(physical_shard_size.has_value(), "For width or block sharded tensors, Row Major page width comes from physical shard size so it must be provided!");

return Size(1, shard_shape[1]);
return Size(1, physical_shard_size.value().width());
}

return Size(1, physical_size.width());
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/layout/page_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RowMajorPageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;
};

Expand All @@ -34,7 +34,7 @@ class TilePageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;

const Tile& get_tile() const;
Expand All @@ -54,7 +54,7 @@ class PageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;

std::optional<Tile> get_tile() const;
Expand Down
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/tensor/layout/size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Size::Size(const std::pair<size_t, size_t>& size)
Size::Size(const std::array<size_t, 2>& size)
: Size(size[0], size[1]) {}

Size::Size(const std::array<uint32_t, 2>& size)
: Size(size[0], size[1]) {}

Size Size::operator*(size_t scalar) const {
return Size(height_ * scalar, width_ * scalar);
}
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/tensor/layout/size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Size final {
Size(std::size_t height, std::size_t width);
Size(const std::pair<std::size_t, std::size_t>& size);
Size(const std::array<std::size_t, 2>& size);
Size(const std::array<std::uint32_t, 2>& size);

operator std::pair<std::size_t, std::size_t>() const;
operator std::array<std::size_t, 2>() const;
Expand Down
Loading

0 comments on commit 36d9c9d

Please sign in to comment.