Skip to content

Commit

Permalink
Port all data movements ops to compute_output_specs
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Jan 13, 2025
1 parent 1a7e545 commit 5c25a41
Show file tree
Hide file tree
Showing 48 changed files with 294 additions and 342 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(
TT_FATAL(is_floating_point(input_tensor_a.get_dtype()), "Unsupported data format");
if (!output_tensors.empty() && output_tensors.at(0).has_value()) {
TT_FATAL(is_floating_point(output_tensors.at(0).value().get_dtype()), "Unsupported data format");
const std::vector<ttnn::SimpleShape> output_shape_required = this->compute_output_shapes(input_tensors);
const auto output_spec_required = this->compute_output_specs(input_tensors, output_tensors);
const auto& out_tensor = output_tensors.at(0).value();
TT_FATAL(
out_tensor.get_logical_shape() == output_shape_required.at(0),
out_tensor.get_logical_shape() == output_spec_required.at(0).logical_shape(),
"The input tensors need a shape of {}, however the output tensor is only {}",
output_shape_required,
output_spec_required.at(0).logical_shape(),
out_tensor.get_legacy_shape());
}
if (this->in_place) {
Expand Down Expand Up @@ -122,16 +122,10 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(
}
}

std::vector<ttnn::SimpleShape> EltwiseBinaryBroadcast::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor.get_logical_shape()};
}

std::vector<Tensor> EltwiseBinaryBroadcast::create_output_tensors(
std::vector<ttnn::TensorSpec> EltwiseBinaryBroadcast::compute_output_specs(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
if (!output_tensors.empty() && output_tensors.at(0).has_value()) {
return {output_tensors.at(0).value()};
return {output_tensors.at(0)->get_tensor_spec()};
}
if (this->in_place) {
return {};
Expand All @@ -145,16 +139,36 @@ std::vector<Tensor> EltwiseBinaryBroadcast::create_output_tensors(
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_device_tensor(
input_tensor.get_legacy_shape(),
return {TensorSpec(
input_tensor.get_logical_shape(),
TensorLayout::fromPaddedShape(
input_tensor.get_dtype(),
PageConfig(Layout::TILE),
mem_config,
input_tensor.get_logical_shape(),
input_tensor.get_padded_shape()))};
}

return {TensorSpec(
input_tensor.get_logical_shape(),
TensorLayout::fromPaddedShape(
input_tensor.get_dtype(),
Layout::TILE,
input_tensor.device(),
mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
PageConfig(Layout::TILE),
output_mem_config,
input_tensor.get_logical_shape(),
input_tensor.get_padded_shape()))};
}

std::vector<Tensor> EltwiseBinaryBroadcast::create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
if (!output_tensors.empty() && output_tensors.at(0).has_value()) {
return {output_tensors.at(0).value()};
}
if (this->in_place) {
return {};
}
auto spec = compute_output_specs(input_tensors, output_tensors)[0];
return {create_device_tensor(spec, input_tensors.at(0).device())};
}

operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ struct EltwiseBinaryBroadcast {

void validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,18 @@ void CloneOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
};

CloneOperation::shape_return_value_t CloneOperation::compute_output_shapes(
CloneOperation::spec_return_value_t CloneOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
const auto& input = tensor_args.input;
return TensorSpec(
input.get_logical_shape(),
TensorLayout(operation_attributes.dtype, PageConfig(input.get_layout()), operation_attributes.memory_config));
};

CloneOperation::tensor_return_value_t CloneOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& input = tensor_args.input;
return create_device_tensor(
compute_output_shapes(operation_attributes, tensor_args),
operation_attributes.dtype,
input.get_layout(),
input.device(),
operation_attributes.memory_config);
auto spec = compute_output_specs(operation_attributes, tensor_args);
return create_device_tensor(spec, tensor_args.input.device());
}

std::tuple<CloneOperation::operation_attributes_t, CloneOperation::tensor_args_t> CloneOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct CloneOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -50,7 +50,7 @@ struct CloneOperation {
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,19 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor>& input_tensors) c
}
}

