Skip to content

Commit

Permalink
#13794: Conv2d BS arbitrary kernel dims
Browse files Browse the repository at this point in the history
Allow arbitrary kernel dimensions in block
sharded conv2d.

This enabled more torch traces to pass as
block sharding is now more viable option is
auto-shard codepath for convs.

Logic for arbitrary double buffer of activations
and weights in block sharded code path is removed.
This caused issues with some torch trace examples
that used to pass with height sharding, as now
auto shard would pick block sharding as a better
option, but arbitrary double buffer would cause
out-of-memory issues.
  • Loading branch information
Pavle Josipovic committed Nov 1, 2024
1 parent bffe65a commit 7063540
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def run_downsample_if_req(
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_accum_enabled,
enable_act_double_buffer=enable_act_double_buffer,
enable_act_double_buffer=enable_act_double_buffer
if height_sharding
else True
if input_width < 56
else False,
enable_weights_double_buffer=True if input_width < 56 else False,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
Expand Down Expand Up @@ -330,6 +335,7 @@ def __call__(
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
enable_act_double_buffer=enable_act_double_buffer,
enable_weights_double_buffer=True,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
Expand Down Expand Up @@ -835,6 +841,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt

reshard = False
height_shard = False
is_gs = is_grayskull()
if is_wormhole_b0() and self.batch_size == 20:
if is_first_run:
reshard = True if not is_wormhole_b0() else False
Expand All @@ -853,7 +860,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
reshard_if_not_optimal=reshard,
height_sharding=height_shard,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -877,7 +884,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=False,
enable_act_double_buffer=True,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -891,7 +898,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -905,7 +912,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -931,7 +938,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
reshard_if_not_optimal=reshard,
height_sharding=height_shard,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -955,7 +962,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -969,7 +976,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -983,7 +990,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -997,7 +1004,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -1012,7 +1019,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
conv_op_cache,
eltwise_binary_out_in_place=True,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand Down Expand Up @@ -1054,7 +1061,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
reshard_if_not_optimal=reshard,
height_sharding=height_shard,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -1078,7 +1085,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand All @@ -1092,7 +1099,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x_width,
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 or is_gs else False,
enable_split_reader=False,
enable_subblock_padding=False,
)
Expand Down
2 changes: 2 additions & 0 deletions models/demos/vgg/tt/ttnn_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def ttnn_vgg16(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=True,
enable_weights_double_buffer=True,
)

tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight
Expand Down Expand Up @@ -224,6 +225,7 @@ def ttnn_vgg11(
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
enable_weights_double_buffer=True,
)

tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def test_conv2d_localrun(device, input_spec):
[1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6
[1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15
[1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1], # 100
[1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141
[1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171
Expand All @@ -472,8 +471,6 @@ def test_conv2d_localrun(device, input_spec):
[1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293
[1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294
[1, 576, 576, 19, 19, 5, 5, 1, 1, 2, 2, 576, False, 1], # 300
[1, 672, 672, 24, 24, 5, 5, 1, 1, 2, 2, 672, False, 1], # 341
[1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347
[1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348
[1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363
Expand Down
3 changes: 0 additions & 3 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,6 @@ def test_conv_features(
fp32_accum,
packer_l1_acc,
):
if shard_layout == ttnn.TensorMemoryLayout.BLOCK_SHARDED and filter > 3:
pytest.skip("Block sharding only supports filter size <= 3")

if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")

Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.input_channels_alignment == 16,
compute_kernel_config,
conv_config.enable_act_double_buffer,
conv_config.enable_weights_double_buffer,
conv_config.enable_split_reader,
conv_config.enable_subblock_padding);
if (conv_config.deallocate_activation) {
Expand Down Expand Up @@ -990,6 +991,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.input_channels_alignment == 16,
compute_kernel_config,
conv_config.enable_act_double_buffer,
conv_config.enable_weights_double_buffer,
conv_config.enable_split_reader,
conv_config.enable_subblock_padding,
use_non_tile_height);
Expand Down
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct Conv2dConfig {
bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false
Layout output_layout = Layout::TILE;
bool enable_act_double_buffer = false;
bool enable_weights_double_buffer = false; // Used on for block sharded convolutions
bool enable_split_reader = false;
bool enable_subblock_padding = false;
static constexpr auto attribute_names = std::make_tuple(
Expand All @@ -68,6 +69,7 @@ struct Conv2dConfig {
"transpose_shards",
"output_layout",
"enable_act_double_buffer",
"enable_weights_double_buffer",
"enable_split_reader",
"enable_subblock_padding");
const auto attribute_values() const {
Expand All @@ -91,6 +93,7 @@ struct Conv2dConfig {
std::cref(this->transpose_shards),
std::cref(this->output_layout),
std::cref(this->enable_act_double_buffer),
std::cref(this->enable_weights_double_buffer),
std::cref(this->enable_split_reader),
std::cref(this->enable_subblock_padding));
}
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void py_bind_conv2d(py::module& module) {

auto py_conv_config = py::class_<Conv2dConfig>(module, "Conv2dConfig");
py_conv_config.def(
py::init<MathFidelity, DataType, DataType, bool, bool, bool, string, uint32_t, bool, bool, uint32_t, uint32_t, bool, bool, std::optional<TensorMemoryLayout>, std::optional<CoreRangeSet>, bool, Layout, bool, bool, bool>(),
py::init<MathFidelity, DataType, DataType, bool, bool, bool, string, uint32_t, bool, bool, uint32_t, uint32_t, bool, bool, std::optional<TensorMemoryLayout>, std::optional<CoreRangeSet>, bool, Layout, bool, bool, bool, bool>(),
py::kw_only(),
py::arg("math_fidelity") = MathFidelity::HiFi4,
py::arg("dtype") = DataType::BFLOAT16,
Expand All @@ -258,6 +258,7 @@ void py_bind_conv2d(py::module& module) {
py::arg("transpose_shards") = true,
py::arg("output_layout") = Layout::TILE,
py::arg("enable_act_double_buffer") = false,
py::arg("enable_weights_double_buffer") = false,
py::arg("enable_split_reader") = false,
py::arg("enable_subblock_padding") = false
);
Expand All @@ -280,6 +281,7 @@ void py_bind_conv2d(py::module& module) {
py_conv_config.def_readwrite("transpose_shards", &Conv2dConfig::transpose_shards);
py_conv_config.def_readwrite("output_layout", &Conv2dConfig::output_layout);
py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer);
py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer);
py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader);
py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding);

Expand Down
6 changes: 4 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional<const
bool use_shallow_conv_variant,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config,
bool enable_act_double_buffer,
bool enable_weights_double_buffer,
bool enable_split_reader,
bool enable_subblock_padding,
bool use_non_tile_height
) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))};
operation::launch_op(
[sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height]
[sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height]
(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
using ttnn::operations::experimental::auto_format::FormatParams;
auto& a = input_tensors.at(0);
Expand All @@ -86,7 +87,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional<const
bool fp32_accum = a.device()->arch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false;
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false);
return operation::run_without_autoformat(
OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height
OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height
),
input_tensors,
optional_input_tensors);
Expand Down Expand Up @@ -235,6 +236,7 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect
compute_kernel_config,
output_tensor,
enable_act_double_buffer,
enable_weights_double_buffer,
enable_split_reader,
enable_subblock_padding,
use_non_tile_height);
Expand Down
8 changes: 7 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T
std::optional<const DeviceComputeKernelConfig> compute_kernel_config,
Tensor& output,
bool enable_act_double_buffer,
bool enable_weights_double_buffer,
bool enable_split_reader,
bool enable_subblock_padding,
bool use_non_tile_height);
Expand All @@ -77,6 +78,7 @@ struct OptimizedConvNew {
bool use_shallow_conv_variant;
const DeviceComputeKernelConfig compute_kernel_config;
bool enable_act_double_buffer;
bool enable_weights_double_buffer;
bool enable_split_reader;
bool enable_subblock_padding;
bool use_non_tile_height;
Expand All @@ -89,7 +91,7 @@ struct OptimizedConvNew {
MemoryConfig out_mem_config,
DataType dtype,
std::array<std::uint32_t, 4> input_tensor_shape, bool use_shallow_conv_variant,
const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) :
const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) :
output_channels(output_channels),
groups(groups),
sliding_window_config(sliding_window_config),
Expand All @@ -104,6 +106,7 @@ struct OptimizedConvNew {
use_shallow_conv_variant(use_shallow_conv_variant),
compute_kernel_config(compute_kernel_config),
enable_act_double_buffer(enable_act_double_buffer),
enable_weights_double_buffer(enable_weights_double_buffer),
enable_split_reader(enable_split_reader),
enable_subblock_padding(enable_subblock_padding),
use_non_tile_height(use_non_tile_height) {}
Expand All @@ -128,6 +131,7 @@ struct OptimizedConvNew {
"input_tensor_shape",
"use_shallow_conv_variant",
"enable_act_double_buffer",
"enable_weights_double_buffer",
"enable_split_reader",
"enable_subblock_padding");
const auto attribute_values() const {
Expand All @@ -144,6 +148,7 @@ struct OptimizedConvNew {
std::cref(this->input_tensor_shape),
std::cref(this->use_shallow_conv_variant),
std::cref(this->enable_act_double_buffer),
std::cref(this->enable_weights_double_buffer),
std::cref(this->enable_split_reader),
std::cref(this->enable_subblock_padding));
}
Expand All @@ -162,6 +167,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional<const
bool use_shallow_conv_variant,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
bool enable_act_double_buffer = false,
bool enable_weights_double_buffer = false,
bool enable_split_reader = false,
bool enable_subblock_padding = false,
bool use_non_tile_height = false
Expand Down
Loading

0 comments on commit 7063540

Please sign in to comment.