Skip to content

Commit

Permalink
#6361: Add native RM implementation of concat. Only restriction is th…
Browse files Browse the repository at this point in the history
…e requirement of aligned pages
  • Loading branch information
tt-aho committed Mar 18, 2024
1 parent 1267b58 commit a6ec598
Show file tree
Hide file tree
Showing 11 changed files with 909 additions and 741 deletions.
216 changes: 120 additions & 96 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,32 @@
)


@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B)))
@pytest.mark.parametrize("nChannels", ((2, 3, 4)))
def test_tile_simple_concat(memcfg, dtype, nChannels, device, function_level_defaults):
input_shape = torch.Size([nChannels, nChannels, 32, 32])
x = torch.arange(0, input_shape.numel()).reshape(input_shape).bfloat16()

y = (1 + torch.arange(0, input_shape.numel()).reshape(input_shape)).bfloat16()
def run_concat(shapes, dim, device, layout, dtype, input_mem_config, output_mem_config):
if layout == ttl.tensor.Layout.ROW_MAJOR and dtype == ttl.tensor.DataType.BFLOAT8_B:
pytest.skip("Illegal config")
if layout == ttl.tensor.Layout.TILE:
for shape in shapes:
if shape[-2] % 32 != 0 or shape[-1] % 32 != 0:
pytest.skip("Illegal config")
inputs = []
tt_inputs = []
for i in range(len(shapes)):
shape = torch.Size(shapes[i])
inputs.append(torch.rand(shape).to(torch.bfloat16))
tt_inputs.append(
ttl.tensor.Tensor(
inputs[i],
dtype,
)
.to(layout)
.to(device, input_mem_config)
)

xtt = (
ttl.tensor.Tensor(x, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(y, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
tt_cpu = torch.concat(inputs, dim)

dim = 3
output_shape = list(x.shape)
output_shape[3] = y.shape[3] + x.shape[3]
tt_cpu = torch.concat([x, y], dim)
assert tt_cpu.shape == torch.Size(output_shape)
tt = ttl.tensor.concat(tt_inputs, dim, output_mem_config)

tt = ttl.tensor.concat(xtt, dim)
assert list(tt.get_legacy_shape()) == output_shape
xtt_data = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR)
tt_dev = xtt_data.to_torch()
tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

if dtype == ttl.tensor.DataType.BFLOAT8_B:
passing, output = comp_pcc(tt_cpu, tt_dev)
Expand All @@ -59,89 +54,118 @@ def test_tile_simple_concat(memcfg, dtype, nChannels, device, function_level_def
assert passing


# @pytest.mark.skip(reason="For Stable Diffusion Sizes only")
@pytest.mark.parametrize(
"shape_a, shape_b, dim",
"shapes, dim",
(
((1, 1, 32, 32), (1, 1, 32, 32), 3),
((1, 1, 32, 64), (1, 1, 32, 128), 3),
((1, 1, 32, 128), (1, 1, 32, 64), 3),
((1, 1, 64, 128), (1, 1, 64, 256), 3),
((1, 32, 32, 32), (1, 32, 32, 32), 2),
((1, 1, 32, 32), (1, 1, 32, 32), 3),
((2, 4, 32, 1280), (2, 4, 32, 1280), 3),
(((1, 2, 64, 64),), -1),
(((1, 1, 64, 64), (1, 1, 128, 64)), -2),
(((1, 1, 32, 128), (1, 1, 32, 64), (1, 1, 32, 256)), -1),
(((2, 4, 32, 1280), (2, 3, 32, 1280), (2, 5, 32, 1280), (2, 8, 32, 1280)), 1),
(((1, 1, 32, 32), (1, 1, 32, 32)), 3),
(((1, 1, 32, 64), (1, 1, 32, 128)), 3),
(((1, 1, 32, 128), (1, 1, 32, 64)), 3),
(((1, 1, 64, 128), (1, 1, 64, 256)), 3),
(((1, 32, 32, 32), (1, 32, 32, 32)), 2),
(((2, 4, 32, 1280), (2, 4, 32, 1280)), 3),
# SD Shapes
((2, 1280, 4, 4), (2, 1280, 4, 4), 1),
((2, 640, 32, 32), (2, 320, 32, 32), 1),
((2, 1280, 8, 8), (2, 1280, 8, 8), 1),
((2, 640, 16, 16), (2, 640, 16, 16), 1),
((2, 320, 32, 32), (2, 320, 32, 32), 1),
(((2, 1280, 4, 4), (2, 1280, 4, 4)), 1),
(((2, 640, 32, 32), (2, 320, 32, 32)), 1),
(((2, 1280, 8, 8), (2, 1280, 8, 8)), 1),
(((2, 640, 16, 16), (2, 640, 16, 16)), 1),
(((2, 320, 32, 32), (2, 320, 32, 32)), 1),
),
)
def test_tile_concat(shape_a, shape_b, dim, device, function_level_defaults):
shape_a = torch.Size(shape_a)

x = torch.arange(0, shape_a.numel()).reshape(shape_a).to(torch.bfloat16)

shape_b = torch.Size(shape_b)
y = torch.arange(0, shape_b.numel()).reshape(shape_b).to(torch.bfloat16)

xtt = (
ttl.tensor.Tensor(
x,
ttl.tensor.DataType.BFLOAT16,
).to(device),
ttl.tensor.Tensor(
y,
ttl.tensor.DataType.BFLOAT16,
).to(device),
)

output_shape = list(x.shape)
output_shape[dim] = y.shape[dim] + x.shape[dim]
tt_cpu = torch.concat([x, y], dim)
assert tt_cpu.shape == torch.Size(output_shape)

tt = ttl.tensor.concat([xtt[0], xtt[1]], dim)
assert list(tt.get_legacy_shape()) == output_shape
tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

passing, output = comp_equal(tt_cpu, tt_dev)
logger.info(output)
assert passing
@pytest.mark.parametrize(
"layout, dtype",
(
(ttl.tensor.Layout.TILE, ttl.tensor.DataType.BFLOAT16),
(ttl.tensor.Layout.TILE, ttl.tensor.DataType.BFLOAT8_B),
(ttl.tensor.Layout.ROW_MAJOR, ttl.tensor.DataType.BFLOAT16),
),
)
@pytest.mark.parametrize(
"input_mem_config",
(
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
),
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.L1,
),
),
)
@pytest.mark.parametrize(
"output_mem_config",
(
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
),
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.L1,
),
),
)
def test_concat(shapes, dim, device, layout, dtype, input_mem_config, output_mem_config, function_level_defaults):
run_concat(shapes, dim, device, layout, dtype, input_mem_config, output_mem_config)