std::vector<ttnn::SimpleShape> ConcatDeviceOperation::compute_output_shapes(
std::vector<ttnn::TensorSpec> ConcatDeviceOperation::compute_output_specs(
const std::vector<Tensor>& input_tensors) const {
ttnn::SimpleShape shape_out = input_tensors[0].get_logical_shape();
const Tensor& ref_in_tensor = input_tensors.at(0);
ttnn::SimpleShape shape_out = ref_in_tensor.get_logical_shape();
shape_out[this->dim] = 0;
for (const Tensor& in_ref : input_tensors) {
ttnn::SimpleShape curr_shape = in_ref.get_logical_shape();
shape_out[this->dim] += curr_shape[this->dim];
}
return {shape_out};
}

std::vector<Tensor> ConcatDeviceOperation::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const Tensor& ref_in_tensor = input_tensors.at(0);

if (this->output_mem_config.is_sharded()) {
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
ref_in_tensor.get_dtype(),
ref_in_tensor.get_layout(),
ref_in_tensor.device(),
this->output_mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, ref_in_tensor.get_dtype(), ref_in_tensor.get_layout(), this->output_mem_config);
}
return {TensorSpec(
shape_out,
TensorLayout(ref_in_tensor.get_dtype(), PageConfig(ref_in_tensor.get_layout()), this->output_mem_config))};
}

operation::ProgramWithCallbacks ConcatDeviceOperation::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ struct ConcatDeviceOperation {
unsigned int groups;
const tt::tt_metal::MemoryConfig output_mem_config;
void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
ConcatOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor>& input_tensors) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,24 @@ void CopyDeviceOperation::validate_with_output_tensors(
out_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Copy does not currently support sharding");
}

std::vector<ttnn::SimpleShape> CopyDeviceOperation::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::TensorSpec> CopyDeviceOperation::compute_output_specs(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
if (!output_tensors.empty() && output_tensors[0].has_value()) {
return {output_tensors[0]->get_tensor_spec()};
}
if (input_tensors.size() == 2) {
return {input_tensors[1].get_logical_shape()};
} else {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor.get_logical_shape()};
return {input_tensors[1].get_tensor_spec()};
}

const auto& input_tensor = input_tensors.at(0);
return {TensorSpec(
input_tensor.get_logical_shape(),
TensorLayout::fromPaddedShape(
output_dtype,
PageConfig(input_tensor.get_layout()),
output_mem_config,
input_tensor.get_logical_shape(),
input_tensor.get_padded_shape()))};
}

std::vector<Tensor> CopyDeviceOperation::create_output_tensors(
Expand All @@ -77,17 +87,10 @@ std::vector<Tensor> CopyDeviceOperation::create_output_tensors(
}
if (input_tensors.size() == 2) {
return {input_tensors[1]};
} else {
const auto& input_tensor = input_tensors.at(0);
std::vector<Tensor> output_tensors;
output_tensors.emplace_back(create_device_tensor(
input_tensor.get_legacy_shape(),
output_dtype,
input_tensors.at(0).get_layout(),
input_tensor.device(),
output_mem_config));
return output_tensors;
}
const auto& input_tensor = input_tensors.at(0);
auto spec = compute_output_specs(input_tensors, output_tensors)[0];
return {create_device_tensor(spec, input_tensor.device())};
}

