Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Jan 13, 2025
1 parent 5c25a41 commit 0e874ec
Showing 1 changed file with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ void Untilize::validate(const std::vector<Tensor>& input_tensors) const {
TT_FATAL(this->use_multicore == true, "Error");
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error");
uint32_t ntiles = input_tensor_a.volume() / TILE_HW;
uint32_t ntiles_per_block = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH;
uint32_t ntiles_per_block = input_tensor_a.get_padded_shape()[-1] / TILE_WIDTH;
uint32_t nblocks = std::ceil((float)ntiles / ntiles_per_block);
auto num_cores =
untilize_helpers::get_num_cores(input_tensor_a.device()->compute_with_storage_grid_size(), nblocks);
uint32_t fused_height = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1];
uint32_t fused_height = input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1];
TT_FATAL(fused_height % num_cores == 0, "Error");
} else {
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error");
Expand All @@ -61,6 +61,7 @@ void Untilize::validate(const std::vector<Tensor>& input_tensors) const {
}

std::vector<ttnn::TensorSpec> Untilize::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
using namespace tt::constants;
const auto& input_tensor = input_tensors.at(0);
DataType output_dtype =
input_tensor.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor.get_dtype();
Expand All @@ -79,14 +80,14 @@ std::vector<ttnn::TensorSpec> Untilize::compute_output_specs(const std::vector<T
}

uint32_t ntiles = input_tensor.volume() / TILE_HW;
uint32_t ntiles_per_block = input_tensor.get_legacy_shape()[-1] / TILE_WIDTH;
uint32_t ntiles_per_block = input_tensor.get_padded_shape()[-1] / TILE_WIDTH;
uint32_t nblocks = std::ceil((float)ntiles / ntiles_per_block);
auto num_cores =
untilize_helpers::get_num_cores(input_tensor.device()->compute_with_storage_grid_size(), nblocks);
auto shard_grid = tt::tt_metal::num_cores_to_corerangeset(
num_cores, input_tensor.device()->compute_with_storage_grid_size(), true);
uint32_t fused_height = input_tensor.volume() / input_tensor.get_legacy_shape()[-1];
std::array<uint32_t, 2> shard_shape = {fused_height / num_cores, input_tensor.get_legacy_shape()[-1]};
uint32_t fused_height = input_tensor.volume() / input_tensor.get_padded_shape()[-1];
std::array<uint32_t, 2> shard_shape = {fused_height / num_cores, input_tensor.get_padded_shape()[-1]};
ShardSpec shard_spec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
Expand Down

0 comments on commit 0e874ec

Please sign in to comment.