@pytest.mark.parametrize(
"shapes, dim",
(
(((1, 2, 64, 64),), -1),
(((1, 1, 64, 64), (1, 1, 128, 64)), -2),
(((1, 1, 32, 128), (1, 1, 32, 64), (1, 1, 32, 256)), -1),
(((2, 4, 32, 1280), (2, 3, 32, 1280), (2, 5, 32, 1280), (2, 8, 32, 1280)), 1),
(((2, 4, 32, 1280), (2, 4, 32, 1280)), 3),
# SD Shapes
(((2, 1280, 4, 4), (2, 1280, 4, 4)), 1),
(((2, 320, 32, 32), (2, 320, 32, 32)), 1),
(((2, 1280, 8, 8), (2, 1280, 8, 8)), 1),
),
)
def test_multi_input_concat(shapes, dim, device, function_level_defaults):
inputs = []
tt_inputs = []
for i in range(len(shapes)):
shape = torch.Size(shapes[i])
inputs.append(i + torch.arange(0, shape.numel()).reshape(shape).to(torch.bfloat16))
tt_inputs.append(
ttl.tensor.Tensor(
inputs[i],
ttl.tensor.DataType.BFLOAT16,
).to(device)
)

tt_cpu = torch.concat(inputs, dim)

tt = ttl.tensor.concat(tt_inputs, dim)

tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