operation::ProgramWithCallbacks CopyDeviceOperation::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct CopyDeviceOperation {

void validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;

std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,17 @@ void ExpandOperation::validate_on_program_cache_hit(
validate(operation_attributes, tensor_args);
};

ExpandOperation::shape_return_value_t ExpandOperation::compute_output_shapes(
ExpandOperation::spec_return_value_t ExpandOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return SimpleShape{operation_attributes.output_shape};
if (tensor_args.output.has_value()) {
return tensor_args.output->get_tensor_spec();
}
return TensorSpec(
SimpleShape{operation_attributes.output_shape},
TensorLayout(
tensor_args.input.get_dtype(),
PageConfig(tensor_args.input.get_layout()),
operation_attributes.memory_config));
};

ExpandOperation::tensor_return_value_t ExpandOperation::create_output_tensors(
Expand All @@ -63,12 +71,7 @@ ExpandOperation::tensor_return_value_t ExpandOperation::create_output_tensors(
return {tensor_args.output.value()};
}

return create_device_tensor(
compute_output_shapes(operation_attributes, tensor_args),
tensor_args.input.get_dtype(),
tensor_args.input.get_layout(),
tensor_args.input.device(),
operation_attributes.memory_config);
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
}

std::tuple<ExpandOperation::operation_attributes_t, ExpandOperation::tensor_args_t> ExpandOperation::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct ExpandOperation {
const std::optional<Tensor>& output;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ExpandRowMajorFactory {
Expand Down Expand Up @@ -52,7 +52,7 @@ struct ExpandOperation {
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,11 @@ void FillRM::validate(const std::vector<Tensor>& input_tensors) const {
"FillRM does not currently support sharding");
}

std::vector<ttnn::SimpleShape> FillRM::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
return {ttnn::SimpleShape({this->N, this->C, this->H, this->W})};
}

std::vector<Tensor> FillRM::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::TensorSpec> FillRM::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
ttnn::SimpleShape shape({this->N, this->C, this->H, this->W});
const auto& input_tensor = input_tensors.at(0);
return operation::generic_create_output_tensors(
*this, input_tensors, input_tensor.get_dtype(), Layout::ROW_MAJOR, this->output_mem_config);
return {
TensorSpec(shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::ROW_MAJOR), output_mem_config))};
}

operation::ProgramWithCallbacks FillRM::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ struct FillRM {
const tt::tt_metal::MemoryConfig output_mem_config;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,34 @@ void Fold::validate_on_program_cache_hit(const operation_attributes_t& op_attr,
return validate_fold({tensors.input_tensor}, op_attr.is_sharded, op_attr.stride_h, op_attr.stride_w);
}

Fold::shape_return_value_t Fold::compute_output_shapes(
Fold::spec_return_value_t Fold::compute_output_specs(
const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
auto input_tensor = tensors.input_tensor;
const ttnn::SimpleShape input_shape = input_tensor.get_logical_shape();
// we concatenate (stride_h sticks in H-dim) * (stride_w in W-dim) into 1 stick along C-dim
return ttnn::SimpleShape(
ttnn::SimpleShape output_shape(
{1,
1,
input_shape[0] * input_shape[1] * input_shape[2] / (op_attr.stride_h * op_attr.stride_w),
input_shape[3] * op_attr.stride_h * op_attr.stride_w});
}

Fold::tensor_return_value_t Fold::create_output_tensors(
const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
const Tensor& input_tensor = tensors.input_tensor;
DataType output_dtype = input_tensor.get_dtype();

auto output_shape = compute_output_shapes(op_attr, tensors);

if (op_attr.is_sharded) {
MemoryConfig mem_config = input_tensor.memory_config();
mem_config.shard_spec->shape[0] /= op_attr.stride_h * op_attr.stride_w;
mem_config.shard_spec->shape[1] *= op_attr.stride_h * op_attr.stride_w;

return {create_device_tensor(
output_shape, output_dtype, input_tensor.get_layout(), input_tensor.device(), mem_config)};
} else {
return {create_device_tensor(
output_shape, output_dtype, Layout::ROW_MAJOR, input_tensor.device(), input_tensor.memory_config())};
return {TensorSpec(
output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), mem_config))};
}

return {TensorSpec(
output_shape,
TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::ROW_MAJOR), input_tensor.memory_config()))};
}

Fold::tensor_return_value_t Fold::create_output_tensors(
const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return create_device_tensor(compute_output_specs(op_attr, tensors), tensors.input_tensor.device());
}

std::tuple<Fold::operation_attributes_t, Fold::tensor_args_t> Fold::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Fold {
const Tensor& input_tensor;
};

using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct SingleCore {
Expand Down Expand Up @@ -73,7 +73,7 @@ struct Fold {
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Loading

0 comments on commit 5c25a41

Please sign in to comment.