Skip to content

Commit

Permalink
Add LSTMCellFusion transformation (openvinotoolkit#21594)
Browse files Browse the repository at this point in the history
* Add LSTMCellFusion transformation

Partially fixes: CVS-125605

* code style

* fix accuracy issue

* add headers
  • Loading branch information
mateusztabaka authored Dec 18, 2023
1 parent ec8765f commit d03dc4f
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API LSTMCellFusion;

} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief LSTMCellFusion transformation replaces a sequence of
* operations with LSTMCell op.
*/
class ov::pass::LSTMCellFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("LSTMCellFusion", "0");
LSTMCellFusion();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/lstm_cell_fusion.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/lstm_cell.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "validation_util.hpp"

static std::string get_activation_name(const std::shared_ptr<ov::Node>& node) {
std::string name = node->get_type_name();
name[0] = std::tolower(name[0]);
return name;
}

ov::pass::LSTMCellFusion::LSTMCellFusion() {
MATCHER_SCOPE(LSTMCellFusion);

auto x_label = pattern::any_input(pattern::rank_equals(2));
auto h_label = pattern::any_input(pattern::rank_equals(2));
auto concat_label = pattern::wrap_type<op::v0::Concat>({x_label, h_label});
auto weights_label = pattern::any_input([](const Output<Node>& output) {
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
});
auto matmul_label = pattern::wrap_type<op::v0::MatMul>({concat_label, weights_label});
auto bias_label = pattern::any_input([](const Output<Node>& output) {
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
});
auto bias_add_label = pattern::wrap_type<op::v1::Add>({matmul_label, bias_label});
auto axis_label = pattern::wrap_type<op::v0::Constant>();
auto split_label = pattern::wrap_type<op::v1::Split>({bias_add_label, axis_label});
auto it_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label});
auto ct_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label});
auto ft_additional_bias_label = pattern::wrap_type<op::v0::Constant>();
auto add_label = pattern::wrap_type<op::v1::Add>({split_label, ft_additional_bias_label});
auto ft_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({add_label});
auto ot_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label});
auto mul_label = pattern::wrap_type<op::v1::Multiply>({it_label, ct_label});
auto c_label = pattern::any_input();
auto mul1_label = pattern::wrap_type<op::v1::Multiply>({ft_label, c_label});
auto Co_label = pattern::wrap_type<op::v1::Add>({mul_label, mul1_label});
auto Co_activation_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({Co_label});
auto Ho_label = pattern::wrap_type<op::v1::Multiply>({Co_activation_label, ot_label});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

const auto& X = pattern_map.at(x_label);
const auto& H = pattern_map.at(h_label);
const auto& C = pattern_map.at(c_label);
const auto& WR = pattern_map.at(weights_label);
const auto& B = pattern_map.at(bias_label);
const auto& ft_additional_bias = pattern_map.at(ft_additional_bias_label);
auto Ho = pattern_map.at(Ho_label);
auto Co = pattern_map.at(Co_label);

const auto& WR_shape = WR.get_shape();
const auto& B_shape = B.get_shape();
const auto& ft_additional_bias_shape = ft_additional_bias.get_shape();

if (WR_shape[0] % 4 != 0)
return false;
if (WR_shape[0] != B_shape[1])
return false;
if (B_shape[0] != 1)
return false;
if (shape_size(ft_additional_bias_shape) != 1)
return false;

size_t hidden_size = WR_shape[0] / 4;

if (WR_shape[1] <= hidden_size)
return false;

size_t input_size = WR_shape[1] - hidden_size;

const auto& X_shape = X.get_partial_shape();
const auto& H_shape = H.get_partial_shape();
const auto& C_shape = C.get_partial_shape();

if (!H_shape[0].compatible(X_shape[0]))
return false;

if (!C_shape[0].compatible(X_shape[0]))
return false;

if (!X_shape[1].compatible(input_size))
return false;

if (!H_shape[1].compatible(hidden_size))
return false;

if (!C_shape[1].compatible(hidden_size))
return false;

