Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFORMATIONS] Fix Optional to match even with no inputs #23471

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,52 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

using namespace ov::pass::pattern::op;

/*
┌──────────────┐
│ Relu │
┌──────────────┐ └──────┬───────┘
│ Relu │ │
└──────┬───────┘ ┌──────┴───────┐ ┌──────────────┐
│ │WrapType<Relu>│ │ Relu │
┌──────┴───────┐ └──────┬───────┘ └───────┬──────┘
│Optional<Relu>│ Unfolds into │ │
└──────┬───────┘ └────────┐ ┌────────┘
│ │ │
┌─┴─┐ ┌┴──────┴┐
│ABS│ │ Or │
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: ABS -> Abs

└───┘ └────┬───┘
┌─┴─┐
│ABS│
└───┘

In case there're no inputs to the Optional, there's no second branch hence no need in the
Or node and we may omit it leaving only the WrapType node with the Optional entry inside.
*/

std::vector<ov::DiscreteTypeInfo> ov::pass::pattern::op::Optional::get_optional_types() const {
return optional_types;
}

bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) {
ov::OutputVector or_in_values = input_values();
auto wrap_node = std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, or_in_values);
or_in_values.push_back(wrap_node);
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node = std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);

// Either continue using the WrapType if there're no inputs to it or create an Or node,
// if there're other inputs to Optional creating another "branch" for matching.
// Use only the 0th input as a "data" input. (To be changed or considered when Optional
// starts supporting multiple inputs)
auto pattern = num_input_values_to_optional == 0 ? std::static_pointer_cast<Pattern>(wrap_node)
: std::static_pointer_cast<Pattern>(std::make_shared<Or>(
OutputVector{wrap_node, input_values_to_optional[0]}));

if (matcher->match_value(std::make_shared<ov::pass::pattern::op::Or>(or_in_values), graph_value)) {
if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) {
auto& pattern_map = matcher->get_pattern_value_map();
if (pattern_map.count(wrap_node)) {
pattern_map[shared_from_this()] = graph_value;
Expand Down
64 changes: 64 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reduce_sum.hpp"
Expand Down Expand Up @@ -508,6 +510,68 @@ TEST(pattern, matching_optional) {
std::make_shared<op::v0::Abs>(c)));
}

TEST(pattern, optional_full_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_add = ov::pass::pattern::optional<op::v1::Add>();
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu, model_relu));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend to test pattern_map in these tests
Please check that pattern_map contains pattern_add/pattern_relu keys and the corresponding values are equal to model_add and model_relu.

We can do it in the next PR, not a problem.

}

TEST(pattern, optional_half_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please check that pattern_map[pattern_relu] is nullptr in this case

}

TEST(pattern, optional_testing) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in OV we have some operations with "optional" inputs, e.g. Interpolate-11 operation
https://docs.openvino.ai/2023.3/openvino_docs_ops_image_Interpolate_11.html#

image

optional class might be used to match Interpolate with different number of inputs: 2 inputs and 3 inputs using 1 instance:
Interpolate(wrap_type, wrap_type, optional)

could you check if we have such a test? if not could you add it?

Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
auto model_abs = std::make_shared<op::v0::Abs>(model_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Exp>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Cos>(model_add), model_add));

ASSERT_TRUE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Abs>(model_abs)));
ASSERT_FALSE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Relu>(model_abs)));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_abs),
std::make_shared<op::v0::Relu>(model_abs)));

ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Exp>(model_add), model_abs));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Abs>(model_add), model_abs));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
}

TEST(pattern, mean) {
// construct mean
TestMatcher n;
Expand Down
Loading