passing, output = comp_equal(tt_cpu, tt_dev)
logger.info(output)
assert passing
@pytest.mark.parametrize(
"layout, dtype",
(
(ttl.tensor.Layout.TILE, ttl.tensor.DataType.BFLOAT16),
(ttl.tensor.Layout.TILE, ttl.tensor.DataType.BFLOAT8_B),
(ttl.tensor.Layout.ROW_MAJOR, ttl.tensor.DataType.BFLOAT16),
),
)
@pytest.mark.parametrize(
"input_mem_config",
(
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
),
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.L1,
),
),
)
@pytest.mark.parametrize(
"output_mem_config",
(
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
),
ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.L1,
),
),
)
def test_concat_with_program_cache(
shapes, dim, device, layout, dtype, input_mem_config, output_mem_config, use_program_cache, function_level_defaults
):
run_concat(shapes, dim, device, layout, dtype, input_mem_config, output_mem_config)
tmp = ttl.tensor.empty([1, 256, 32, 32], ttl.tensor.DataType.BFLOAT16, ttl.tensor.Layout.TILE, device)
run_concat(shapes, dim, device, layout, dtype, input_mem_config, output_mem_config)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo
TT_ASSERT((shard_shape[0] % TILE_HEIGHT == 0 && shard_shape[1] % TILE_WIDTH == 0), "Shard shape must be tile sized");
} else if (layout == Layout::ROW_MAJOR) {
// Require alignment for now
TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes_wrapper(data_type) % 32 == 0);
TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes_wrapper(data_type) % ADDRESS_ALIGNMENT == 0);
}