NodeVector split_consumers{pattern_map.at(it_label).get_node_shared_ptr(),
pattern_map.at(ct_label).get_node_shared_ptr(),
pattern_map.at(ot_label).get_node_shared_ptr(),
pattern_map.at(add_label).get_node_shared_ptr()};

std::shared_ptr<Node> it;
std::shared_ptr<Node> ct;
std::shared_ptr<Node> ot;
std::shared_ptr<Node> add;

// manually match split consumers to gates
for (const auto& n : split_consumers) {
if (n->input_value(0).get_index() == 0)
it = n;
else if (n->input_value(0).get_index() == 1)
ct = n;
else if (n->input_value(0).get_index() == 2)
add = n;
else if (n->input_value(0).get_index() == 3)
ot = n;
}

auto ft = pattern_map.at(ft_label).get_node_shared_ptr();

std::string f_activation_name = ft->get_type_name();

if (f_activation_name != it->get_type_name() || f_activation_name != ot->get_type_name())
return false;

f_activation_name[0] = std::tolower(f_activation_name[0]);
std::string g_activation_name = get_activation_name(ct);

auto Co_activation = pattern_map.at(Co_activation_label).get_node_shared_ptr();
std::string h_activation_name = get_activation_name(Co_activation);

auto zero = op::v0::Constant::create(element::i32, Shape{}, {0});
auto WR_split = std::make_shared<op::v1::Split>(WR, zero /* axis */, 4);
auto WR_fico = std::make_shared<op::v0::Concat>(
OutputVector{WR_split->output(2), WR_split->output(0), WR_split->output(1), WR_split->output(3)},
0);
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
auto split_lengths = op::v0::Constant::create(element::i32, Shape{2}, {input_size, hidden_size});
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, one /* axis */, split_lengths);
Output<Node> W = vsplit->output(0);
if (auto constant = ov::util::constantfold_subgraph(W))
W = constant;
Output<Node> R = vsplit->output(1);
if (auto constant = ov::util::constantfold_subgraph(R))
R = constant;

auto B_split = std::make_shared<op::v1::Split>(std::make_shared<op::v0::Squeeze>(B, zero), zero /* axis */, 4);
auto B_f =
std::make_shared<op::v1::Add>(B_split->output(2), std::make_shared<op::v0::Squeeze>(ft_additional_bias));

Output<Node> B_fico = std::make_shared<op::v0::Concat>(
OutputVector{B_f, B_split->output(0), B_split->output(1), B_split->output(3)},
0);
if (auto constant = ov::util::constantfold_subgraph(B_fico))
B_fico = constant;

auto lstm_cell = std::make_shared<op::v4::LSTMCell>(
X,
H,
C,
W,
R,
B_fico,
hidden_size,
std::vector<std::string>{f_activation_name, g_activation_name, h_activation_name});
lstm_cell->set_friendly_name(m.get_match_root()->get_friendly_name());

copy_runtime_info(
{
pattern_map.at(concat_label).get_node_shared_ptr(),
WR.get_node_shared_ptr(),
pattern_map.at(matmul_label).get_node_shared_ptr(),
B.get_node_shared_ptr(),
pattern_map.at(bias_add_label).get_node_shared_ptr(),
pattern_map.at(split_label).get_node_shared_ptr(),
it,
ct,
ft,
ot,
pattern_map.at(add_label).get_node_shared_ptr(),
pattern_map.at(mul_label).get_node_shared_ptr(),
C.get_node_shared_ptr(),
pattern_map.at(mul1_label).get_node_shared_ptr(),
pattern_map.at(Co_label).get_node_shared_ptr(),
Co.get_node_shared_ptr(),
Co_activation,
Ho.get_node_shared_ptr(),
},
{W.get_node_shared_ptr(), R.get_node_shared_ptr(), B_fico.get_node_shared_ptr(), lstm_cell});

Ho.replace(lstm_cell->output(0));
Co.replace(lstm_cell->output(1));

return true;
};

