Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support complex tensors for Reciprocal operations #23355

Merged
merged 22 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/frontends/tensorflow_common/src/op/reciprocal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
//

#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
rkazants marked this conversation as resolved.
Show resolved Hide resolved
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/negative.hpp"
#include "openvino/op/power.hpp"
rkazants marked this conversation as resolved.
Show resolved Hide resolved
#include "utils.hpp"

Expand All @@ -16,8 +22,34 @@ namespace op {

OutputVector translate_reciprocal_op(const NodeContext& node) {
// computes element-wise 1/x, where x - input
default_op_checks(node, 1, {"Reciprocal"});
default_op_checks(node, 1, {"Reciprocal"}, true);
auto x = node.get_input(0);
auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
if (complex_type_mark_x) {
x = complex_type_mark_x->input_value(0);
auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
auto two = create_same_type_const_scalar<int32_t>(x, 2);
auto gather_index_real = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto gather_index_imag = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto x_real = make_shared<v8::Gather>(x, gather_index_real, minus_one)->output(0);
auto x_imag = make_shared<v8::Gather>(x, gather_index_imag, minus_one)->output(0);

// compute (a^2+b^2)
auto real_squared_norm = make_shared<v1::Power>(x_real, two);
auto img_squared_norm = make_shared<v1::Power>(x_imag, two);
auto squared_norm = make_shared<v1::Add>(real_squared_norm, img_squared_norm);
rkazants marked this conversation as resolved.
Show resolved Hide resolved

// compute 1/(a+bi) = (a-bi)/(a^2+b^2)
auto complex_reciprocal = make_shared<v1::Divide>(
make_shared<v0::Concat>(OutputVector{x_real, make_shared<ov::op::v0::Negative>(x_imag)}, -1),
squared_norm);
auto complex_result =
make_shared<ComplexTypeMark>(complex_reciprocal, complex_type_mark_x->get_complex_part_type());
set_node_name(node.get_name(), complex_reciprocal);
rkazants marked this conversation as resolved.
Show resolved Hide resolved
return {complex_result};
}

// For real numbers, computes element-wise 1/x, where x - input
auto minus_one_const = create_same_type_const_scalar<int32_t>(x, -1);
auto reciprocal = make_shared<v1::Power>(x, minus_one_const);
set_node_name(node.get_name(), reciprocal);
Expand Down
43 changes: 43 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,46 @@ def test_reciprocal_basic(self, params, ie_device, precision, ir_version, temp_d
self._test(*self.create_reciprocal_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

class TestComplexReciprocal(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real_1:0' in inputs_info
assert 'param_imag_1:0' in inputs_info
param_real_shape_1 = inputs_info['param_real_1:0']
param_imag_shape_1 = inputs_info['param_imag_1:0']
inputs_data = {}
inputs_data['param_real_1:0'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag_1:0'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2

return inputs_data

def create_complex_reciprocal_net(self, x_shape,x_type):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real1 = tf.compat.v1.placeholder(np.float32, x_shape, 'param_real_1')
param_imag1 = tf.compat.v1.placeholder(np.float32, x_shape, 'param_imag_1')
complex_x = tf.raw_ops.Complex(real=param_real1, imag=param_imag1)
reciprocal = tf.raw_ops.Reciprocal(x=complex_x)
real = tf.raw_ops.Real(input=reciprocal)
img = tf.raw_ops.Imag(input=reciprocal)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

test_data_basic = [
dict(x_shape=[], x_type=np.float32),
dict(x_shape=[2, 3], x_type=np.float32),
dict(x_shape=[4, 1, 3], x_type=np.float32),
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_reciprocal(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_complex_reciprocal_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
Loading