Skip to content

Commit

Permalink
[TRANSFORMATIONS] Fix Optional to match even with no inputs
Browse files Browse the repository at this point in the history
The Optional pattern type may create a wrong pattern to match
if no inputs are provided to the Optional node. If no inputs
present to the Optional type, it will not create an alternative
branch(es) to check against resulting in the incorrect matching.

Fix that by adding a check for the number of inputs being 0.

Do a minor refactoring/renaming for the readability purposes.

Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
  • Loading branch information
CuriousPanCake committed Mar 15, 2024
1 parent 6f8b70f commit c3ff345
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@ std::vector<ov::DiscreteTypeInfo> ov::pass::pattern::op::Optional::get_optional_
bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) {
ov::OutputVector or_in_values = input_values();
auto wrap_node = std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, or_in_values);
or_in_values.push_back(wrap_node);
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node = std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, input_values_to_optional);

if (matcher->match_value(std::make_shared<ov::pass::pattern::op::Or>(or_in_values), graph_value)) {
// Add the newly created WrapType node to the list containing its inputs and create an Or node with the list
input_values_to_optional.push_back(wrap_node);
auto or_node = std::make_shared<ov::pass::pattern::op::Or>(input_values_to_optional);

if (matcher->match_value(or_node, graph_value) || num_input_values_to_optional == 0) {
auto& pattern_map = matcher->get_pattern_value_map();
if (pattern_map.count(wrap_node)) {
pattern_map[shared_from_this()] = graph_value;
Expand Down

0 comments on commit c3ff345

Please sign in to comment.