From dc668abca98306e18cb1cd8018728fd5406367c4 Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Fri, 22 Nov 2024 10:16:59 +0000 Subject: [PATCH] #14049: adding weight config to avoid duplicate computation. --- .../ttnn/operations/conv/conv2d/conv2d.cpp | 9 ++- .../operations/conv/conv2d/conv2d_pybind.cpp | 6 +- .../operations/conv/conv2d/conv2d_utils.hpp | 7 +++ .../conv/conv2d/prepare_conv2d_weights.cpp | 62 +++++++++++-------- .../conv/conv2d/prepare_conv2d_weights.hpp | 3 +- .../conv_transpose2d/conv_transpose2d.cpp | 9 ++- 6 files changed, 66 insertions(+), 30 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index f8b147b573f..af3c7ac6b79 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -106,6 +106,12 @@ Result conv2d( std::optional bias_tensor_on_device = bias_tensor; if (!weight_is_on_device) { // prepare weights in desired layout and move to device + WeightConfig weight_config = WeightConfig{ + .act_block_w_ntiles = opt_conv_op_block_config.act_block_w_ntiles, + .out_subblock_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles, + .act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles, + .shard_layout = parallel_config.shard_scheme, + }; tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights( weight_tensor, input_tensor_post_tm.memory_config(), @@ -123,7 +129,8 @@ Result conv2d( groups, device, bias_tensor, - conv_config + conv_config, + weight_config ); weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt); if(bias_tensor.has_value()){ diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 01f62870d9b..ce902c79c39 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -143,7 +143,8 @@ void py_bind_conv2d(py::module& module) { py::arg("groups"), py::arg("device"), py::arg("bias_tensor") = std::nullopt, - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("weight_config") = std::nullopt); module.def( @@ -166,7 +167,8 @@ void py_bind_conv2d(py::module& module) { py::arg("groups"), py::arg("device"), py::arg("bias_tensor") = std::nullopt, - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("weight_config") = std::nullopt); module.def( "prepare_conv_bias", diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index a07a8635002..c3c4097fcd7 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -103,6 +103,13 @@ struct Conv2dConfig { } }; +struct WeightConfig{ + uint32_t act_block_w_ntiles; + uint32_t out_subblock_w_ntiles; + uint32_t act_block_h_ntiles; + TensorMemoryLayout shard_layout; +}; + uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor); uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 2511b56d2b1..f0ecbb2da91 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -71,7 +71,7 @@ OptimizedConvBlockConfig get_opt_block_config( conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && - conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR; + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR; use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; ParallelConfig parallel_config = determine_parallel_config( @@ -135,7 +135,8 @@ std::pair> prepare_conv_weights( uint32_t groups, T *device, std::optional& bias_tensor, - std::optional conv_config_) { + std::optional conv_config_, + std::optional weight_config_) { ttnn::Tensor bias_tensor_; TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation."); @@ -143,27 +144,36 @@ std::pair> prepare_conv_weights( const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; - + uint32_t weight_block_h_ntiles=0, weight_block_w_ntiles=0, act_block_h_ntiles=0; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - auto opt_conv_op_block_config = get_opt_block_config( - mm_conv, - in_channels, - out_channels, - output_height, - output_width, - batch_size, - input_width, - kernel_size, - stride, - device, - conv_config, - input_tensor_layout, - input_memory_config - ); - - uint32_t weight_block_h_ntiles = opt_conv_op_block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; - uint32_t act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles; + TensorMemoryLayout tensor_layout; + if(conv_config_.has_value() && weight_config_.has_value()) { + WeightConfig weight_config = weight_config_.value(); + weight_block_h_ntiles = weight_config.act_block_w_ntiles; + weight_block_w_ntiles = weight_config.out_subblock_w_ntiles; + act_block_h_ntiles = weight_config.act_block_h_ntiles; + tensor_layout = weight_config.shard_layout; + }else{ + auto opt_conv_op_block_config = get_opt_block_config( + mm_conv, + in_channels, + out_channels, + output_height, + output_width, + batch_size, + input_width, + kernel_size, + stride, + device, + conv_config, + input_tensor_layout, + input_memory_config + ); + weight_block_h_ntiles = opt_conv_op_block_config.act_block_w_ntiles; + weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; + act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles; + tensor_layout = conv_config.shard_layout.value(); + } validate_weight_tensor(weight_tensor); ttnn::Tensor weight_tensor_ = weight_tensor; // tensor to return @@ -213,7 +223,7 @@ std::pair> prepare_conv_weights( weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); // for conv op, pad the weights to block shape - if (conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED) { + if (tensor_layout == TensorMemoryLayout::HEIGHT_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, conv_config.weights_dtype); } else { @@ -362,7 +372,8 @@ template std::pair> prepare_conv_weigh uint32_t groups, Device *device, std::optional& bias_tensor, - std::optional conv_config_); + std::optional conv_config_, + std::optional weight_config); template std::pair> prepare_conv_weights( const ttnn::Tensor& weight_tensor, @@ -381,7 +392,8 @@ template std::pair> prepare_conv_weigh uint32_t groups, MeshDevice *device, std::optional& bias_tensor, - std::optional conv_config_); + std::optional conv_config_, + std::optional weight_config); template ttnn::Tensor prepare_conv_bias( const ttnn::Tensor& bias_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 2e9834ad47d..4cf349da240 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -44,7 +44,8 @@ std::pair> prepare_conv_weights( uint32_t groups, T *device, std::optional& bias_tensor, - std::optional conv_config_); + std::optional conv_config_, + std::optional weight_config_); template ttnn::Tensor prepare_conv_bias( diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 5861c53957a..ca3e4ecb9a2 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -255,6 +255,12 @@ Result conv_transpose2d( // input_width); // prepare weights in desired layout and move to device + WeightConfig weight_config = WeightConfig{ + .act_block_w_ntiles = opt_conv_op_block_config.act_block_w_ntiles, + .out_subblock_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles, + .act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles, + }; + tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights( transform_weights_for_conv_transpose2d(weight_tensor), input_tensor_post_tm.memory_config(), @@ -272,7 +278,8 @@ Result conv_transpose2d( groups, device, bias_tensor, - conv_config + conv_config, + weight_config ); weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt); if(bias_tensor.has_value()){