diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index 8b399bc5ca5a..38b17b0296e0 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -72,35 +72,6 @@ uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num1, uint32_t n return divisor; } -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype) { - return tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( - std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); -} - -// Converts convolution weights to tilized 2d matrix layout with special block height padding -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype) { - return tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( - std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); -} - -// Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout( - const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { - return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout( - std::move(conv_weight_tensor), num_groups, output_dtype); -} - ParallelConfig determine_parallel_config_non_tile_mul_width( const TensorMemoryLayout shard_layout, uint32_t batch_size, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 59e5e27a0c04..9a5758872c20 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -171,26 +171,6 @@ shard_or_reshard_tensor_if_required( bool auto_shard, bool is_non_tile_mul_width = false); -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to tilized 2d matrix layout with special block height padding -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout( - const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); - std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config); } // namespace operations::conv diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 7e68416803ab..a3f39ce5c775 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -21,6 +21,475 @@ using sliding_window::SlidingWindowConfig; namespace conv2d { +template +Tensor convert_tensor(const Tensor& input_tensor, compute_& compute) { + auto convert_tensor = [&compute](const auto& input_tensor) { + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + input_tensor.get_storage()); + }; + + return ttnn::distributed::is_multi_device_tensor(input_tensor) ? transform(input_tensor, convert_tensor) + : convert_tensor(input_tensor); +} + +template +Tensor convert_tensor_to_tiled_layout_common( + const Tensor& input_tensor, + std::optional output_dtype, + const std::unordered_map& function_map, + Args&&... args) { + TT_ASSERT( + input_tensor.get_layout() == Layout::ROW_MAJOR && + "Tensor(weight/bias) should be in row major layout for conversion to tilized layout."); + + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(input_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(input_tensor.get_dtype() == input_tensor.get_dtype()); + } + } + auto entry = function_map.find(input_tensor.get_dtype()); + if (entry == function_map.end()) { + TT_THROW("Unsupported data type"); + } + return entry->second(input_tensor, std::forward(args)..., output_dtype.value_or(input_tensor.get_dtype())); +} + +template +Tensor create_tensor_from_owned_buffer( + owned_buffer::Buffer& buf, DataType& output_dtype, ttnn::SimpleShape& output_shape) { + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + auto tensor = + Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, DataType::FLOAT32, Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + output_dtype == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + } else { + TT_FATAL( + (output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B), + "Unsupported output datatype"); + } + auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); +} + +template +Tensor to_weight_special_padding_tile_layout( + const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + auto weight_matrix_cols = w_shape[0]; + // width padding + if (weight_matrix_cols % in1_block_w_datums != 0) { + weight_matrix_cols = + (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; + } + // height padding + assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); + uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); + auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c = 0; c < w_shape[1]; c++) { + for (auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + + r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; + auto idx = + k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + } + } + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); + }; + return convert_tensor(conv_weight_tensor, compute); +} + +template +Tensor to_weight_tile_layout( + const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + auto weight_matrix_cols = w_shape[0]; + // width padding + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + if (weight_matrix_cols % in1_block_w_datums != 0) { + weight_matrix_cols = + (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; + } + // height padding + auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + if (weight_matrix_rows % in1_block_h_datums != 0) { + weight_matrix_rows = + (uint32_t)std::ceil((double)weight_matrix_rows / (double)in1_block_h_datums) * in1_block_h_datums; + } + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c = 0; c < w_shape[1]; c++) { + for (auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + + r * w_shape[3] * w_shape[1] * weight_matrix_cols; + auto idx = + k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + } + } + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); + }; + + return convert_tensor(conv_weight_tensor, compute); +} + +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout( + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype) { + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_tile_layout}, + {DataType::FLOAT32, &to_weight_tile_layout}, + {DataType::UINT32, &to_weight_tile_layout}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); +} + +template +Tensor to_weight_tile_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { + auto weight_matrix_cols = w_shape[0]; + TT_ASSERT(weight_matrix_cols % num_channel_shards == 0); + auto conv_output_shard_width = weight_matrix_cols / num_channel_shards; + auto conv_output_shard_width_padded = + (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * + constants::TILE_WIDTH; + if (conv_output_shard_width < conv_output_shard_width_padded) { + // width padding for conv output shard padding + weight_matrix_cols = conv_output_shard_width_padded * num_channel_shards; + } + + auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; + TT_ASSERT(w_shape[1] % num_channel_shards == 0); + auto conv_input_shard_width = w_shape[1] / num_channel_shards; + auto weight_block_height = conv_input_shard_width * w_shape[2] * w_shape[3]; + auto weight_block_height_padded = + (uint32_t)std::ceil((double)weight_block_height / (double)constants::TILE_HEIGHT) * constants::TILE_HEIGHT; + if (weight_block_height < weight_block_height_padded) { + // height padding for non tile multiple block height + weight_matrix_rows = weight_block_height_padded * num_channel_shards; + } + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto ic = 0; ic < num_channel_shards; ic++) { + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c_s = 0; c_s < conv_input_shard_width; c_s++) { + for (auto oc = 0; oc < num_channel_shards; oc++) { + for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { + auto matrix_idx = (oc * conv_output_shard_width_padded + k_s) + + c_s * weight_matrix_cols + + s * conv_input_shard_width * weight_matrix_cols + + r * w_shape[3] * conv_input_shard_width * weight_matrix_cols + + ic * weight_block_height_padded * weight_matrix_cols; + auto idx = (oc * conv_output_shard_width + k_s) * w_shape[1] * w_shape[2] * w_shape[3] + + (ic * conv_input_shard_width + c_s) * w_shape[2] * w_shape[3] + + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + } + } + } + } + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); + }; + return convert_tensor(conv_weight_tensor, compute); +} + +// Converts convolution weights to tilized 2d matrix layout for block sharded conv. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype) { + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_weight_tile_layout_block_sharded}, + {DataType::UINT32, &to_weight_tile_layout_block_sharded}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, num_channel_shards); +} + +template +Tensor to_bias_tile_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, DataType output_dtype) { + auto b_shape = conv_bias_tensor.get_legacy_shape(); + TT_ASSERT(b_shape[0] == 1 && b_shape[1] == 1 && b_shape[2] == 1); + auto compute = [&b_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { + auto bias_matrix_cols = b_shape[3]; + /*TT_ASSERT(bias_matrix_cols % num_channel_shards == 0);*/ + auto conv_output_shard_width = bias_matrix_cols / num_channel_shards; + auto conv_output_shard_width_padded = + (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * + constants::TILE_WIDTH; + if (conv_output_shard_width < conv_output_shard_width_padded) { + // width padding for conv output shard padding + bias_matrix_cols = conv_output_shard_width_padded * num_channel_shards; + } + + auto bias_matrix_rows = 32; + ttnn::SimpleShape output_shape{1, 1, bias_matrix_rows, bias_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto oc = 0; oc < num_channel_shards; oc++) { + for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { + auto matrix_idx = oc * conv_output_shard_width_padded + k_s; + auto idx = oc * conv_output_shard_width + k_s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); + }; + + return convert_tensor(conv_bias_tensor, compute); +} + +// Converts convolution bias to tilized 2d matrix layout for block sharded conv. +// Returns a new tensor with layout=Tile +Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype) { + const static std::unordered_map< + DataType, + std::function> + to_b_tile_layout_map = { + {DataType::BFLOAT16, &to_bias_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_bias_tile_layout_block_sharded}, + {DataType::UINT32, &to_bias_tile_layout_block_sharded}, + }; + return convert_tensor_to_tiled_layout_common( + conv_bias_tensor, output_dtype, to_b_tile_layout_map, num_channel_shards); +} + +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype) { + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_special_padding_tile_layout}, + {DataType::FLOAT32, &to_weight_special_padding_tile_layout}, + {DataType::UINT32, &to_weight_special_padding_tile_layout}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); +} + +/* +Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels +*/ +template +static Tensor conv_group_weight_zero_pad_helper( + const Tensor& weight, + const ttnn::SimpleShape& original_weight_shape, + const ttnn::SimpleShape& output_weight_shape, + uint32_t num_groups, + DataType output_dtype) { + auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype]( + const auto& conv_weight_tensor_buffer) { + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); + for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { + int new_batch_idx = curr_batch_idx; + + // Find which group_id the filter belongs to - through this, we can compute the offset where the padding + // should be applied + auto group_size = original_weight_shape[0] / num_groups; + auto group_index = curr_batch_idx / group_size; + auto group_id = std::min(group_index, num_groups - 1); + int new_channel_start_idx = group_id * original_weight_shape[1]; + + for (int j = 0; j < original_weight_shape[1]; j++) { + for (int k = 0; k < original_weight_shape[2]; k++) { + for (int m = 0; m < original_weight_shape[3]; m++) { + // Get value from original weight tensor + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); + auto value = conv_weight_tensor_buffer[value_flat_input_index]; + + // Copy value to output tensor at the adjusted position + auto new_channel_idx = new_channel_start_idx + j; + auto output_flat_input_index = compute_flat_indices( + ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, + compute_strides(output_weight_shape)); + output_buffer[output_flat_input_index] = value; + } + } + } + } + return Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); + }; + + return convert_tensor(weight, pad_weight); +} + +/* +Helper function to aid in converting depthwise weight tensor to broadcasted weight tensor with repeated input channels +*/ +template +static Tensor conv_depthwise_weight_bcast_helper( + const Tensor& conv_weight_tensor, + const ttnn::SimpleShape& original_weight_shape, + const ttnn::SimpleShape& output_weight_shape, + DataType output_dtype) { + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); + auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); + // Copy the original weight tensor to the output tensor + for (int i = 0; i < output_weight_shape[0]; i++) { + for (int j = 0; j < output_weight_shape[1]; j++) { + for (int k = 0; k < output_weight_shape[2]; k++) { + for (int l = 0; l < output_weight_shape[3]; l++) { + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); + auto value = conv_weight_tensor_buffer[value_flat_input_index]; + auto output_flat_input_index = + compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); + output_buffer[output_flat_input_index] = value; + } + } + } + } + + auto output_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); + return output_tensor; +} + +/* +Converts convolution weights to grouped layout with padded zeros +This function will take in a weight tensor with shape [out_channels, in_channels // groups, H, W] and return a newly +allocated output tensor with shape [out_channels, in_channels, H, W] The extra channels in shape[1] will be padded with +0 - then the entire weight tensor is convolved with the input tensor - equivalent to convolution if the input tensor was +divided into num_groups for each groupped filter +*/ +Tensor convert_conv_weight_tensor_to_grouped_layout( + const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { + // Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value + // should match number of input channels being convolved with the weight tensor + auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); + ttnn::SimpleShape original_conv_weight_tensor_shape{ + original_conv_weight_tensor_shape_test[0], + original_conv_weight_tensor_shape_test[1], + original_conv_weight_tensor_shape_test[2], + original_conv_weight_tensor_shape_test[3]}; + ttnn::SimpleShape output_conv_weight_tensor_shape{ + original_conv_weight_tensor_shape[0], + original_conv_weight_tensor_shape[1] * num_groups, + original_conv_weight_tensor_shape[2], + original_conv_weight_tensor_shape[3]}; + + const static std::unordered_map< + DataType, + std::function> + to_w_tile_layout_map = { + {DataType::INT32, &conv_group_weight_zero_pad_helper}, + {DataType::FLOAT32, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT16, &conv_group_weight_zero_pad_helper}, + {DataType::UINT16, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT8_B, &conv_group_weight_zero_pad_helper}, + {DataType::UINT32, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT4_B, &conv_group_weight_zero_pad_helper}, + }; + output_dtype = output_dtype == DataType::BFLOAT8_B ? DataType::FLOAT32 : output_dtype; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, + output_dtype, + to_w_tile_layout_map, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups); +} + +/* +Converts convolution weights to depthwise layout +This function will take in a weight tensor with shape [out_channels, 1, H, W] and return a newly +allocated output tensor with shape [out_channels, act_block_h, H, W] The extra channels in shape[1] are repeated +from the original weight tensor - it would be convolving act_block in conv_matrix in one go +*/ +Tensor convert_conv_weight_tensor_to_depthwise_layout( + const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype) { + auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); + uint32_t num_input_channels_to_repeat = act_block_h_ntiles * constants::TILE_HEIGHT; + ttnn::SimpleShape original_conv_weight_tensor_shape{ + original_conv_weight_tensor_shape_test[0], + original_conv_weight_tensor_shape_test[1], + original_conv_weight_tensor_shape_test[2], + original_conv_weight_tensor_shape_test[3]}; + ttnn::SimpleShape output_conv_weight_tensor_shape{ + original_conv_weight_tensor_shape[0], + num_input_channels_to_repeat, + original_conv_weight_tensor_shape[2], + original_conv_weight_tensor_shape[3]}; + + // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor + const static std:: + unordered_map> + to_w_tile_layout_map = { + {DataType::INT32, &conv_depthwise_weight_bcast_helper}, + {DataType::FLOAT32, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT16, &conv_depthwise_weight_bcast_helper}, + {DataType::UINT16, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT8_B, &conv_depthwise_weight_bcast_helper}, + {DataType::UINT32, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT4_B, &conv_depthwise_weight_bcast_helper}, + }; + output_dtype = ((output_dtype == DataType::BFLOAT8_B) || (output_dtype == DataType::BFLOAT4_B)) ? DataType::FLOAT32 + : output_dtype; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, + output_dtype, + to_w_tile_layout_map, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape); +} + void validate_weight_tensor(const ttnn::Tensor& weight_tensor) { TT_FATAL( !ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE), "conv weight should be placed on host"); @@ -219,16 +688,14 @@ std::pair> prepare_conv_weights_biases // Convert weight tensor to 0 padded shape if groups > 1 if (!is_conv1d and groups > 1) { - weight_tensor_ = - tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); + weight_tensor_ = convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); } else if (is_conv1d and groups > 1) { if (is_depthwise_conv) { weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype); weight_block_h_ntiles = act_block_h_ntiles; } else { - weight_tensor_ = - tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); + weight_tensor_ = convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); } } @@ -267,13 +734,13 @@ std::pair> prepare_conv_weights_biases // for conv op, pad the weights to block shape if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( + weight_tensor_ = convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); } else if (parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded( + weight_tensor_ = convert_conv_weight_tensor_to_tiled_layout_block_sharded( weight_tensor_, num_cores_channels, weights_bias_dtype); } else { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( + weight_tensor_ = convert_conv_weight_tensor_to_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index d901177fe1d0..2c4b7f8eab13 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -19,6 +19,40 @@ namespace ttnn { namespace operations::conv { namespace conv2d { +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout( + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype = std::nullopt); + +// Converts convolution weights to tilized 2d matrix layout for block sharded conv. Adds zero padding between weight +// blocks based on output shard width padding. Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); + +// Converts convolution bias to tilized layout for block sharded conv. Adds zero padding between bias blocks based on +// output shard width padding. Returns a new tensor with layout=Tile +Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); + +// Converts convolution weights to tilized 2d matrix layout with special block height padding +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype = std::nullopt); + +// Converts convolution weights to grouped layout with padded zeros +Tensor convert_conv_weight_tensor_to_grouped_layout( + const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); + +// Converts convolution weights to depthwise layout with broadcasted weights +Tensor convert_conv_weight_tensor_to_depthwise_layout( + const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); + template ttnn::Tensor conv_bias_layout_convert( const ttnn::Tensor& bias_tensor, diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 79e9e47783ce..6c521ebf0d68 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -12,474 +12,6 @@ namespace tt { namespace tt_metal { -template -Tensor convert_tensor(const Tensor& input_tensor, compute_& compute) { - auto convert_tensor = [&compute](const auto& input_tensor) { - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - input_tensor.get_storage()); - }; - - return ttnn::distributed::is_multi_device_tensor(input_tensor) ? transform(input_tensor, convert_tensor) - : convert_tensor(input_tensor); -} -template -Tensor convert_tensor_to_tiled_layout_common( - const Tensor& input_tensor, - std::optional output_dtype, - const std::unordered_map& function_map, - Args&&... args) { - TT_ASSERT( - input_tensor.get_layout() == Layout::ROW_MAJOR && - "Tensor(weight/bias) should be in row major layout for conversion to tilized layout."); - - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(input_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(input_tensor.get_dtype() == input_tensor.get_dtype()); - } - } - auto entry = function_map.find(input_tensor.get_dtype()); - if (entry == function_map.end()) { - TT_THROW("Unsupported data type"); - } - return entry->second(input_tensor, std::forward(args)..., output_dtype.value_or(input_tensor.get_dtype())); -} - -template -Tensor create_tensor_from_owned_buffer( - owned_buffer::Buffer& buf, DataType& output_dtype, ttnn::SimpleShape& output_shape) { - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - auto tensor = - Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, DataType::FLOAT32, Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - output_dtype == DataType::BFLOAT8_B - ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) - : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - } else { - TT_FATAL( - (output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B), - "Unsupported output datatype"); - } - auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); -} - -template -Tensor to_weight_special_padding_tile_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - auto weight_matrix_cols = w_shape[0]; - // width padding - if (weight_matrix_cols % in1_block_w_datums != 0) { - weight_matrix_cols = - (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; - } - // height padding - assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); - uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); - auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; - ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(output_shape.volume()); - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c = 0; c < w_shape[1]; c++) { - for (auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + - r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; - auto idx = - k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; - } - } - } - } - return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); - }; - return convert_tensor(conv_weight_tensor, compute); -} - -template -Tensor to_weight_tile_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - auto weight_matrix_cols = w_shape[0]; - // width padding - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - if (weight_matrix_cols % in1_block_w_datums != 0) { - weight_matrix_cols = - (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; - } - // height padding - auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - if (weight_matrix_rows % in1_block_h_datums != 0) { - weight_matrix_rows = - (uint32_t)std::ceil((double)weight_matrix_rows / (double)in1_block_h_datums) * in1_block_h_datums; - } - ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(output_shape.volume()); - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c = 0; c < w_shape[1]; c++) { - for (auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + - r * w_shape[3] * w_shape[1] * weight_matrix_cols; - auto idx = - k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; - } - } - } - } - return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); - }; - - return convert_tensor(conv_weight_tensor, compute); -} - -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype) { - const static std::unordered_map> - to_w_tile_layout_map = { - {DataType::BFLOAT16, &to_weight_tile_layout}, - {DataType::FLOAT32, &to_weight_tile_layout}, - {DataType::UINT32, &to_weight_tile_layout}}; - - return convert_tensor_to_tiled_layout_common( - conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); -} - -template -Tensor to_weight_tile_layout_block_sharded( - const Tensor& conv_weight_tensor, uint32_t num_channel_shards, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = [&w_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { - auto weight_matrix_cols = w_shape[0]; - TT_ASSERT(weight_matrix_cols % num_channel_shards == 0); - auto conv_output_shard_width = weight_matrix_cols / num_channel_shards; - auto conv_output_shard_width_padded = - (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * - constants::TILE_WIDTH; - if (conv_output_shard_width < conv_output_shard_width_padded) { - // width padding for conv output shard padding - weight_matrix_cols = conv_output_shard_width_padded * num_channel_shards; - } - - auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; - TT_ASSERT(w_shape[1] % num_channel_shards == 0); - auto conv_input_shard_width = w_shape[1] / num_channel_shards; - auto weight_block_height = conv_input_shard_width * w_shape[2] * w_shape[3]; - auto weight_block_height_padded = - (uint32_t)std::ceil((double)weight_block_height / (double)constants::TILE_HEIGHT) * constants::TILE_HEIGHT; - if (weight_block_height < weight_block_height_padded) { - // height padding for non tile multiple block height - weight_matrix_rows = weight_block_height_padded * num_channel_shards; - } - ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(output_shape.volume()); - for (auto ic = 0; ic < num_channel_shards; ic++) { - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c_s = 0; c_s < conv_input_shard_width; c_s++) { - for (auto oc = 0; oc < num_channel_shards; oc++) { - for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { - auto matrix_idx = (oc * conv_output_shard_width_padded + k_s) + - c_s * weight_matrix_cols + - s * conv_input_shard_width * weight_matrix_cols + - r * w_shape[3] * conv_input_shard_width * weight_matrix_cols + - ic * weight_block_height_padded * weight_matrix_cols; - auto idx = (oc * conv_output_shard_width + k_s) * w_shape[1] * w_shape[2] * w_shape[3] + - (ic * conv_input_shard_width + c_s) * w_shape[2] * w_shape[3] + - r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; - } - } - } - } - } - } - return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); - }; - return convert_tensor(conv_weight_tensor, compute); -} - -// Converts convolution weights to tilized 2d matrix layout for block sharded conv. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( - const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype) { - const static std::unordered_map> - to_w_tile_layout_map = { - {DataType::BFLOAT16, &to_weight_tile_layout_block_sharded}, - {DataType::FLOAT32, &to_weight_tile_layout_block_sharded}, - {DataType::UINT32, &to_weight_tile_layout_block_sharded}}; - - return convert_tensor_to_tiled_layout_common( - conv_weight_tensor, output_dtype, to_w_tile_layout_map, num_channel_shards); -} - -template -Tensor to_bias_tile_layout_block_sharded( - const Tensor& conv_bias_tensor, uint32_t num_channel_shards, DataType output_dtype) { - auto b_shape = conv_bias_tensor.get_legacy_shape(); - TT_ASSERT(b_shape[0] == 1 && b_shape[1] == 1 && b_shape[2] == 1); - auto compute = [&b_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { - auto bias_matrix_cols = b_shape[3]; - /*TT_ASSERT(bias_matrix_cols % num_channel_shards == 0);*/ - auto conv_output_shard_width = bias_matrix_cols / num_channel_shards; - auto conv_output_shard_width_padded = - (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * - constants::TILE_WIDTH; - if (conv_output_shard_width < conv_output_shard_width_padded) { - // width padding for conv output shard padding - bias_matrix_cols = conv_output_shard_width_padded * num_channel_shards; - } - - auto bias_matrix_rows = 32; - ttnn::SimpleShape output_shape{1, 1, bias_matrix_rows, bias_matrix_cols}; - auto output_buffer = owned_buffer::create(output_shape.volume()); - for (auto oc = 0; oc < num_channel_shards; oc++) { - for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { - auto matrix_idx = oc * conv_output_shard_width_padded + k_s; - auto idx = oc * conv_output_shard_width + k_s; - output_buffer[matrix_idx] = input_buffer[idx]; - } - } - return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); - }; - - return convert_tensor(conv_bias_tensor, compute); -} - -// Converts convolution bias to tilized 2d matrix layout for block sharded conv. -// Returns a new tensor with layout=Tile -Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( - const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype) { - const static std::unordered_map< - DataType, - std::function> - to_b_tile_layout_map = { - {DataType::BFLOAT16, &to_bias_tile_layout_block_sharded}, - {DataType::FLOAT32, &to_bias_tile_layout_block_sharded}, - {DataType::UINT32, &to_bias_tile_layout_block_sharded}, - }; - return convert_tensor_to_tiled_layout_common( - conv_bias_tensor, output_dtype, to_b_tile_layout_map, num_channel_shards); -} - -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype) { - const static std::unordered_map> - to_w_tile_layout_map = { - {DataType::BFLOAT16, &to_weight_special_padding_tile_layout}, - {DataType::FLOAT32, &to_weight_special_padding_tile_layout}, - {DataType::UINT32, &to_weight_special_padding_tile_layout}}; - - return convert_tensor_to_tiled_layout_common( - conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); -} - -/* -Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels -*/ -template -static Tensor conv_group_weight_zero_pad_helper( - const Tensor& weight, - const ttnn::SimpleShape& original_weight_shape, - const ttnn::SimpleShape& output_weight_shape, - uint32_t num_groups, - DataType output_dtype) { - auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype]( - const auto& conv_weight_tensor_buffer) { - owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); - for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { - int new_batch_idx = curr_batch_idx; - - // Find which group_id the filter belongs to - through this, we can compute the offset where the padding - // should be applied - auto group_size = original_weight_shape[0] / num_groups; - auto group_index = curr_batch_idx / group_size; - auto group_id = std::min(group_index, num_groups - 1); - int new_channel_start_idx = group_id * original_weight_shape[1]; - - for (int j = 0; j < original_weight_shape[1]; j++) { - for (int k = 0; k < original_weight_shape[2]; k++) { - for (int m = 0; m < original_weight_shape[3]; m++) { - // Get value from original weight tensor - auto value_flat_input_index = compute_flat_indices( - ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); - auto value = conv_weight_tensor_buffer[value_flat_input_index]; - - // Copy value to output tensor at the adjusted position - auto new_channel_idx = new_channel_start_idx + j; - auto output_flat_input_index = compute_flat_indices( - ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, - compute_strides(output_weight_shape)); - output_buffer[output_flat_input_index] = value; - } - } - } - } - return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); - }; - - return convert_tensor(weight, pad_weight); -} - -/* -Helper function to aid in converting depthwise weight tensor to broadcasted weight tensor with repeated input channels -*/ -template -static Tensor conv_depthwise_weight_bcast_helper( - const Tensor& conv_weight_tensor, - const ttnn::SimpleShape& original_weight_shape, - const ttnn::SimpleShape& output_weight_shape, - DataType output_dtype) { - owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); - auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); - // Copy the original weight tensor to the output tensor - for (int i = 0; i < output_weight_shape[0]; i++) { - for (int j = 0; j < output_weight_shape[1]; j++) { - for (int k = 0; k < output_weight_shape[2]; k++) { - for (int l = 0; l < output_weight_shape[3]; l++) { - auto value_flat_input_index = compute_flat_indices( - ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); - auto value = conv_weight_tensor_buffer[value_flat_input_index]; - auto output_flat_input_index = - compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); - output_buffer[output_flat_input_index] = value; - } - } - } - } - - auto output_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); - return output_tensor; -} - -/* -Converts convolution weights to grouped layout with padded zeros -This function will take in a weight tensor with shape [out_channels, in_channels // groups, H, W] and return a newly -allocated output tensor with shape [out_channels, in_channels, H, W] The extra channels in shape[1] will be padded with -0 - then the entire weight tensor is convolved with the input tensor - equivalent to convolution if the input tensor was -divided into num_groups for each groupped filter -*/ -Tensor convert_conv_weight_tensor_to_grouped_layout( - const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { - // Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value - // should match number of input channels being convolved with the weight tensor - auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); - ttnn::SimpleShape original_conv_weight_tensor_shape{ - original_conv_weight_tensor_shape_test[0], - original_conv_weight_tensor_shape_test[1], - original_conv_weight_tensor_shape_test[2], - original_conv_weight_tensor_shape_test[3]}; - ttnn::SimpleShape output_conv_weight_tensor_shape{ - original_conv_weight_tensor_shape[0], - original_conv_weight_tensor_shape[1] * num_groups, - original_conv_weight_tensor_shape[2], - original_conv_weight_tensor_shape[3]}; - - const static std::unordered_map< - DataType, - std::function> - to_w_tile_layout_map = { - {DataType::INT32, &conv_group_weight_zero_pad_helper}, - {DataType::FLOAT32, &conv_group_weight_zero_pad_helper}, - {DataType::BFLOAT16, &conv_group_weight_zero_pad_helper}, - {DataType::UINT16, &conv_group_weight_zero_pad_helper}, - {DataType::BFLOAT8_B, &conv_group_weight_zero_pad_helper}, - {DataType::UINT32, &conv_group_weight_zero_pad_helper}, - {DataType::BFLOAT4_B, &conv_group_weight_zero_pad_helper}, - }; - output_dtype = output_dtype == DataType::BFLOAT8_B ? DataType::FLOAT32 : output_dtype; - - return convert_tensor_to_tiled_layout_common( - conv_weight_tensor, - output_dtype, - to_w_tile_layout_map, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups); -} - -/* -Converts convolution weights to depthwise layout -This function will take in a weight tensor with shape [out_channels, 1, H, W] and return a newly -allocated output tensor with shape [out_channels, act_block_h, H, W] The extra channels in shape[1] are repeated -from the original weight tensor - it would be convolving act_block in conv_matrix in one go -*/ -Tensor convert_conv_weight_tensor_to_depthwise_layout( - const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype) { - auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); - uint32_t num_input_channels_to_repeat = act_block_h_ntiles * constants::TILE_HEIGHT; - ttnn::SimpleShape original_conv_weight_tensor_shape{ - original_conv_weight_tensor_shape_test[0], - original_conv_weight_tensor_shape_test[1], - original_conv_weight_tensor_shape_test[2], - original_conv_weight_tensor_shape_test[3]}; - ttnn::SimpleShape output_conv_weight_tensor_shape{ - original_conv_weight_tensor_shape[0], - num_input_channels_to_repeat, - original_conv_weight_tensor_shape[2], - original_conv_weight_tensor_shape[3]}; - - // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor - const static std:: - unordered_map> - to_w_tile_layout_map = { - {DataType::INT32, &conv_depthwise_weight_bcast_helper}, - {DataType::FLOAT32, &conv_depthwise_weight_bcast_helper}, - {DataType::BFLOAT16, &conv_depthwise_weight_bcast_helper}, - {DataType::UINT16, &conv_depthwise_weight_bcast_helper}, - {DataType::BFLOAT8_B, &conv_depthwise_weight_bcast_helper}, - {DataType::UINT32, &conv_depthwise_weight_bcast_helper}, - {DataType::BFLOAT4_B, &conv_depthwise_weight_bcast_helper}, - }; - output_dtype = ((output_dtype == DataType::BFLOAT8_B) || (output_dtype == DataType::BFLOAT4_B)) ? DataType::FLOAT32 - : output_dtype; - - return convert_tensor_to_tiled_layout_common( - conv_weight_tensor, - output_dtype, - to_w_tile_layout_map, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape); -} - const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape) { int64_t old_volume = tensor.get_logical_volume(); int64_t new_volume = 1; diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 9621a4962445..d538568b63ea 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -12,40 +12,6 @@ namespace tt { namespace tt_metal { -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to tilized 2d matrix layout for block sharded conv. Adds zero padding between weight -// blocks based on output shard width padding. Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( - const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); - -// Converts convolution bias to tilized layout for block sharded conv. Adds zero padding between bias blocks based on -// output shard width padding. Returns a new tensor with layout=Tile -Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( - const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); - -// Converts convolution weights to tilized 2d matrix layout with special block height padding -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - const Tensor& conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout( - const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); - -// Converts convolution weights to depthwise layout with broadcasted weights -Tensor convert_conv_weight_tensor_to_depthwise_layout( - const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); - const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape); // TODO: Remove this once we switch to SimpleShape .volume()