forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LSTMCellFusion transformation (openvinotoolkit#21594)
* Add LSTMCellFusion transformation Partially fixes: CVS-125605 * code style * fix accuracy issue * add headers
- Loading branch information
1 parent
ec8765f
commit d03dc4f
Showing
4 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
src/common/transformations/include/transformations/common_optimizations/lstm_cell_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API LSTMCellFusion; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief LSTMCellFusion transformation replaces a sequence of | ||
* operations with LSTMCell op. | ||
*/ | ||
class ov::pass::LSTMCellFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("LSTMCellFusion", "0"); | ||
LSTMCellFusion(); | ||
}; |
212 changes: 212 additions & 0 deletions
212
src/common/transformations/src/transformations/common_optimizations/lstm_cell_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/lstm_cell_fusion.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/lstm_cell.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/relu.hpp" | ||
#include "openvino/op/sigmoid.hpp" | ||
#include "openvino/op/split.hpp" | ||
#include "openvino/op/squeeze.hpp" | ||
#include "openvino/op/tanh.hpp" | ||
#include "openvino/op/variadic_split.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "validation_util.hpp" | ||
|
||
static std::string get_activation_name(const std::shared_ptr<ov::Node>& node) { | ||
std::string name = node->get_type_name(); | ||
name[0] = std::tolower(name[0]); | ||
return name; | ||
} | ||
|
||
ov::pass::LSTMCellFusion::LSTMCellFusion() { | ||
MATCHER_SCOPE(LSTMCellFusion); | ||
|
||
auto x_label = pattern::any_input(pattern::rank_equals(2)); | ||
auto h_label = pattern::any_input(pattern::rank_equals(2)); | ||
auto concat_label = pattern::wrap_type<op::v0::Concat>({x_label, h_label}); | ||
auto weights_label = pattern::any_input([](const Output<Node>& output) { | ||
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output); | ||
}); | ||
auto matmul_label = pattern::wrap_type<op::v0::MatMul>({concat_label, weights_label}); | ||
auto bias_label = pattern::any_input([](const Output<Node>& output) { | ||
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output); | ||
}); | ||
auto bias_add_label = pattern::wrap_type<op::v1::Add>({matmul_label, bias_label}); | ||
auto axis_label = pattern::wrap_type<op::v0::Constant>(); | ||
auto split_label = pattern::wrap_type<op::v1::Split>({bias_add_label, axis_label}); | ||
auto it_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label}); | ||
auto ct_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label}); | ||
auto ft_additional_bias_label = pattern::wrap_type<op::v0::Constant>(); | ||
auto add_label = pattern::wrap_type<op::v1::Add>({split_label, ft_additional_bias_label}); | ||
auto ft_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({add_label}); | ||
auto ot_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({split_label}); | ||
auto mul_label = pattern::wrap_type<op::v1::Multiply>({it_label, ct_label}); | ||
auto c_label = pattern::any_input(); | ||
auto mul1_label = pattern::wrap_type<op::v1::Multiply>({ft_label, c_label}); | ||
auto Co_label = pattern::wrap_type<op::v1::Add>({mul_label, mul1_label}); | ||
auto Co_activation_label = pattern::wrap_type<op::v0::Relu, op::v0::Sigmoid, op::v0::Tanh>({Co_label}); | ||
auto Ho_label = pattern::wrap_type<op::v1::Multiply>({Co_activation_label, ot_label}); | ||
|
||
matcher_pass_callback callback = [=](pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
|
||
const auto& X = pattern_map.at(x_label); | ||
const auto& H = pattern_map.at(h_label); | ||
const auto& C = pattern_map.at(c_label); | ||
const auto& WR = pattern_map.at(weights_label); | ||
const auto& B = pattern_map.at(bias_label); | ||
const auto& ft_additional_bias = pattern_map.at(ft_additional_bias_label); | ||
auto Ho = pattern_map.at(Ho_label); | ||
auto Co = pattern_map.at(Co_label); | ||
|
||
const auto& WR_shape = WR.get_shape(); | ||
const auto& B_shape = B.get_shape(); | ||
const auto& ft_additional_bias_shape = ft_additional_bias.get_shape(); | ||
|
||
if (WR_shape[0] % 4 != 0) | ||
return false; | ||
if (WR_shape[0] != B_shape[1]) | ||
return false; | ||
if (B_shape[0] != 1) | ||
return false; | ||
if (shape_size(ft_additional_bias_shape) != 1) | ||
return false; | ||
|
||
size_t hidden_size = WR_shape[0] / 4; | ||
|
||
if (WR_shape[1] <= hidden_size) | ||
return false; | ||
|
||
size_t input_size = WR_shape[1] - hidden_size; | ||
|
||
const auto& X_shape = X.get_partial_shape(); | ||
const auto& H_shape = H.get_partial_shape(); | ||
const auto& C_shape = C.get_partial_shape(); | ||
|
||
if (!H_shape[0].compatible(X_shape[0])) | ||
return false; | ||
|
||
if (!C_shape[0].compatible(X_shape[0])) | ||
return false; | ||
|
||
if (!X_shape[1].compatible(input_size)) | ||
return false; | ||
|
||
if (!H_shape[1].compatible(hidden_size)) | ||
return false; | ||
|
||
if (!C_shape[1].compatible(hidden_size)) | ||
return false; | ||
|
||
NodeVector split_consumers{pattern_map.at(it_label).get_node_shared_ptr(), | ||
pattern_map.at(ct_label).get_node_shared_ptr(), | ||
pattern_map.at(ot_label).get_node_shared_ptr(), | ||
pattern_map.at(add_label).get_node_shared_ptr()}; | ||
|
||
std::shared_ptr<Node> it; | ||
std::shared_ptr<Node> ct; | ||
std::shared_ptr<Node> ot; | ||
std::shared_ptr<Node> add; | ||
|
||
// manually match split consumers to gates | ||
for (const auto& n : split_consumers) { | ||
if (n->input_value(0).get_index() == 0) | ||
it = n; | ||
else if (n->input_value(0).get_index() == 1) | ||
ct = n; | ||
else if (n->input_value(0).get_index() == 2) | ||
add = n; | ||
else if (n->input_value(0).get_index() == 3) | ||
ot = n; | ||
} | ||
|
||
auto ft = pattern_map.at(ft_label).get_node_shared_ptr(); | ||
|
||
std::string f_activation_name = ft->get_type_name(); | ||
|
||
if (f_activation_name != it->get_type_name() || f_activation_name != ot->get_type_name()) | ||
return false; | ||
|
||
f_activation_name[0] = std::tolower(f_activation_name[0]); | ||
std::string g_activation_name = get_activation_name(ct); | ||
|
||
auto Co_activation = pattern_map.at(Co_activation_label).get_node_shared_ptr(); | ||
std::string h_activation_name = get_activation_name(Co_activation); | ||
|
||
auto zero = op::v0::Constant::create(element::i32, Shape{}, {0}); | ||
auto WR_split = std::make_shared<op::v1::Split>(WR, zero /* axis */, 4); | ||
auto WR_fico = std::make_shared<op::v0::Concat>( | ||
OutputVector{WR_split->output(2), WR_split->output(0), WR_split->output(1), WR_split->output(3)}, | ||
0); | ||
auto one = op::v0::Constant::create(element::i32, Shape{}, {1}); | ||
auto split_lengths = op::v0::Constant::create(element::i32, Shape{2}, {input_size, hidden_size}); | ||
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, one /* axis */, split_lengths); | ||
Output<Node> W = vsplit->output(0); | ||
if (auto constant = ov::util::constantfold_subgraph(W)) | ||
W = constant; | ||
Output<Node> R = vsplit->output(1); | ||
if (auto constant = ov::util::constantfold_subgraph(R)) | ||
R = constant; | ||
|
||
auto B_split = std::make_shared<op::v1::Split>(std::make_shared<op::v0::Squeeze>(B, zero), zero /* axis */, 4); | ||
auto B_f = | ||
std::make_shared<op::v1::Add>(B_split->output(2), std::make_shared<op::v0::Squeeze>(ft_additional_bias)); | ||
|
||
Output<Node> B_fico = std::make_shared<op::v0::Concat>( | ||
OutputVector{B_f, B_split->output(0), B_split->output(1), B_split->output(3)}, | ||
0); | ||
if (auto constant = ov::util::constantfold_subgraph(B_fico)) | ||
B_fico = constant; | ||
|
||
auto lstm_cell = std::make_shared<op::v4::LSTMCell>( | ||
X, | ||
H, | ||
C, | ||
W, | ||
R, | ||
B_fico, | ||
hidden_size, | ||
std::vector<std::string>{f_activation_name, g_activation_name, h_activation_name}); | ||
lstm_cell->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
|
||
copy_runtime_info( | ||
{ | ||
pattern_map.at(concat_label).get_node_shared_ptr(), | ||
WR.get_node_shared_ptr(), | ||
pattern_map.at(matmul_label).get_node_shared_ptr(), | ||
B.get_node_shared_ptr(), | ||
pattern_map.at(bias_add_label).get_node_shared_ptr(), | ||
pattern_map.at(split_label).get_node_shared_ptr(), | ||
it, | ||
ct, | ||
ft, | ||
ot, | ||
pattern_map.at(add_label).get_node_shared_ptr(), | ||
pattern_map.at(mul_label).get_node_shared_ptr(), | ||
C.get_node_shared_ptr(), | ||
pattern_map.at(mul1_label).get_node_shared_ptr(), | ||
pattern_map.at(Co_label).get_node_shared_ptr(), | ||
Co.get_node_shared_ptr(), | ||
Co_activation, | ||
Ho.get_node_shared_ptr(), | ||
}, | ||
{W.get_node_shared_ptr(), R.get_node_shared_ptr(), B_fico.get_node_shared_ptr(), lstm_cell}); | ||
|
||
Ho.replace(lstm_cell->output(0)); | ||
Co.replace(lstm_cell->output(1)); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<pattern::Matcher>(Ho_label, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
src/common/transformations/tests/common_optimizations/lstm_cell_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/lstm_cell_fusion.hpp" | ||
|
||
#include "common_test_utils/ov_test_utils.hpp" | ||
#include "openvino/op/abs.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/lstm_cell.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
#include "openvino/op/sigmoid.hpp" | ||
#include "openvino/op/split.hpp" | ||
#include "openvino/op/tanh.hpp" | ||
#include "openvino/pass/constant_folding.hpp" | ||
|
||
using namespace ov; | ||
|
||
TEST_F(TransformationTestsF, LSTMCellFusion) { | ||
size_t input_size = 3; | ||
size_t hidden_size = 2; | ||
{ | ||
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size}); | ||
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size}); | ||
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size}); | ||
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1); | ||
Shape WR_shape{4 * hidden_size, input_size + hidden_size}; | ||
std::vector<float> WR_values(shape_size(WR_shape)); | ||
std::iota(WR_values.begin(), WR_values.end(), 0.0f); | ||
auto WR = op::v0::Constant::create(element::f32, WR_shape, WR_values); | ||
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, true); | ||
Shape B_shape{1, 4 * hidden_size}; | ||
std::vector<float> B_values(shape_size(B_shape)); | ||
std::iota(B_values.begin(), B_values.end(), 0.0f); | ||
auto B = op::v0::Constant::create(element::f32, B_shape, B_values); | ||
auto biasadd = std::make_shared<op::v1::Add>(matmul, B); | ||
auto one = op::v0::Constant::create(element::i32, Shape{}, {1}); | ||
auto split = std::make_shared<op::v1::Split>(biasadd, one /* axis */, 4 /* num splits */); | ||
auto it = std::make_shared<op::v0::Sigmoid>(split->output(0)); | ||
auto ct = std::make_shared<op::v0::Tanh>(split->output(1)); | ||
auto ft = std::make_shared<op::v0::Sigmoid>( | ||
std::make_shared<op::v1::Add>(split->output(2), op::v0::Constant::create(element::f32, Shape{1, 1}, {1}))); | ||
auto ot = std::make_shared<op::v0::Sigmoid>(split->output(3)); | ||
auto mul = std::make_shared<op::v1::Multiply>(it, ct); | ||
auto mul1 = std::make_shared<op::v1::Multiply>(ft, C); | ||
auto Ct = std::make_shared<op::v1::Add>(mul, mul1); | ||
auto Ht = std::make_shared<op::v1::Multiply>(std::make_shared<op::v0::Tanh>(Ct), ot); | ||
auto C_abs = std::make_shared<op::v0::Abs>(Ct); | ||
auto H_abs = std::make_shared<op::v0::Abs>(Ht); | ||
model = std::make_shared<Model>(NodeVector{H_abs, C_abs}, ParameterVector{X, H, C}); | ||
manager.register_pass<ov::pass::LSTMCellFusion>(); | ||
} | ||
|
||
{ | ||
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size}); | ||
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size}); | ||
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size}); | ||
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1); | ||
Shape W_shape{4 * hidden_size, input_size}; | ||
Shape R_shape{4 * hidden_size, hidden_size}; | ||
std::vector<float> W_values{ | ||
20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37, | ||
}; | ||
auto W = op::v0::Constant::create(element::f32, W_shape, W_values); | ||
std::vector<float> R_values{ | ||
23, | ||
24, | ||
28, | ||
29, | ||
3, | ||
4, | ||
8, | ||
9, | ||
13, | ||
14, | ||
18, | ||
19, | ||
33, | ||
34, | ||
38, | ||
39, | ||
}; | ||
auto R = op::v0::Constant::create(element::f32, R_shape, R_values); | ||
Shape B_shape{4 * hidden_size}; | ||
std::vector<float> B_values{5, 6, 0, 1, 2, 3, 6, 7}; | ||
auto B = op::v0::Constant::create(element::f32, B_shape, B_values); | ||
auto lstm_cell = std::make_shared<op::v4::LSTMCell>(X, | ||
H, | ||
C, | ||
W, | ||
R, | ||
B, | ||
hidden_size, | ||
std::vector<std::string>{"sigmoid", "tanh", "tanh"}); | ||
auto C_abs = std::make_shared<op::v0::Abs>(lstm_cell->output(1)); | ||
auto H_abs = std::make_shared<op::v0::Abs>(lstm_cell->output(0)); | ||
model_ref = std::make_shared<Model>(NodeVector{H_abs, C_abs}, ParameterVector{X, H, C}); | ||
manager.register_pass<ov::pass::LSTMCellFusion>(); | ||
} | ||
|
||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); | ||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); | ||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY); | ||
} |