auto m = std::make_shared<pattern::Matcher>(Ho_label, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "transformations/common_optimizations/hswish_fusion.hpp"
#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/lstm_cell_fusion.hpp"
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "transformations/common_optimizations/matmul_multiply_fusion.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
Expand Down Expand Up @@ -171,6 +172,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, PullThroughReduce)

// GRUCellFusion and SequenceFusion should be before NopElimination
REGISTER_PASS(manager, LSTMCellFusion)
REGISTER_PASS(manager, GRUCellFusion)
REGISTER_PASS(manager, SequenceFusion)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/lstm_cell_fusion.hpp"

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/lstm_cell.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/pass/constant_folding.hpp"

using namespace ov;

TEST_F(TransformationTestsF, LSTMCellFusion) {
size_t input_size = 3;
size_t hidden_size = 2;
{
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size});
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
Shape WR_shape{4 * hidden_size, input_size + hidden_size};
std::vector<float> WR_values(shape_size(WR_shape));
std::iota(WR_values.begin(), WR_values.end(), 0.0f);
auto WR = op::v0::Constant::create(element::f32, WR_shape, WR_values);
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, true);
Shape B_shape{1, 4 * hidden_size};
std::vector<float> B_values(shape_size(B_shape));
std::iota(B_values.begin(), B_values.end(), 0.0f);
auto B = op::v0::Constant::create(element::f32, B_shape, B_values);
auto biasadd = std::make_shared<op::v1::Add>(matmul, B);
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
auto split = std::make_shared<op::v1::Split>(biasadd, one /* axis */, 4 /* num splits */);
auto it = std::make_shared<op::v0::Sigmoid>(split->output(0));
auto ct = std::make_shared<op::v0::Tanh>(split->output(1));
auto ft = std::make_shared<op::v0::Sigmoid>(
std::make_shared<op::v1::Add>(split->output(2), op::v0::Constant::create(element::f32, Shape{1, 1}, {1})));
auto ot = std::make_shared<op::v0::Sigmoid>(split->output(3));
auto mul = std::make_shared<op::v1::Multiply>(it, ct);
auto mul1 = std::make_shared<op::v1::Multiply>(ft, C);
auto Ct = std::make_shared<op::v1::Add>(mul, mul1);
auto Ht = std::make_shared<op::v1::Multiply>(std::make_shared<op::v0::Tanh>(Ct), ot);
auto C_abs = std::make_shared<op::v0::Abs>(Ct);
auto H_abs = std::make_shared<op::v0::Abs>(Ht);
model = std::make_shared<Model>(NodeVector{H_abs, C_abs}, ParameterVector{X, H, C});
manager.register_pass<ov::pass::LSTMCellFusion>();
}

{
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size});
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
Shape W_shape{4 * hidden_size, input_size};
Shape R_shape{4 * hidden_size, hidden_size};
std::vector<float> W_values{
20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37,
};
auto W = op::v0::Constant::create(element::f32, W_shape, W_values);
std::vector<float> R_values{
23,
24,
28,
29,
3,
4,
8,
9,
13,
14,
18,
19,
33,
34,
38,
39,
};
auto R = op::v0::Constant::create(element::f32, R_shape, R_values);
Shape B_shape{4 * hidden_size};
std::vector<float> B_values{5, 6, 0, 1, 2, 3, 6, 7};
auto B = op::v0::Constant::create(element::f32, B_shape, B_values);
auto lstm_cell = std::make_shared<op::v4::LSTMCell>(X,
H,
C,
W,
R,
B,
hidden_size,
std::vector<std::string>{"sigmoid", "tanh", "tanh"});
auto C_abs = std::make_shared<op::v0::Abs>(lstm_cell->output(1));
auto H_abs = std::make_shared<op::v0::Abs>(lstm_cell->output(0));
model_ref = std::make_shared<Model>(NodeVector{H_abs, C_abs}, ParameterVector{X, H, C});
manager.register_pass<ov::pass::LSTMCellFusion>();
}

comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

0 comments on commit d03dc4f

Please sign in to comment.