Skip to content

Commit

Permalink
Further removal of Shape/LegacyShape
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Dec 27, 2024
1 parent 2a86ff7 commit fda5027
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 339 deletions.
45 changes: 21 additions & 24 deletions ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ void Pad::validate_with_output_tensors(
"On device padding only supports padding at end of dims");
}
TT_FATAL(
input_tensor.get_legacy_shape()[0] + this->input_tensor_start[0] <= this->output_tensor_shape[0],
input_tensor.get_padded_shape()[0] + this->input_tensor_start[0] <= this->output_padded_shape[0],
"Output size cannot fit input with offset");
TT_FATAL(
input_tensor.get_legacy_shape()[1] + this->input_tensor_start[1] <= this->output_tensor_shape[1],
input_tensor.get_padded_shape()[1] + this->input_tensor_start[1] <= this->output_padded_shape[1],
"Output size cannot fit input with offset");
TT_FATAL(
input_tensor.get_legacy_shape()[2] + this->input_tensor_start[2] <= this->output_tensor_shape[2],
input_tensor.get_padded_shape()[2] + this->input_tensor_start[2] <= this->output_padded_shape[2],
"Output size cannot fit input with offset");
TT_FATAL(
input_tensor.get_legacy_shape()[3] + this->input_tensor_start[3] <= this->output_tensor_shape[3],
input_tensor.get_padded_shape()[3] + this->input_tensor_start[3] <= this->output_padded_shape[3],
"Output size cannot fit input with offset");

if (input_tensor.get_layout() == Layout::TILE) {
TT_FATAL((this->output_tensor_shape[2] % TILE_HEIGHT == 0), "Can only pad tilized tensor with full tiles");
TT_FATAL((this->output_tensor_shape[3] % TILE_WIDTH == 0), "Can only pad tilized tensor with full tiles");
TT_FATAL((this->output_padded_shape[2] % TILE_HEIGHT == 0), "Can only pad tilized tensor with full tiles");
TT_FATAL((this->output_padded_shape[3] % TILE_WIDTH == 0), "Can only pad tilized tensor with full tiles");
TT_FATAL(
input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16,
"Cannot pad tilized tensor with specified format");
Expand All @@ -62,19 +62,16 @@ void Pad::validate_with_output_tensors(
}
}

std::vector<ttnn::SimpleShape> Pad::compute_output_shapes(const std::vector<Tensor>&) const {
return {this->output_tensor_shape.logical_shape()};
}

std::vector<Tensor> Pad::create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
std::vector<ttnn::TensorSpec> Pad::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {create_device_tensor(
output_tensor_shape,
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
this->output_mem_config)};
return {TensorSpec(
output_logical_shape,
TensorLayout::fromPaddedShape(
input_tensor.get_dtype(),
PageConfig(input_tensor.get_layout()),
output_mem_config,
output_logical_shape,
output_padded_shape))};
}

operation::ProgramWithCallbacks Pad::create_program(
Expand Down Expand Up @@ -104,22 +101,22 @@ operation::ProgramWithCallbacks Pad::create_program(
return {};
} else if (input_w != output_w) {
return detail::pad_rm_sharded_width_only(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
} else if (input_tot_h != output_tot_h) {
return detail::pad_rm_sharded_height_only(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
} else {
// for no padding, we just use the height-only padding program
return detail::pad_rm_sharded_height_only(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
}
} else {
if (use_multicore) {
return detail::pad_rm_reader_writer_multi_core_v2(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
} else {
return detail::pad_rm_reader_writer(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
}
}
} else if (input_tensor.get_layout() == Layout::TILE) {
Expand All @@ -128,7 +125,7 @@ operation::ProgramWithCallbacks Pad::create_program(
tt::LogType::LogOp, "TILE layout does not have multicore implementation yet. Falling back to 1 core.");
}
return detail::pad_tile(
input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value);
input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value);
} else {
TT_THROW("Unsupported layout for pad");
return {};
Expand Down
17 changes: 11 additions & 6 deletions ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,29 @@
namespace ttnn::operations::data_movement {

struct Pad {
const tt::tt_metal::LegacyShape output_tensor_shape;
const ttnn::SimpleShape output_logical_shape;
const ttnn::SimpleShape output_padded_shape;
const ttnn::SimpleShape input_tensor_start;
const float pad_value;
const tt::tt_metal::MemoryConfig output_mem_config;
const bool use_multicore;

void validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
static constexpr auto attribute_names = std::forward_as_tuple(
"output_tensor_shape", "input_tensor_start", "pad_value", "output_mem_config", "use_multicore");
"output_logical_shape",
"output_padded_shape",
"input_tensor_start",
"pad_value",
"output_mem_config",
"use_multicore");
const auto attribute_values() const {
return std::forward_as_tuple(
this->output_tensor_shape,
this->output_logical_shape,
this->output_padded_shape,
this->input_tensor_start,
this->pad_value,
this->output_mem_config,
Expand Down
Loading

0 comments on commit fda5027

Please sign in to comment.