Skip to content

Commit

Permalink
[Transformations] Add downgrade transformation for AvgPool-14 and `…
Browse files Browse the repository at this point in the history
…MaxPool-14` (#23381)

### Details
 - Add downgrade transformation `AvgPool-14` -> `AvgPool-1`
 - Add downgrade transformation `MaxPool-14` -> `MaxPool-8`
 - This PR unblocks PT FE extension

### Related PRs
 - #22796
 - #22930
 - #22966
 - #23027
 - #23582

### Tickets:
 - 133918

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
p-wysocki and mlukasze authored May 8, 2024
1 parent 0470cff commit bf959b7
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
/**
* @ingroup ie_transformation_common_api
* @brief Converts AvgPool v14 to AvgPool v1
*/
class TRANSFORMATIONS_API ConvertAvgPool14ToAvgPool1 : public MatcherPass {
public:
OPENVINO_RTTI("ConvertAvgPool14ToAvgPool1", "0");
ConvertAvgPool14ToAvgPool1();
};
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertMaxPool8ToMaxPool1;
class TRANSFORMATIONS_API ConvertMaxPool14ToMaxPool8;

} // namespace pass
} // namespace ov
Expand All @@ -24,3 +25,13 @@ class ov::pass::ConvertMaxPool8ToMaxPool1 : public ov::pass::MatcherPass {
OPENVINO_RTTI("ConvertMaxPool8ToMaxPool1");
ConvertMaxPool8ToMaxPool1();
};

/**
* @ingroup ov_transformation_common_api
* @brief ConvertMaxPool14ToMaxPool8 converts v14::MaxPool into v8::MaxPool.
*/
class ov::pass::ConvertMaxPool14ToMaxPool8 : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertMaxPool14ToMaxPool8", "0");
ConvertMaxPool14ToMaxPool8();
};
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "transformations/init_node_info.hpp"
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_avgpool_downgrade.hpp"
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
Expand Down Expand Up @@ -209,6 +210,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_DISABLED_PASS(manager, ConvertSoftMax1ToSoftMax8)
REGISTER_PASS(manager, ConvertMaxPool8ToMaxPool1)
REGISTER_DISABLED_PASS(manager, ConvertMaxPool1ToMaxPool8)
REGISTER_PASS(manager, ConvertMaxPool14ToMaxPool8)
REGISTER_PASS(manager, ConvertPriorBox8To0)
REGISTER_DISABLED_PASS(manager, ConvertDetectionOutput1ToDetectionOutput8)
REGISTER_PASS(manager, ConvertDetectionOutput8ToDetectionOutput1)
Expand All @@ -221,6 +223,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertPad12ToPad1)
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
REGISTER_PASS(manager, ConcatFusion)
REGISTER_PASS(manager, ConvertAvgPool14ToAvgPool1)

auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_avgpool_downgrade.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/avg_pool.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::ConvertAvgPool14ToAvgPool1::ConvertAvgPool14ToAvgPool1() {
MATCHER_SCOPE(ConvertAvgPool14ToAvgPool1);

const auto avg_pool_v14_pattern = pattern::wrap_type<ov::op::v14::AvgPool>();

const matcher_pass_callback callback = [](pattern::Matcher& m) {
const auto avg_pool_v14 = std::dynamic_pointer_cast<ov::op::v14::AvgPool>(m.get_match_root());
const auto rounding_type_v14 = avg_pool_v14->get_rounding_type();
const auto rounding_type_v1 =
rounding_type_v14 == ov::op::RoundingType::CEIL_TORCH ? ov::op::RoundingType::CEIL : rounding_type_v14;

const auto exclude_pad = avg_pool_v14->get_exclude_pad();
const auto input = avg_pool_v14->input_value(0);
NodeRegistry node_registry;
ov::Shape pads_begin;
ov::Shape pads_end;
ov::Output<ov::Node> new_input;

using ov::op::v0::Constant;
using ov::op::v0::Concat;
using ov::op::v1::Pad;
using ov::op::v1::Subtract;
using ov::op::v1::ConvertLike;
using ov::op::v3::Broadcast;
using ov::op::v3::ShapeOf;
using ov::op::v4::Range;

if (!exclude_pad && rounding_type_v14 == ov::op::RoundingType::CEIL_TORCH) {
const auto zero = node_registry.make<Constant>(element::f32, Shape{}, 0);
const auto zero_node = node_registry.make<ConvertLike>(zero, input);
const auto zero_i64 = node_registry.make<Constant>(element::i64, Shape{}, 0);
const auto shape = node_registry.make<ShapeOf>(input, element::i64);
const auto rank = node_registry.make<ShapeOf>(shape, element::i64);
const auto pads_begin_v14 = avg_pool_v14->get_pads_begin();
const auto pads_begin_node =
node_registry.make<Constant>(element::i64, Shape{pads_begin_v14.size()}, pads_begin_v14);
const auto pads_end_v14 = avg_pool_v14->get_pads_end();
const auto pads_end_node =
node_registry.make<Constant>(element::i64, Shape{pads_end_v14.size()}, pads_end_v14);
const auto pads_len = node_registry.make<Constant>(element::i64, Shape{}, pads_begin_v14.size());
const auto pads_diff = node_registry.make<Subtract>(rank, pads_len);
const auto pads_remaining = node_registry.make<Broadcast>(zero_i64, pads_diff);
const auto pads_begin_v1 = node_registry.make<ov::op::v0::Concat>(
OutputVector{std::move(pads_remaining), std::move(pads_begin_node)},
0);
const auto pads_end_v1 = node_registry.make<ov::op::v0::Concat>(
OutputVector{std::move(pads_remaining), std::move(pads_begin_node)},
0);
const auto pad_node =
node_registry.make<Pad>(input, pads_begin_v1, pads_end_v1, zero_node, ov::op::PadMode::CONSTANT);
pads_begin = Shape(pads_begin_v14.size(), 0);
pads_end = Shape(pads_begin_v14.size(), 0);
new_input = pad_node;
} else {
pads_begin = avg_pool_v14->get_pads_begin();
pads_end = avg_pool_v14->get_pads_end();
new_input = input;
}
const auto avg_pool_v1 = node_registry.make<ov::op::v1::AvgPool>(new_input,
avg_pool_v14->get_strides(),
pads_begin,
pads_end,
avg_pool_v14->get_kernel(),
exclude_pad,
rounding_type_v1,
avg_pool_v14->get_auto_pad());
avg_pool_v1->set_friendly_name(avg_pool_v14->get_friendly_name());
copy_runtime_info(avg_pool_v14, node_registry.get());
replace_node(avg_pool_v14, avg_pool_v1);
return true;
};

auto m = std::make_shared<pattern::Matcher>(avg_pool_v14_pattern, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@

#include "itt.hpp"
#include "openvino/core/descriptor/tensor.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/avg_pool.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/max_pool.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/utils/utils.hpp"

using namespace std;
using namespace ov;

pass::ConvertMaxPool8ToMaxPool1::ConvertMaxPool8ToMaxPool1() {
ov::pass::ConvertMaxPool8ToMaxPool1::ConvertMaxPool8ToMaxPool1() {
MATCHER_SCOPE(ConvertMaxPool8ToMaxPool1);

auto maxpool_v8_pattern = pattern::wrap_type<ov::op::v8::MaxPool>();
Expand All @@ -29,13 +40,13 @@ pass::ConvertMaxPool8ToMaxPool1::ConvertMaxPool8ToMaxPool1() {
if (dilation != 1)
return false;

auto maxpool_v1_node = make_shared<ov::op::v1::MaxPool>(maxpool_v8_node->input_value(0),
maxpool_v8_node->get_strides(),
maxpool_v8_node->get_pads_begin(),
maxpool_v8_node->get_pads_end(),
maxpool_v8_node->get_kernel(),
maxpool_v8_node->get_rounding_type(),
maxpool_v8_node->get_auto_pad());
auto maxpool_v1_node = std::make_shared<ov::op::v1::MaxPool>(maxpool_v8_node->input_value(0),
maxpool_v8_node->get_strides(),
maxpool_v8_node->get_pads_begin(),
maxpool_v8_node->get_pads_end(),
maxpool_v8_node->get_kernel(),
maxpool_v8_node->get_rounding_type(),
maxpool_v8_node->get_auto_pad());

OPENVINO_SUPPRESS_DEPRECATED_START
auto out_name = ov::op::util::create_ie_output_name(maxpool_v8_node->output(0));
Expand All @@ -53,6 +64,120 @@ pass::ConvertMaxPool8ToMaxPool1::ConvertMaxPool8ToMaxPool1() {
return true;
};

auto m = make_shared<pattern::Matcher>(maxpool_v8_pattern, matcher_name);
auto m = std::make_shared<pattern::Matcher>(maxpool_v8_pattern, matcher_name);
register_matcher(m, callback);
}

ov::pass::ConvertMaxPool14ToMaxPool8::ConvertMaxPool14ToMaxPool8() {
MATCHER_SCOPE(ConvertMaxPool14ToMaxPool8);
const auto max_pool_v14_pattern = pattern::wrap_type<ov::op::v14::MaxPool>();

const matcher_pass_callback callback = [](pattern::Matcher& m) {
using ov::op::v0::Constant;
using ov::op::v0::Concat;
using ov::op::v1::Subtract;
using ov::op::v1::Multiply;
using ov::op::v1::Greater;
using ov::op::v1::Select;
using ov::op::v1::ConvertLike;
using ov::op::v1::Add;
using ov::op::v3::ShapeOf;
using ov::op::v4::Range;
using ov::op::v8::Gather;
using ov::op::v12::Pad;

const auto max_pool_v14 = std::dynamic_pointer_cast<ov::op::v14::MaxPool>(m.get_match_root());
const auto rounding_type_v14 = max_pool_v14->get_rounding_type();
std::shared_ptr<ov::op::v8::MaxPool> max_pool_v8;
NodeRegistry node_registry;
if (rounding_type_v14 == ov::op::RoundingType::CEIL_TORCH) {
if (max_pool_v14->is_dynamic()) {
return false;
}
auto input = max_pool_v14->input_value(0);
const auto strides = max_pool_v14->get_strides();
const auto padding_begin = max_pool_v14->get_pads_begin();
const auto padding_begin_node =
node_registry.make<Constant>(element::i64, Shape{padding_begin.size()}, padding_begin);
const auto padding_end = max_pool_v14->get_pads_end();
const auto padding_end_node =
node_registry.make<Constant>(element::i64, Shape{padding_end.size()}, padding_end);
const auto zero = node_registry.make<Constant>(element::i64, Shape{}, 0);
const auto one = node_registry.make<Constant>(element::i64, Shape{}, 1);
const auto two = node_registry.make<Constant>(element::i64, Shape{}, 2);

const auto pads_size = max_pool_v14->get_pads_begin().size();
const auto pads_len = node_registry.make<Constant>(element::i64, Shape{}, pads_size);
const auto pads_remaining =
node_registry.make<Constant>(element::i64, Shape{2}, std::vector<int64_t>{0, 0});

// gather input spatial dims and prepare for compare as values (in_dim + pad)
const auto end = node_registry.make<Constant>(element::i64, Shape{}, pads_size + 2);
const auto dim_idxs = node_registry.make<Range>(two, end, one, element::i64);
const auto shape = node_registry.make<ShapeOf>(input, element::i64);
const auto gth_in_dims = node_registry.make<Gather>(shape, dim_idxs, zero);
const auto in_left_padded = node_registry.make<Add>(gth_in_dims, padding_begin_node);

// gather output spatial dims and prepare it for compare as values (out_dim - 1) * stride
const auto mp = node_registry.make<ov::op::v8::MaxPool>(input,
max_pool_v14->get_strides(),
max_pool_v14->get_dilations(),
max_pool_v14->get_pads_begin(),
max_pool_v14->get_pads_end(),
max_pool_v14->get_kernel(),
ov::op::RoundingType::CEIL);
const auto shape_of_mp = node_registry.make<ShapeOf>(mp, element::i64);
const auto gth_out_dims = node_registry.make<Gather>(shape_of_mp, dim_idxs, zero);
const auto out_sub_one = node_registry.make<Subtract>(gth_out_dims, one);
const auto stride_node = node_registry.make<Constant>(element::i64, Shape{strides.size()}, strides);
const auto out_mul_stride = node_registry.make<Multiply>(out_sub_one, stride_node);

// if (in_dim + pad) > ((out_dim - 1) * stride) sliding window in bound use end padding.
const auto in_gt_out = node_registry.make<Greater>(in_left_padded, out_mul_stride);
const auto selected_pads = node_registry.make<Select>(in_gt_out, padding_end_node, zero);

// apply padding on input clear pads attribute
const auto pb = node_registry.make<Concat>(OutputVector{pads_remaining->output(0), padding_end_node}, 0);
const auto pe = node_registry.make<Concat>(OutputVector{pads_remaining, selected_pads}, 0);
auto minus_inf =
node_registry.make<Constant>(element::f32, Shape{}, -std::numeric_limits<float>::infinity());
std::shared_ptr<ov::Node> convert_like_node = node_registry.make<ConvertLike>(minus_inf, input);
const auto pad_node = node_registry.make<Pad>(input, pb, pe, convert_like_node, op::PadMode::CONSTANT);
auto pads_begin = max_pool_v14->get_pads_begin();
auto pads_end = max_pool_v14->get_pads_end();
std::fill_n(pads_begin.begin(), pads_begin.size(), 0);
std::fill_n(pads_end.begin(), pads_end.size(), 0);

max_pool_v8 = node_registry.make<ov::op::v8::MaxPool>(pad_node,
max_pool_v14->get_strides(),
max_pool_v14->get_dilations(),
pads_begin,
pads_end,
max_pool_v14->get_kernel(),
ov::op::RoundingType::CEIL,
ov::op::PadType::EXPLICIT,
max_pool_v14->get_index_element_type(),
max_pool_v14->get_axis());
copy_runtime_info(max_pool_v14, node_registry.get());
} else {
max_pool_v8 = std::make_shared<ov::op::v8::MaxPool>(max_pool_v14->input_value(0),
max_pool_v14->get_strides(),
max_pool_v14->get_dilations(),
max_pool_v14->get_pads_begin(),
max_pool_v14->get_pads_end(),
max_pool_v14->get_kernel(),
rounding_type_v14,
max_pool_v14->get_auto_pad(),
max_pool_v14->get_index_element_type(),
max_pool_v14->get_axis());
copy_runtime_info(max_pool_v14, max_pool_v8);
}
max_pool_v8->set_friendly_name(max_pool_v14->get_friendly_name());
replace_node(max_pool_v14, max_pool_v8);
max_pool_v14->clear_control_dependencies();
return true;
};

auto m = std::make_shared<pattern::Matcher>(max_pool_v14_pattern, matcher_name);
register_matcher(m, callback);
}
Loading

0 comments on commit bf959b7

Please sign in to comment.