Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support complex tensors for Reciprocal operations #23355

Merged
merged 22 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/frontends/tensorflow_common/src/op/reciprocal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/power.hpp"
rkazants marked this conversation as resolved.
Show resolved Hide resolved
#include "utils.hpp"

Expand All @@ -16,13 +17,27 @@ namespace op {

OutputVector translate_reciprocal_op(const NodeContext& node) {
// computes element-wise 1/x, where x - input
default_op_checks(node, 1, {"Reciprocal"});
default_op_checks(node, 1, {"Reciprocal"}, true);
auto x = node.get_input(0);
auto minus_one_const = create_same_type_const_scalar<int32_t>(x, -1);
auto reciprocal = make_shared<v1::Power>(x, minus_one_const);

auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());

if (complex_type_mark_x) {
x = complex_type_mark_x->input_value(0);
}

auto reciprocal = make_shared<v1::Power>(complex_type_mark_x, minus_one_const);

set_node_name(node.get_name(), reciprocal);
if (complex_type_mark_x) {
auto complex_reciprocal = make_shared<ComplexTypeMark>(reciprocal, complex_type_mark_x);
return {complex_reciprocal->output(0)};
}
rkazants marked this conversation as resolved.
Show resolved Hide resolved

return {reciprocal};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
Expand Down
47 changes: 47 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,50 @@ def test_reciprocal_basic(self, params, ie_device, precision, ir_version, temp_d
self._test(*self.create_reciprocal_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

class TestComplexReciprocal(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real_1:0' in inputs_info
assert 'param_imag_1:0' in inputs_info
param_real_shape_1 = inputs_info['param_real_1:0']
param_imag_shape_1 = inputs_info['param_imag_1:0']
inputs_data = {}
inputs_data['param_real_1:0'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag_1:0'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data

def create_complex_reciprocal_net(self, input_shape):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_1')
param_imag1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_1')
complex_x = tf.raw_ops.Complex(real=param_real1, imag=param_imag1)
reciprocal = tf.raw_ops.Reciprocal(x=complex_x)
real = tf.raw_ops.Real(input=reciprocal)
img = tf.raw_ops.Imag(input=reciprocal)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None


test_data_basic = [
dict(input_shape=[]),
dict(input_shape=[2]),
dict(input_shape=[1, 3]),
dict(input_shape=[2, 3, 4]),
dict(input_shape=[3, 4, 5, 6]),
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_reciprocal(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(
*self.create_complex_reciprocal_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
Loading