Skip to content

Commit

Permalink
#14049: adding weight config to avoid duplicate computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 22, 2024
1 parent 3dfbf32 commit dc668ab
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 30 deletions.
9 changes: 8 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ Result conv2d(
std::optional<ttnn::Tensor> bias_tensor_on_device = bias_tensor;
if (!weight_is_on_device) {
// prepare weights in desired layout and move to device
WeightConfig weight_config = WeightConfig{
.act_block_w_ntiles = opt_conv_op_block_config.act_block_w_ntiles,
.out_subblock_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles,
.act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles,
.shard_layout = parallel_config.shard_scheme,
};
tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights(
weight_tensor,
input_tensor_post_tm.memory_config(),
Expand All @@ -123,7 +129,8 @@ Result conv2d(
groups,
device,
bias_tensor,
conv_config
conv_config,
weight_config
);
weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt);
if(bias_tensor.has_value()){
Expand Down
6 changes: 4 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ void py_bind_conv2d(py::module& module) {
py::arg("groups"),
py::arg("device"),
py::arg("bias_tensor") = std::nullopt,
py::arg("conv_config") = std::nullopt);
py::arg("conv_config") = std::nullopt,
py::arg("weight_config") = std::nullopt);


module.def(
Expand All @@ -166,7 +167,8 @@ void py_bind_conv2d(py::module& module) {
py::arg("groups"),
py::arg("device"),
py::arg("bias_tensor") = std::nullopt,
py::arg("conv_config") = std::nullopt);
py::arg("conv_config") = std::nullopt,
py::arg("weight_config") = std::nullopt);

module.def(
"prepare_conv_bias",
Expand Down
7 changes: 7 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ struct Conv2dConfig {
}
};

struct WeightConfig{
uint32_t act_block_w_ntiles;
uint32_t out_subblock_w_ntiles;
uint32_t act_block_h_ntiles;
TensorMemoryLayout shard_layout;
};

uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor);

uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor);
Expand Down
62 changes: 37 additions & 25 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ OptimizedConvBlockConfig get_opt_block_config(
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;

bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 &&
conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR;
(conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR;
use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16;

ParallelConfig parallel_config = determine_parallel_config(
Expand Down Expand Up @@ -135,35 +135,45 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
uint32_t groups,
T *device,
std::optional<const ttnn::Tensor>& bias_tensor,
std::optional<const Conv2dConfig> conv_config_) {
std::optional<const Conv2dConfig> conv_config_,
std::optional<const WeightConfig> weight_config_) {
ttnn::Tensor bias_tensor_;
TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation.");

const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups);
const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
const uint32_t output_width =
((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

uint32_t weight_block_h_ntiles=0, weight_block_w_ntiles=0, act_block_h_ntiles=0;
Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
auto opt_conv_op_block_config = get_opt_block_config(
mm_conv,
in_channels,
out_channels,
output_height,
output_width,
batch_size,
input_width,
kernel_size,
stride,
device,
conv_config,
input_tensor_layout,
input_memory_config
);

uint32_t weight_block_h_ntiles = opt_conv_op_block_config.act_block_w_ntiles;
uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles;
uint32_t act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles;
TensorMemoryLayout tensor_layout;
if(conv_config_.has_value() && weight_config_.has_value()) {
WeightConfig weight_config = weight_config_.value();
weight_block_h_ntiles = weight_config.act_block_w_ntiles;
weight_block_w_ntiles = weight_config.out_subblock_w_ntiles;
act_block_h_ntiles = weight_config.act_block_h_ntiles;
tensor_layout = weight_config.shard_layout;
}else{
auto opt_conv_op_block_config = get_opt_block_config(
mm_conv,
in_channels,
out_channels,
output_height,
output_width,
batch_size,
input_width,
kernel_size,
stride,
device,
conv_config,
input_tensor_layout,
input_memory_config
);
weight_block_h_ntiles = opt_conv_op_block_config.act_block_w_ntiles;
weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles;
act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles;
tensor_layout = conv_config.shard_layout.value();
}

validate_weight_tensor(weight_tensor);
ttnn::Tensor weight_tensor_ = weight_tensor; // tensor to return
Expand Down Expand Up @@ -213,7 +223,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED) {
if (tensor_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, conv_config.weights_dtype);
} else {
Expand Down Expand Up @@ -362,7 +372,8 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
uint32_t groups,
Device *device,
std::optional<const ttnn::Tensor>& bias_tensor,
std::optional<const Conv2dConfig> conv_config_);
std::optional<const Conv2dConfig> conv_config_,
std::optional<const WeightConfig> weight_config);

template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights<MeshDevice>(
const ttnn::Tensor& weight_tensor,
Expand All @@ -381,7 +392,8 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
uint32_t groups,
MeshDevice *device,
std::optional<const ttnn::Tensor>& bias_tensor,
std::optional<const Conv2dConfig> conv_config_);
std::optional<const Conv2dConfig> conv_config_,
std::optional<const WeightConfig> weight_config);

template ttnn::Tensor prepare_conv_bias<Device>(
const ttnn::Tensor& bias_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
uint32_t groups,
T *device,
std::optional<const ttnn::Tensor>& bias_tensor,
std::optional<const Conv2dConfig> conv_config_);
std::optional<const Conv2dConfig> conv_config_,
std::optional<const WeightConfig> weight_config_);

template <typename T>
ttnn::Tensor prepare_conv_bias(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ Result conv_transpose2d(
// input_width);

// prepare weights in desired layout and move to device
WeightConfig weight_config = WeightConfig{
.act_block_w_ntiles = opt_conv_op_block_config.act_block_w_ntiles,
.out_subblock_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles,
.act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles,
};

tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights(
transform_weights_for_conv_transpose2d(weight_tensor),
input_tensor_post_tm.memory_config(),
Expand All @@ -272,7 +278,8 @@ Result conv_transpose2d(
groups,
device,
bias_tensor,
conv_config
conv_config,
weight_config
);
weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt);
if(bias_tensor.has_value()){
Expand Down

0 comments on commit dc668ab

Please sign in to comment.