Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Mar 15, 2024
1 parent c3ff345 commit 9d93ac5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node = std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, input_values_to_optional);
auto wrap_node =
std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, input_values_to_optional);

// Add the newly created WrapType node to the list containing its inputs and create an Or node with the list
input_values_to_optional.push_back(wrap_node);
Expand Down
62 changes: 62 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,68 @@ TEST(pattern, matching_optional) {
std::make_shared<op::v0::Abs>(c)));
}

TEST(pattern, optional_full_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_add = ov::pass::pattern::optional<op::v1::Add>();
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu, model_relu));
}

TEST(pattern, optional_half_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
}

TEST(pattern, optional_new_test) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
auto model_abs = std::make_shared<op::v0::Abs>(model_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v1::Multiply>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v1::Multiply>(model_add), model_add));

ASSERT_TRUE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Abs>(model_abs)));
ASSERT_FALSE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Relu>(model_abs)));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_abs),
std::make_shared<op::v0::Relu>(model_abs)));

ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v1::Divide>(model_add), model_abs));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v0::Abs>(model_add), model_abs));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
}

TEST(pattern, mean) {
// construct mean
TestMatcher n;
Expand Down

0 comments on commit 9d93ac5

Please sign in to comment.