auto element_size = tensor_impl::element_size_bytes_wrapper(data_type);
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/all_gather/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void AllGather::validate(const std::vector<Tensor> &input_tensors) const {
const auto& dtype = input_tensors[0].get_dtype();
const auto& page_size = input_tensors[0].buffer()->page_size();
TT_FATAL(page_size <= all_gather_buffer_params::eth_buffer_size, "Page size too large");
TT_FATAL(page_size % 32 == 0, "All Gather currently requires aligned pages");
TT_FATAL(page_size % ADDRESS_ALIGNMENT == 0, "All Gather currently requires aligned pages");

// TODO: This can be removed by passing two page sizes, actual and aligned to be used for address offsets
// Buffer sizes also need to take this aligned page size into consideration
Expand Down
42 changes: 25 additions & 17 deletions tt_eager/tt_dnn/op_library/concat/concat_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ ConcatOpParallelizationStrategy Concat::get_parallelization_strategy(const std::
if (input_tensors[0].is_sharded()) {
return ConcatOpParallelizationStrategy::SHARDED_MULTI_CORE;
} else {
uint32_t num_tiles = tt_metal::compute_volume(this->compute_output_shapes(input_tensors).at(0)) / TILE_HW;
if (num_tiles > 1) {
uint32_t num_pages = tt_metal::compute_volume(this->compute_output_shapes(input_tensors).at(0));
if (input_tensors[0].get_layout() == Layout::ROW_MAJOR) {
num_pages /= input_tensors[0].get_legacy_shape()[-1];
} else {
num_pages /= TILE_HW;
}
if (num_pages > 1) {
return ConcatOpParallelizationStrategy::MULTI_CORE;
} else {
return ConcatOpParallelizationStrategy::SINGLE_CORE;
Expand All @@ -30,7 +35,8 @@ ConcatOpParallelizationStrategy Concat::get_parallelization_strategy(const std::
void Concat::validate(const std::vector<Tensor> &input_tensors) const {
const auto &first_input = input_tensors[0];
tt::tt_metal::Shape shape_first = first_input.get_legacy_shape();
shape_first[dim] = 0;
TT_FATAL(this->dim < shape_first.rank(), "Concat dim specified is larger than input tensor rank.");
shape_first[this->dim] = 0;
bool shard_first = input_tensors[0].is_sharded();

for (const Tensor &in_ref : input_tensors) {
Expand All @@ -40,27 +46,31 @@ void Concat::validate(const std::vector<Tensor> &input_tensors) const {
TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts.");
TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes.");
tt::tt_metal::Shape curr_shape = in_ref.get_legacy_shape();
curr_shape[dim] = 0;
TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal");
curr_shape[this->dim] = 0;
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
if (in_ref.get_layout() == Layout::ROW_MAJOR && this->dim == shape_first.rank() - 1) {
TT_FATAL(
(in_ref.get_legacy_shape()[this->dim] * in_ref.element_size()) % ADDRESS_ALIGNMENT == 0,
"Current concat implementation requires aligned last dim when concatting on last dim");
}
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
if (shard_first) {
TT_FATAL((in_ref.get_layout() == Layout::ROW_MAJOR), "Only row major supported for sharded concat.");
} else {
TT_FATAL((in_ref.get_layout() == Layout::TILE), "Only tile layout supported.");
}
}
if (shard_first) {
TT_FATAL(dim == 3, "Only width concat on sharded tensors");
TT_FATAL(this->dim == shape_first.rank() - 1, "Only width concat on sharded tensors");
TT_FATAL(this->output_mem_config.is_sharded(), "Output must be sharded if input is sharded");
}
}

std::vector<tt::tt_metal::Shape> Concat::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
tt::tt_metal::Shape shape_out = input_tensors[0].get_legacy_shape();
shape_out[dim] = 0;
shape_out[this->dim] = 0;
for (const Tensor &in_ref : input_tensors) {
tt::tt_metal::Shape curr_shape = in_ref.get_legacy_shape();
shape_out[dim] += curr_shape[dim];
shape_out[this->dim] += curr_shape[this->dim];
}
return {shape_out};
}
Expand Down Expand Up @@ -93,7 +103,7 @@ operation::ProgramWithCallbacks Concat::create_program(
};
}

Tensor concat(std::vector<Tensor> &input_tensors, std::int64_t dim, const MemoryConfig &output_mem_config) {
Tensor concat(std::vector<Tensor> &input_tensors, const std::int64_t dim, const MemoryConfig &output_mem_config) {
TT_FATAL(input_tensors.size() > 0, "need 1 or more tensors");
if (input_tensors.size() == 1) {
return AutoFormat::move_tensor_to_mem_config(input_tensors[0], output_mem_config);
Expand All @@ -104,16 +114,14 @@ Tensor concat(std::vector<Tensor> &input_tensors, std::int64_t dim, const Memory
if (input_tensors[0].is_sharded()) {
return operation::run(Concat{normalized_dim, output_mem_config}, {input_tensors}).at(0);
} else {
if (normalized_dim == ref_rank - 1) {
for (const auto &input_tensor : input_tensors) {
TT_FATAL(input_tensor.get_legacy_shape()[dim] % TILE_WIDTH == 0, "Current concat implementation requires tile sized last dim when concatting on last dim");
}
} else if (normalized_dim == ref_rank - 2) {
if (input_tensors[0].get_layout() == Layout::ROW_MAJOR && normalized_dim == ref_rank - 1) {
for (const auto &input_tensor : input_tensors) {
TT_FATAL(input_tensor.get_legacy_shape()[dim] % TILE_HEIGHT == 0, "Current concat implementation requires tile sized second last dim when concatting on second last dim");
TT_FATAL(
(input_tensor.get_legacy_shape()[dim] * input_tensor.element_size()) % ADDRESS_ALIGNMENT == 0,
"Current concat implementation requires aligned last dim when concatting on last dim");
}
}
return operation::run_with_autoformat(Concat{normalized_dim}, {input_tensors}).at(0);
return operation::run_without_autoformat(Concat{normalized_dim}, {input_tensors}).at(0);
}
}

Expand Down
18 changes: 11 additions & 7 deletions tt_eager/tt_dnn/op_library/concat/concat_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ namespace tt {

namespace tt_metal {

enum class ConcatOpParallelizationStrategy {
SINGLE_CORE = 0, MULTI_CORE = 1, SHARDED_MULTI_CORE = 2
};
enum class ConcatOpParallelizationStrategy { SINGLE_CORE = 0, MULTI_CORE = 1, SHARDED_MULTI_CORE = 2 };

struct Concat {
uint32_t dim;
Expand All @@ -31,13 +29,19 @@ struct Concat {
}
};

operation::ProgramWithCallbacks sharded_concat_multi_core(const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output);
operation::ProgramWithCallbacks concat_multi_core(const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output);
operation::ProgramWithCallbacks concat_single_core(const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output);
operation::ProgramWithCallbacks sharded_concat_multi_core(
const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output);
operation::ProgramWithCallbacks concat_multi_core(
const std::vector<Tensor> &input_tensors, const uint32_t dim, const Tensor &output);
operation::ProgramWithCallbacks concat_single_core(
const std::vector<Tensor> &input_tensors, const uint32_t dim, const Tensor &output);

// Ref: https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
// Notes: Non-empty tensors provided must have the same shape, except in the cat dimension.
Tensor concat(std::vector<Tensor> &input_tensors, std::int64_t dim = 0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
Tensor concat(
std::vector<Tensor> &input_tensors,
const std::int64_t dim = 0,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

Expand Down
Loading

0 comments on commit a6ec598

Please sign in to comment.