diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 0a7bebd8763215..a1071c1af0e3b8 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -29,104 +29,217 @@ class OperatorSupport(OperatorSupport): def __init__(self, options): support_dict = { "_operator.getitem": None, + "torch.ops.aten._adaptive_avg_pool1d.default": None, "torch.ops.aten._adaptive_avg_pool2d.default": None, + "torch.ops.aten._adaptive_avg_pool3d.default": None, + "torch.ops.aten._convolution.default": None, + "torch.ops.aten._embedding_bag.default": None, + "torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None, + "torch.ops.aten._local_scalar_dense.default": None, "torch.ops.aten._log_softmax.default": None, + "torch.ops.aten._native_batch_norm_legit.default": None, + "torch.ops.aten._native_batch_norm_legit.no_stats": None, + "torch.ops.aten._native_batch_norm_legit_functional.default": None, + "torch.ops.aten._native_batch_norm_legit_no_training.default": None, + "torch.ops.aten._scaled_dot_product_flash_attention.default": None, + "torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default": None, "torch.ops.aten._softmax.default": None, "torch.ops.aten._to_copy.default": None, "torch.ops.aten._unsafe_view.default": None, - "torch.ops.aten._unsafe_view.default": None, + "torch.ops.aten.abs.default": None, + "torch.ops.aten.acos.default": None, + "torch.ops.aten.acosh.default": None, + "torch.ops.aten.adaptive_max_pool1d.default": None, + "torch.ops.aten.adaptive_max_pool2d.default": None, + "torch.ops.aten.adaptive_max_pool3d.default": None, "torch.ops.aten.add.Scalar": None, "torch.ops.aten.add.Tensor": None, "torch.ops.aten.add_.Tensor": None, + "torch.ops.aten.addcmul.default": None, "torch.ops.aten.addmm.default": None, + "torch.ops.aten.alias.default": None, + "torch.ops.aten.all.default": None, "torch.ops.aten.amax.default": None, - "torch.ops.aten.arange.start": None, + "torch.ops.aten.amin.default": None, + "torch.ops.aten.any.default": None, + "torch.ops.aten.any.dim": None, "torch.ops.aten.arange.default": None, + "torch.ops.aten.arange.start": None, + "torch.ops.aten.arange.start_step": None, "torch.ops.aten.argmax.default": None, + "torch.ops.aten.argmin.default": None, + "torch.ops.aten.as_strided.default": None, + "torch.ops.aten.asin.default": None, + "torch.ops.aten.asinh.default": None, + "torch.ops.aten.asinh.default": None, + "torch.ops.aten.atanh.default": None, "torch.ops.aten.avg_pool2d.default": None, + "torch.ops.aten.avg_pool3d.default": None, "torch.ops.aten.baddbmm.default": None, "torch.ops.aten.bitwise_and.Tensor": None, + "torch.ops.aten.bitwise_not.default": None, + "torch.ops.aten.bitwise_or.Tensor": None, + "torch.ops.aten.bitwise_xor.Tensor": None, "torch.ops.aten.bmm.default": None, "torch.ops.aten.cat.default": None, + "torch.ops.aten.ceil.default": None, + "torch.ops.aten.clamp.default": None, + "torch.ops.aten.clamp_max.default": None, + "torch.ops.aten.clamp_max.Tensor": None, "torch.ops.aten.clamp_min.default": None, + "torch.ops.aten.clamp_min.Tensor": None, "torch.ops.aten.clone.default": None, + "torch.ops.aten.constant_pad_nd.default": None, "torch.ops.aten.convolution.default": None, + "torch.ops.aten.copy.default": None, "torch.ops.aten.copy_.default": None, "torch.ops.aten.cos.default": None, + "torch.ops.aten.cosh.default": None, "torch.ops.aten.cumsum.default": None, "torch.ops.aten.detach.default": None, + "torch.ops.aten.detach_.default": None, "torch.ops.aten.div.Scalar": None, "torch.ops.aten.div.Tensor": None, + "torch.ops.aten.div.Tensor_mode": None, + "torch.ops.aten.div_.Tensor": None, + "torch.ops.aten.elu.default": None, + "torch.ops.aten.elu_.default": None, "torch.ops.aten.embedding.default": None, "torch.ops.aten.empty.memory_format": None, - "torch.ops.aten.erf.default": None, "torch.ops.aten.eq.Scalar": None, "torch.ops.aten.eq.Tensor": None, + "torch.ops.aten.erf.default": None, "torch.ops.aten.exp.default": None, "torch.ops.aten.expand.default": None, + "torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None, "torch.ops.aten.fill.Scalar": None, + "torch.ops.aten.fill_.Scalar": None, + "torch.ops.aten.fill.Tensor": None, + "torch.ops.aten.fill_.Tensor": None, + "torch.ops.aten.flip.default": None, + "torch.ops.aten.floor.default": None, + "torch.ops.aten.floor.default": None, + "torch.ops.aten.fmod.Scalar": None, + "torch.ops.aten.fmod.Tensor": None, "torch.ops.aten.full.default": None, + "torch.ops.aten.full.names": None, + "torch.ops.aten.full_like.default": None, "torch.ops.aten.gather.default": None, + "torch.ops.aten.ge.Scalar": None, + "torch.ops.aten.ge.Tensor": None, "torch.ops.aten.gelu.default": None, + "torch.ops.aten.glu.default": None, + "torch.ops.aten.grid_sampler_2d.default": None, "torch.ops.aten.gt.Scalar": None, + "torch.ops.aten.gt.Tensor": None, "torch.ops.aten.hardsigmoid.default": None, + "torch.ops.aten.hardswish.default": None, "torch.ops.aten.hardswish_.default": None, + "torch.ops.aten.hardtanh.default": None, "torch.ops.aten.hardtanh_.default": None, "torch.ops.aten.index.Tensor": None, + "torch.ops.aten.index_select.default": None, + "torch.ops.aten.isfinite.default": None, + "torch.ops.aten.isinf.default": None, + "torch.ops.aten.isnan.default": None, + "torch.ops.aten.le.Scalar": None, + "torch.ops.aten.le.Tensor": None, + "torch.ops.aten.leaky_relu.default": None, "torch.ops.aten.leaky_relu_.default": None, "torch.ops.aten.lift_fresh_copy.default": None, "torch.ops.aten.linalg_vector_norm.default": None, - "torch.ops.aten.lt.Tensor": None, "torch.ops.aten.log.default": None, "torch.ops.aten.log_sigmoid_forward.default": None, + "torch.ops.aten.log10.default": None, + "torch.ops.aten.log1p.default": None, + "torch.ops.aten.log2.default": None, + "torch.ops.aten.logical_not.default": None, "torch.ops.aten.logsumexp.default": None, - "torch.ops.aten.masked_fill_.Scalar": None, + "torch.ops.aten.lt.Scalar": None, + "torch.ops.aten.lt.Tensor": None, + "torch.ops.aten.masked_fill.Scalar": None, "torch.ops.aten.masked_fill.Tensor": None, + "torch.ops.aten.masked_fill_.Scalar": None, + "torch.ops.aten.masked_fill_.Tensor": None, + "torch.ops.aten.max.default": None, "torch.ops.aten.max.dim": None, "torch.ops.aten.max_pool2d_with_indices.default": None, + "torch.ops.aten.max_pool3d_with_indices.default": None, + "torch.ops.aten.maximum.default": None, + "torch.ops.aten.mean.default": None, "torch.ops.aten.mean.dim": None, + "torch.ops.aten.min.default": None, + "torch.ops.aten.min.dim": None, + "torch.ops.aten.minimum.default": None, "torch.ops.aten.mm.default": None, "torch.ops.aten.mul.Scalar": None, "torch.ops.aten.mul.Tensor": None, "torch.ops.aten.native_batch_norm.default": None, - "torch.ops.aten._native_batch_norm_legit.default": None, - "torch.ops.aten._native_batch_norm_legit_no_training.default": None, + "torch.ops.aten.native_dropout.default": None, "torch.ops.aten.native_group_norm.default": None, "torch.ops.aten.native_layer_norm.default": None, - "torch.ops.aten.new_full.default": None, + "torch.ops.aten.ne.Scalar": None, + "torch.ops.aten.ne.Tensor": None, "torch.ops.aten.neg.default": None, + "torch.ops.aten.new_full.default": None, "torch.ops.aten.new_ones.default": None, + "torch.ops.aten.new_zeros.default": None, + "torch.ops.aten.ones.default": None, "torch.ops.aten.permute.default": None, + "torch.ops.aten.pow.Scalar": None, "torch.ops.aten.pow.Tensor_Scalar": None, + "torch.ops.aten.pow.Tensor_Tensor": None, + "torch.ops.aten.rand.default": None, + "torch.ops.aten.reciprocal.default": None, "torch.ops.aten.relu.default": None, "torch.ops.aten.relu_.default": None, + "torch.ops.aten.repeat.default": None, + "torch.ops.aten.roll.default": None, "torch.ops.aten.rsqrt.default": None, "torch.ops.aten.rsub.Scalar": None, - "torch.ops.aten._scaled_dot_product_flash_attention.default": None, + "torch.ops.aten.rsub.Tensor": None, "torch.ops.aten.scalar_tensor.default": None, + "torch.ops.aten.scatter.src": None, + "torch.ops.aten.scatter.value": None, "torch.ops.aten.select.int": None, + "torch.ops.aten.select_scatter.default": None, "torch.ops.aten.sigmoid.default": None, + "torch.ops.aten.sign.default": None, "torch.ops.aten.silu.default": None, "torch.ops.aten.silu_.default": None, "torch.ops.aten.sin.default": None, + "torch.ops.aten.sinh.default": None, "torch.ops.aten.slice.Tensor": None, + "torch.ops.aten.slice_scatter.default": None, + "torch.ops.aten.sort.default": None, "torch.ops.aten.split.Tensor": None, + "torch.ops.aten.split_with_sizes.default": None, + "torch.ops.aten.sqrt.default": None, "torch.ops.aten.squeeze.dim": None, "torch.ops.aten.squeeze.dims": None, "torch.ops.aten.stack.default": None, "torch.ops.aten.sub.default": None, "torch.ops.aten.sub.Tensor": None, + "torch.ops.aten.sum.default": None, "torch.ops.aten.sum.dim_IntList": None, "torch.ops.aten.t.default": None, + "torch.ops.aten.tan.default": None, "torch.ops.aten.tanh.default": None, + "torch.ops.aten.topk.default": None, "torch.ops.aten.transpose.int": None, + "torch.ops.aten.tril.default": None, + "torch.ops.aten.tril_.default": None, "torch.ops.aten.unbind.int": None, + "torch.ops.aten.unfold.default": None, "torch.ops.aten.unsqueeze.default": None, "torch.ops.aten.upsample_nearest2d.default": None, + "torch.ops.aten.var.correction": None, "torch.ops.aten.var_mean.correction": None, "torch.ops.aten.view.default": None, "torch.ops.aten.where.self": None, "torch.ops.aten.zeros_like.default": None, + "torch.ops.torchvision.deform_conv2d.default": None, + "torch.ops.torchvision.roi_align.default": None, } for op in _get_disabled_ops(options): diff --git a/src/frontends/pytorch/src/input_model.cpp b/src/frontends/pytorch/src/input_model.cpp index 1c06003855c39e..bd7927228b9980 100644 --- a/src/frontends/pytorch/src/input_model.cpp +++ b/src/frontends/pytorch/src/input_model.cpp @@ -24,7 +24,7 @@ InputModel::InputModel(const std::shared_ptr& model_decoder) : m_m const auto& outputs = m_model_decoder->outputs(); for (size_t i = 0; i < outputs.size(); ++i) { auto out_place = std::make_shared(*this, outputs[i]); - m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast(out_place)); + m_name_to_place.emplace(std::to_string(outputs[i]), std::dynamic_pointer_cast(out_place)); for (const auto& name : out_place->get_names()) { m_name_to_place.emplace(name, std::dynamic_pointer_cast(out_place)); } diff --git a/src/frontends/pytorch/src/op/any.cpp b/src/frontends/pytorch/src/op/any.cpp new file mode 100644 index 00000000000000..b2e24ec818c1b2 --- /dev/null +++ b/src/frontends/pytorch/src/op/any.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/not_equal.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_logical_or.hpp" +#include "openvino/op/reshape.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_any_fx(const NodeContext& context) { + num_inputs_check(context, 1, 3); + auto x = context.get_input(0); + + Output dims; + if (!context.input_is_none(1)) { + dims = context.get_input(1); + } else { + dims = get_axes_range(context, 0); + } + bool keep_dims = false; + if (!context.input_is_none(2)) + keep_dims = context.const_input(2); + auto any = context.mark_node(std::make_shared(x, dims, keep_dims)); + return {any}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/argmax_argmin.cpp b/src/frontends/pytorch/src/op/argmax_argmin.cpp index e052479d5c2202..edb7a938c30b52 100644 --- a/src/frontends/pytorch/src/op/argmax_argmin.cpp +++ b/src/frontends/pytorch/src/op/argmax_argmin.cpp @@ -31,7 +31,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode) } if (!context.input_is_none(1)) { auto axis = context.const_input(1); - auto topk = context.mark_node(std::make_shared(input, k, axis, mode, TopKSortType::NONE)); + auto topk = context.mark_node( + std::make_shared(input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true)); indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); if (!keep_dims) { auto axis_to_remove = context.mark_node(v0::Constant::create(element::i32, Shape{}, {axis})); @@ -41,7 +42,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode) int64_t axis = 0; auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto flatten_input = context.mark_node(std::make_shared(input, minus_one, false)); - auto topk = context.mark_node(std::make_shared(flatten_input, k, axis, mode, TopKSortType::NONE)); + auto topk = context.mark_node( + std::make_shared(flatten_input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true)); indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); if (keep_dims) { auto input_shape = context.mark_node(std::make_shared(input, element::i32)); diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 7dfb7ccd796ad7..7a926a38836c0f 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -102,20 +102,24 @@ OutputVector translate_quantized_cat(const NodeContext& context) { }; OutputVector translate_stack_fx(const NodeContext& context) { - num_inputs_check(context, 2, context.get_input_size()); + num_inputs_check(context, 1, context.get_input_size()); auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); std::deque> list_elems; auto num_elements = context.get_input_size(); - if (num_elements > 2) - num_elements = num_elements - 1; - for (size_t i = 0; i < num_elements; i++) { + for (size_t i = 0; i < num_elements - 1; i++) { auto stack_input = context.mark_node(std::make_shared(context.get_input(static_cast(i)), dim)); list_elems.push_back(stack_input); } int64_t axis = 0; - if (context.get_input_size() > 2) - axis = context.const_input(context.get_input_size() - 1); + if (!context.get_input_type(num_elements - 1).is()) { + // axis can be not present and that means that last input will have List type + axis = context.const_input(num_elements - 1); + } else { + auto stack_input = context.mark_node( + std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); + list_elems.push_back(stack_input); + } return translate_cat_common(context, list_elems, axis, true); } diff --git a/src/frontends/pytorch/src/op/div.cpp b/src/frontends/pytorch/src/op/div.cpp index f8640f2693f90d..7c091f2c2cb8da 100644 --- a/src/frontends/pytorch/src/op/div.cpp +++ b/src/frontends/pytorch/src/op/div.cpp @@ -90,6 +90,17 @@ OutputVector translate_div_fx(const NodeContext& context) { return translate_div_common(context, x, y, rounding_mode, false); }; +OutputVector translate_div_fx_(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto x = context.get_input(0); + auto y = context.get_input(1); + std::string rounding_mode = ""; + if (context.has_attribute("rounding_mode")) { + rounding_mode = context.get_attribute("rounding_mode"); + } + return translate_div_common(context, x, y, rounding_mode, true); +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/embedding_bag.cpp b/src/frontends/pytorch/src/op/embedding_bag.cpp index a1094265eaf789..633eac2a100ca1 100644 --- a/src/frontends/pytorch/src/op/embedding_bag.cpp +++ b/src/frontends/pytorch/src/op/embedding_bag.cpp @@ -15,10 +15,9 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_embedding_bag(const NodeContext& context) { +OutputVector translate_embedding_bag_common(const NodeContext& context) { // aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False, // per_sample_weights=None, include_last_offset=False, padding_idx=None) - num_inputs_check(context, 9, 9); // we have only EmbeddingBagSum case support, check it before translation auto mode = context.const_input(4); PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation"); @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) { // with offsets case auto offsets = context.get_input(2); offsets = context.mark_node(std::make_shared(offsets, element::i32)); - auto include_last_offset = context.const_input(7); + bool include_last_offset = false; + if (!context.input_is_none(7)) + include_last_offset = context.const_input(7); PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported"); // no per_sample_wights if (context.input_is_none(6)) { @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) { return {result, zero, zero, zero}; }; +OutputVector translate_embedding_bag(const NodeContext& context) { + num_inputs_check(context, 9, 9); + return translate_embedding_bag_common(context); +} + +OutputVector translate_embedding_bag_fx(const NodeContext& context) { + num_inputs_check(context, 7, 9); + ov::OutputVector output = translate_embedding_bag_common(context); + return {context.mark_node(make_list_construct(output))}; +} + } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op/log.cpp b/src/frontends/pytorch/src/op/log.cpp index 3ac3554a286efd..80a526a73e0a65 100644 --- a/src/frontends/pytorch/src/op/log.cpp +++ b/src/frontends/pytorch/src/op/log.cpp @@ -22,7 +22,7 @@ namespace op { using namespace ov::op; -OutputVector translate_log_sigmoid(const NodeContext& context) { +std::shared_ptr translate_log_sigmoid_common(const NodeContext& context) { num_inputs_check(context, 1, 1); auto op_vector = op::translate_1to1_match_1_inputs_with_fp32_type_alignment(context); PYTORCH_OP_CONVERSION_CHECK(op_vector.size() == 1, @@ -30,7 +30,16 @@ OutputVector translate_log_sigmoid(const NodeContext& context) { op_vector.size()); auto sigmoid = op_vector[0]; auto log = context.mark_node(std::make_shared(sigmoid)); - return {log}; + return log; +}; + +OutputVector translate_log_sigmoid(const NodeContext& context) { + return {translate_log_sigmoid_common(context)}; +}; + +OutputVector translate_log_sigmoid_fx(const NodeContext& context) { + auto log = translate_log_sigmoid_common(context); + return {context.mark_node(make_list_construct(log->outputs()))}; }; OutputVector translate_log2(const NodeContext& context) { diff --git a/src/frontends/pytorch/src/op/sort.cpp b/src/frontends/pytorch/src/op/sort.cpp index 02cd6214af1eda..b2f48cf002e925 100644 --- a/src/frontends/pytorch/src/op/sort.cpp +++ b/src/frontends/pytorch/src/op/sort.cpp @@ -9,22 +9,8 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_sort(const NodeContext& context) { - num_inputs_check(context, 3, 4); +OutputVector translate_sort_common(const NodeContext& context, bool stable, int64_t dim, bool descending) { const auto input_tensor = context.get_input(0); - bool stable, descending; - int64_t dim; - - if (context.get_input_size() == 4) { - stable = context.const_input(1); - dim = context.const_input(2); - descending = context.const_input(3); - } else { - stable = false; - dim = context.const_input(1); - descending = context.const_input(2); - } - auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN; auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0})); auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim})); @@ -39,6 +25,42 @@ OutputVector translate_sort(const NodeContext& context) { element::i64, stable)); return topk->outputs(); +} + +OutputVector translate_sort(const NodeContext& context) { + num_inputs_check(context, 3, 4); + bool stable, descending; + int64_t dim; + + if (context.get_input_size() == 4) { + stable = context.const_input(1); + dim = context.const_input(2); + descending = context.const_input(3); + } else { + stable = false; + dim = context.const_input(1); + descending = context.const_input(2); + } + + return translate_sort_common(context, stable, dim, descending); +}; + +OutputVector translate_sort_fx(const NodeContext& context) { + // aten.sort.default(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + num_inputs_check(context, 1, 3); + bool descending = false; + bool stable = false; + int64_t dim = -1; + + if (!context.input_is_none(1)) { + dim = context.const_input(1); + } + if (!context.input_is_none(2)) { + descending = context.const_input(2); + } + + auto topk_outputs = translate_sort_common(context, stable, dim, descending); + return {context.mark_node(make_list_construct(OutputVector({topk_outputs[0], topk_outputs[1]})))}; }; OutputVector translate_argsort(const NodeContext& context) { diff --git a/src/frontends/pytorch/src/op/split.cpp b/src/frontends/pytorch/src/op/split.cpp index b58c05c1f5bb47..01e14c3b57c2fc 100644 --- a/src/frontends/pytorch/src/op/split.cpp +++ b/src/frontends/pytorch/src/op/split.cpp @@ -25,11 +25,11 @@ OutputVector translate_chunk_fx(const NodeContext& context) { std::shared_ptr chunk; auto dim_val = context.const_input(2); - auto shape = context.get_input(0).get_shape(); + auto shape = context.get_input(0).get_partial_shape(); if (dim_val < 0) { - dim_val = static_cast(shape.size()) + dim_val; + dim_val = static_cast(shape.rank().get_length()) + dim_val; } - int num_splits = static_cast(shape[dim_val]) / num_chunks; + int num_splits = static_cast(shape[dim_val].get_length()) / num_chunks; chunk = context.mark_node(std::make_shared(context.get_input(0), dim, num_splits)); @@ -37,12 +37,17 @@ OutputVector translate_chunk_fx(const NodeContext& context) { } OutputVector translate_unbind_int_fx(const NodeContext& context) { - num_inputs_check(context, 2, 3); + num_inputs_check(context, 1, 3); auto input = context.get_input(0); - auto dim = context.get_input(1); - auto dim_val = context.const_input(1); + Output dim; + int64_t dim_val = 0; + if (context.input_is_none(1)) { + dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + } else { + dim = context.get_input(1); + dim_val = context.const_input(1); + } auto shape = input.get_shape(); - if (dim_val < 0) { dim_val = static_cast(shape.size()) + dim_val; } diff --git a/src/frontends/pytorch/src/op/topk.cpp b/src/frontends/pytorch/src/op/topk.cpp index 4a1943a2ae4dae..2fd79f3c3f92a4 100644 --- a/src/frontends/pytorch/src/op/topk.cpp +++ b/src/frontends/pytorch/src/op/topk.cpp @@ -41,6 +41,39 @@ OutputVector translate_topk(const NodeContext& context) { return {topk->output(0), indices}; }; +OutputVector translate_topk_fx(const NodeContext& context) { + // aten.topk.default(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> Tuple[Tensor, Tensor] + num_inputs_check(context, 2, 5); + const auto input_tensor = context.get_input(0); + auto k = context.get_input(1); + int64_t axis{-1}; + bool largest = true; + bool sorted = true; + auto mode = TopKMode::MIN; + auto sort = TopKSortType::NONE; + + if (!context.input_is_none(2)) { + axis = context.const_input(2); + } + if (!context.input_is_none(3)) { + largest = context.const_input(3); + } + if (!context.input_is_none(4)) { + sorted = context.const_input(4); + } + if (largest) { + mode = TopKMode::MAX; + } + if (sorted) { + sort = TopKSortType::SORT_VALUES; + } + + auto topk = context.mark_node(std::make_shared(input_tensor, k, axis, mode, sort)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + + return {context.mark_node(make_list_construct(OutputVector({topk->output(0), indices})))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 26dc9ef018100a..55d218df430e43 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -245,6 +245,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d_fx); OP_CONVERTER(translate_adaptive_max_pool3d_fx); OP_CONVERTER(translate_addcmul_fx); OP_CONVERTER(translate_addmm_fx); +OP_CONVERTER(translate_any_fx); OP_CONVERTER(translate_arange_fx); OP_CONVERTER(translate_batch_norm_legit_fx); OP_CONVERTER(translate_batch_norm_legit_no_training_fx); @@ -254,6 +255,8 @@ OP_CONVERTER(translate_constant_pad_nd_fx); OP_CONVERTER(translate_cumsum_fx); OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_div_fx); +OP_CONVERTER(translate_div_fx_); +OP_CONVERTER(translate_embedding_bag_fx); OP_CONVERTER(translate_expand_fx); OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx); OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx); @@ -264,6 +267,7 @@ OP_CONVERTER(translate_group_norm_fx); OP_CONVERTER(translate_index_fx); OP_CONVERTER(translate_layer_norm_fx); OP_CONVERTER(translate_leaky_relu_fx); +OP_CONVERTER(translate_log_sigmoid_fx); OP_CONVERTER(translate_log_softmax_fx); OP_CONVERTER(translate_max_dim_fx); OP_CONVERTER(translate_max_poolnd_fx); @@ -282,10 +286,12 @@ OP_CONVERTER(translate_select_scatter_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_slice_scatter_fx); OP_CONVERTER(translate_softmax_fx); +OP_CONVERTER(translate_sort_fx); OP_CONVERTER(translate_split_with_sizes_fx); OP_CONVERTER(translate_stack_fx); OP_CONVERTER(translate_sub_fx); OP_CONVERTER(translate_sum_fx); +OP_CONVERTER(translate_topk_fx); OP_CONVERTER(translate_to_fx); OP_CONVERTER(translate_transpose_fx); OP_CONVERTER(translate_var_fx); @@ -710,6 +716,7 @@ const std::map get_supported_ops_fx() { {"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d}, {"aten._adaptive_avg_pool3d.default", op::translate_adaptive_avg_pool3d}, {"aten._convolution.default", op::translate_convolution}, + {"aten._embedding_bag.default", op::translate_embedding_bag_fx}, {"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default", op::translate_fake_quantize_per_tensor_affine_fx}, {"aten._local_scalar_dense.default", op::skip_node}, @@ -735,8 +742,11 @@ const std::map get_supported_ops_fx() { {"aten.addcmul.default", op::translate_addcmul_fx}, {"aten.addmm.default", op::translate_addmm_fx}, {"aten.alias.default", op::skip_node}, + {"aten.all.default", op::translate_all}, {"aten.amax.default", op::translate_amax}, {"aten.amin.default", op::translate_amin}, + {"aten.any.default", op::translate_any_fx}, + {"aten.any.dim", op::translate_any_fx}, {"aten.arange.default", op::translate_arange_fx}, {"aten.arange.start", op::translate_arange_fx}, {"aten.arange.start_step", op::translate_arange_fx}, @@ -773,10 +783,13 @@ const std::map get_supported_ops_fx() { {"aten.cumsum.default", op::translate_cumsum_fx}, {"aten.channel_shuffle.default", op::translate_channel_shuffle}, {"aten.detach.default", op::skip_node}, + {"aten.detach_.default", op::skip_node}, {"aten.div.Scalar", op::translate_div_fx}, {"aten.div.Tensor", op::translate_div_fx}, {"aten.div.Tensor_mode", op::translate_div_fx}, + {"aten.div_.Tensor", op::translate_div_fx_}, {"aten.elu.default", op::translate_elu}, + {"aten.elu_.default", op::inplace_op}, {"aten.embedding.default", op::translate_embedding}, {"aten.empty.memory_format", op::translate_empty}, {"aten.eq.Scalar", op::translate_1to1_match_2_inputs_align_types}, @@ -788,7 +801,9 @@ const std::map get_supported_ops_fx() { {"aten.expand.default", op::translate_expand_fx}, {"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx}, {"aten.fill.Scalar", op::translate_fill}, + {"aten.fill_.Scalar", op::inplace_op}, {"aten.fill.Tensor", op::translate_fill}, + {"aten.fill_.Tensor", op::inplace_op}, {"aten.flip.default", op::translate_flip}, {"aten.floor.default", op::translate_1to1_match_1_inputs}, {"aten.floor_divide.default", op::translate_floor_divide}, @@ -802,6 +817,7 @@ const std::map get_supported_ops_fx() { {"aten.ge.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.gelu.default", op::translate_gelu_fx}, {"aten.glu.default", op::translate_glu}, + {"aten.grid_sampler_2d.default", op::translate_grid_sampler}, {"aten.gt.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.gt.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.hardsigmoid.default", op::translate_1to1_match_1_inputs}, @@ -811,6 +827,9 @@ const std::map get_supported_ops_fx() { {"aten.hardtanh_.default", op::inplace_op}, {"aten.index.Tensor", op::translate_index_fx}, {"aten.index_select.default", op::translate_index_select}, + {"aten.isfinite.default", op::inplace_op>}, + {"aten.isinf.default", op::inplace_op>}, + {"aten.isnan.default", op::inplace_op>}, {"aten.le.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.le.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.leaky_relu.default", op::translate_leaky_relu_fx}, @@ -818,15 +837,17 @@ const std::map get_supported_ops_fx() { {"aten.lift_fresh_copy.default", op::skip_node}, {"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm}, {"aten.log.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, - {"aten.log_sigmoid_forward.default", op::translate_log_sigmoid}, + {"aten.log_sigmoid_forward.default", op::translate_log_sigmoid_fx}, {"aten.log10.default", op::translate_log10}, {"aten.log1p.default", op::translate_log1p}, {"aten.log2.default", op::translate_log2}, {"aten.logsumexp.default", op::translate_logsumexp}, {"aten.lt.Scalar", op::translate_1to1_match_2_inputs_align_types}, {"aten.lt.Tensor", op::translate_1to1_match_2_inputs_align_types}, + {"aten.masked_fill.Scalar", op::translate_masked_fill}, {"aten.masked_fill.Tensor", op::translate_masked_fill}, {"aten.masked_fill_.Scalar", op::inplace_op}, + {"aten.masked_fill_.Tensor", op::inplace_op}, {"aten.max.default", op::translate_max}, {"aten.max.dim", op::translate_max_dim_fx}, {"aten.max_pool2d_with_indices.default", op::translate_max_poolnd_fx}, @@ -872,6 +893,7 @@ const std::map get_supported_ops_fx() { {"aten.rsub.Scalar", op::translate_rsub_fx}, {"aten.rsub.Tensor", op::translate_rsub_fx}, {"aten.scalar_tensor.default", op::translate_scalar_tensor_fx}, + {"aten.scatter.src", op::translate_scatter}, {"aten.scatter.value", op::translate_scatter}, {"aten.select.int", op::translate_select}, {"aten.select_scatter.default", op::translate_select_scatter_fx}, @@ -883,6 +905,7 @@ const std::map get_supported_ops_fx() { {"aten.sinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.slice.Tensor", op::translate_slice_fx}, {"aten.slice_scatter.default", op::translate_slice_scatter_fx}, + {"aten.sort.default", op::translate_sort_fx}, {"aten.split.Tensor", op::translate_chunk_fx}, {"aten.split_with_sizes.default", op::translate_split_with_sizes_fx}, {"aten.sqrt.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -896,7 +919,9 @@ const std::map get_supported_ops_fx() { {"aten.t.default", op::translate_t}, {"aten.tan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.tanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, + {"aten.topk.default", op::translate_topk_fx}, {"aten.transpose.int", op::translate_transpose}, + {"aten.tril.default", op::translate_tril}, {"aten.unbind.int", op::translate_unbind_int_fx}, {"aten.unfold.default", op::translate_unfold}, {"aten.unsqueeze.default", op::translate_1to1_match_2_inputs}, @@ -909,20 +934,8 @@ const std::map get_supported_ops_fx() { {"aten.zeros.names", op::translate_zeros_fx}, {"aten.zeros_like.default", op::translate_zeros_like_fx}, {"get_attr", op::translate_constant}, - {"prim::Constant", op::translate_constant}, - {"prim::device", op::translate_constant}, - {"prim::GetAttr", op::translate_get_attr}, - {"prim::If", op::translate_if}, - {"prim::is_cuda", op::return_false_scalar}, - {"prim::ListConstruct", op::translate_list_construct}, - {"prim::Loop", op::translate_loop}, - {"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape [] - {"prim::PythonOp", op::translate_pythonop}, - {"prim::requires_grad", op::return_false_scalar}, - {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. - {"torchvision::deform_conv2d", op::translate_deform_conv}, - {"torchvision::nms", op::translate_nms}, - {"torchvision::roi_align", op::translate_roi_align}, + {"torchvision.deform_conv2d.default", op::translate_deform_conv}, + {"torchvision.roi_align.default", op::translate_roi_align}, }; }; diff --git a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py index 5f750896963935..4bb291c9d51178 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py @@ -479,11 +479,11 @@ def test_pytorch_decoder_can_convert_empty_list(): class aten_roll(torch.nn.Module): def __init__(self, shifts): super(aten_roll, self).__init__() - self.shits = shifts + self.shifts = shifts def forward(self, x): # roll has optional input dim, which is empty int list by default - return torch.roll(x, self.shits) + return torch.roll(x, self.shifts) model = get_scripted_model(aten_roll(1)) consts = [n for n in model.inlined_graph.nodes() if n.kind() == diff --git a/tests/layer_tests/pytorch_tests/test_addcmul.py b/tests/layer_tests/pytorch_tests/test_addcmul.py index 5dde37fb609812..9812bbb1b329c9 100644 --- a/tests/layer_tests/pytorch_tests/test_addcmul.py +++ b/tests/layer_tests/pytorch_tests/test_addcmul.py @@ -47,6 +47,7 @@ def forward(self, x, y, z): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_addcmul(self, input_type, value, ie_device, precision, ir_version): self.input_type = input_type self._test(*self.create_model(value), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_all.py b/tests/layer_tests/pytorch_tests/test_all.py index de91f90cb69d60..6e4b1c494302fc 100644 --- a/tests/layer_tests/pytorch_tests/test_all.py +++ b/tests/layer_tests/pytorch_tests/test_all.py @@ -77,6 +77,7 @@ def _prepare_input(self, out=False): @pytest.mark.parametrize("out", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_all_noparams(self, input_shape, d_type, out, ie_device, precision, ir_version): if type(input_shape) is list: self.input_tensor = np.random.randint(0, 2, input_shape, dtype=d_type) @@ -104,6 +105,7 @@ def test_all_noparams(self, input_shape, d_type, out, ie_device, precision, ir_v @pytest.mark.parametrize("out", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ('arm', 'armv7l', 'aarch64', 'arm64', 'ARM64'), diff --git a/tests/layer_tests/pytorch_tests/test_any.py b/tests/layer_tests/pytorch_tests/test_any.py new file mode 100644 index 00000000000000..0bc0ca79a55d70 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_any.py @@ -0,0 +1,50 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestAny(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return ((np.random.randint(2, size=(3,3,10,10)) > 0),) + + def create_model(self, dim=None, keep_dim=None): + + import torch + class aten_any(torch.nn.Module): + def __init__(self, dim=None, keep_dim=None): + super(aten_any, self).__init__() + + if dim == None: + self.forward = self.forward_default + else: + self.forward = self.forward_dim + self.dim = dim + self.keep_dim = keep_dim + + def forward_default(self, x): + return torch.any(x) + + def forward_dim(self, x): + return torch.any(x, dim=self.dim, keepdim=self.keep_dim) + + + ref_net = None + + return aten_any(dim, keep_dim), ref_net, "aten::any" + + + @pytest.mark.precommit_fx_backend + def test_any_default(self, ie_device, precision, ir_version): + self._test(*self.create_model(), + ie_device, precision, ir_version) + + @pytest.mark.parametrize(("dim", "keep_dim"), + [(0, False), (0, True), (-1, True)]) + @pytest.mark.precommit_fx_backend + def test_any_dim(self, dim, keep_dim, ie_device, precision, ir_version): + self._test(*self.create_model(dim, keep_dim), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_arange.py b/tests/layer_tests/pytorch_tests/test_arange.py index fb8cea4b14b9c1..d871b09c0e4686 100644 --- a/tests/layer_tests/pytorch_tests/test_arange.py +++ b/tests/layer_tests/pytorch_tests/test_arange.py @@ -109,6 +109,7 @@ def forward(self, x, y, z, d): @pytest.mark.nightly @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) @pytest.mark.parametrize("end", [1, 2, 3]) @pytest.mark.parametrize("use_out", [skip_if_export(True), False]) @@ -117,6 +118,7 @@ def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_ver kwargs_to_prepare_input={"end": end}) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)]) def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version): @@ -125,6 +127,7 @@ def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_vers @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)]) def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precision, ir_version): @@ -133,6 +136,7 @@ def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precisi @pytest.mark.nightly @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8"]) @pytest.mark.parametrize("end", [1, 2, 3]) def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision, ir_version): @@ -140,6 +144,7 @@ def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision, kwargs_to_prepare_input={"end": end, "ref_dtype": dtype}) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)]) def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, precision, ir_version): @@ -148,6 +153,7 @@ def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, pr @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) @pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)]) def test_arange_start_end_step_with_prim_dtype(self, dtype, end, start, step, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_argmax_argmin.py b/tests/layer_tests/pytorch_tests/test_argmax_argmin.py index b033a2980de8af..db609894a10c3e 100644 --- a/tests/layer_tests/pytorch_tests/test_argmax_argmin.py +++ b/tests/layer_tests/pytorch_tests/test_argmax_argmin.py @@ -74,6 +74,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ('arm', 'armv7l', 'aarch64', 'arm64', 'ARM64'), diff --git a/tests/layer_tests/pytorch_tests/test_as_strided.py b/tests/layer_tests/pytorch_tests/test_as_strided.py index 964e9319ef5278..254084c89648dd 100644 --- a/tests/layer_tests/pytorch_tests/test_as_strided.py +++ b/tests/layer_tests/pytorch_tests/test_as_strided.py @@ -41,6 +41,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version): self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True) @@ -92,6 +93,7 @@ def forward_size_const(self, x, size_shape_tensor, stride_shape_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided_list_construct(self, size, stride, offset, mode, ie_device, precision, ir_version): inp_kwargs = {"size_shape_tensor": size, "stride_shape_tensor": stride} self._test( @@ -124,5 +126,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_as_strided_lf(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False) diff --git a/tests/layer_tests/pytorch_tests/test_bitwise_ops.py b/tests/layer_tests/pytorch_tests/test_bitwise_ops.py index 248eb44e35402e..72578bd75f2625 100644 --- a/tests/layer_tests/pytorch_tests/test_bitwise_ops.py +++ b/tests/layer_tests/pytorch_tests/test_bitwise_ops.py @@ -55,6 +55,7 @@ def forward_not_out(self, tensor_a, out): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"]) @pytest.mark.parametrize("lhs_dtype", ["bool", "int32", "uint8", "int64"]) @pytest.mark.parametrize("rhs_dtype", ["bool", "int32", "uint8", "int64"]) @@ -107,6 +108,7 @@ def forward(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("lhs_dtype", ["bool", "int32"]) @pytest.mark.parametrize("rhs_dtype", ["bool", "int32"]) @pytest.mark.parametrize( diff --git a/tests/layer_tests/pytorch_tests/test_clamp.py b/tests/layer_tests/pytorch_tests/test_clamp.py index 947f9197f72d35..b8a977a897046b 100644 --- a/tests/layer_tests/pytorch_tests/test_clamp.py +++ b/tests/layer_tests/pytorch_tests/test_clamp.py @@ -48,6 +48,7 @@ def forward_clip_(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_clamp(self, minimum, maximum, as_tensors, op_type, ie_device, precision, ir_version): self._test(*self.create_model(minimum, maximum, as_tensors, op_type), ie_device, precision, ir_version) @@ -76,6 +77,7 @@ def forward(self, x): @pytest.mark.parametrize("minimum", [0., 1., -1., 0.5, 2]) @pytest.mark.parametrize("as_tensor", [True, False]) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend def test_clamp_min(self, minimum, as_tensor, ie_device, precision, ir_version): self._test(*self.create_model(minimum, as_tensor), ie_device, precision, ir_version, use_convert_model=True, trace_model=True) @@ -106,6 +108,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_clamp(self, maximum, as_tensor, ie_device, precision, ir_version): self._test(*self.create_model(maximum, as_tensor), ie_device, precision, ir_version, use_convert_model=True, trace_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_comparision.py b/tests/layer_tests/pytorch_tests/test_comparision.py index 8267a95d1c73a4..dee86407bb6051 100644 --- a/tests/layer_tests/pytorch_tests/test_comparision.py +++ b/tests/layer_tests/pytorch_tests/test_comparision.py @@ -55,6 +55,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_comp(self, op, ie_device, precision, ir_version): self._test(*self.create_model(op), ie_device, precision, ir_version, use_convert_model=True) @@ -127,6 +128,7 @@ def forward3(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_eq_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, op): self.lhs_type = lhs_type self.lhs_shape = lhs_shape diff --git a/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py b/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py new file mode 100644 index 00000000000000..7a92983bb1819d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_constant_pad_nd.py @@ -0,0 +1,37 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestConstantPadND(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, pad, value): + + import torch + class aten_constant_pad_nd(torch.nn.Module): + def __init__(self, pad=None, value=None): + super(aten_constant_pad_nd, self).__init__() + self.pad = pad + self.value = value + + def forward(self, x): + return torch.constant_pad_nd(x, self.pad, self.value); + + + ref_net = None + + return aten_constant_pad_nd(pad, value), ref_net, "aten::constant_pad_nd" + + @pytest.mark.parametrize(("pad", "value"), + [((1,1,1,1), 0),((0,2,0,2), -1.0),((3,1,5,2), 0.5),((0,0,0,0), 0),]) + + @pytest.mark.precommit_fx_backend + def test_constant_pad_nd(self, pad, value, ie_device, precision, ir_version): + self._test(*self.create_model(pad, value), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_copy.py b/tests/layer_tests/pytorch_tests/test_copy.py index 1adff2f36d6536..c2a387a5358b00 100644 --- a/tests/layer_tests/pytorch_tests/test_copy.py +++ b/tests/layer_tests/pytorch_tests/test_copy.py @@ -28,6 +28,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("value", [1, [2.5], range(224)]) def test_copy_(self, value, ie_device, precision, ir_version): self._test(*self.create_model(value), ie_device, precision, ir_version) @@ -63,4 +64,4 @@ def forward_out(self, x, y): @pytest.mark.precommit @pytest.mark.parametrize("out", [True, False]) def test_copy_(self, out, ie_device, precision, ir_version): - self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out}) \ No newline at end of file + self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out}) diff --git a/tests/layer_tests/pytorch_tests/test_deformable_convolution.py b/tests/layer_tests/pytorch_tests/test_deformable_convolution.py index 6e21d41a86df9f..45fac8ea0b8c43 100644 --- a/tests/layer_tests/pytorch_tests/test_deformable_convolution.py +++ b/tests/layer_tests/pytorch_tests/test_deformable_convolution.py @@ -170,6 +170,7 @@ def forward(self, x): @pytest.mark.parametrize("mask", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_deformable_convolution2d(self, params, bias, mask, ie_device, precision, ir_version): self._test( *self.create_model(**params, bias=bias, mask=mask), ie_device, precision, ir_version, trace_model=True diff --git a/tests/layer_tests/pytorch_tests/test_div.py b/tests/layer_tests/pytorch_tests/test_div.py index 3ae112de0e2699..ad6769ded6504e 100644 --- a/tests/layer_tests/pytorch_tests/test_div.py +++ b/tests/layer_tests/pytorch_tests/test_div.py @@ -44,6 +44,7 @@ def forward(self, input_tensor, other_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_div_pt_spec(self, input_array, other_array, rounding_mode, ie_device, precision, ir_version): self.input_array = input_array self.input_type = np.float32 diff --git a/tests/layer_tests/pytorch_tests/test_embedding_bag.py b/tests/layer_tests/pytorch_tests/test_embedding_bag.py index 2c8d289f7e0035..e02eb8f7866a0a 100644 --- a/tests/layer_tests/pytorch_tests/test_embedding_bag.py +++ b/tests/layer_tests/pytorch_tests/test_embedding_bag.py @@ -42,6 +42,7 @@ def forward_offsets_per_sample_weights(self, indicies, weight, offsets, per_samp @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) @pytest.mark.parametrize("per_sample_weights", [True, False]) @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', @@ -86,6 +87,7 @@ def forward_per_sample_weights(self, indicies, weight, per_sample_wights): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("indicies_size", [[1, 1], [2, 5], [3, 10], [4, 7]]) @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) @pytest.mark.parametrize("per_sample_weights", [True, False]) diff --git a/tests/layer_tests/pytorch_tests/test_fake_quantize.py b/tests/layer_tests/pytorch_tests/test_fake_quantize.py index c3283279bb4aa3..0963b646d4d526 100644 --- a/tests/layer_tests/pytorch_tests/test_fake_quantize.py +++ b/tests/layer_tests/pytorch_tests/test_fake_quantize.py @@ -37,6 +37,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize( "scale, zero_point, quant_min, quant_max", [ @@ -61,6 +62,58 @@ def test_fake_quantize_per_tensor_affine( freeze_model=False ) +class TestFakeQuantizePerTensorAffineCacheMaskTensorQParams(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(3, 2, 2).astype(np.float32),) + + def create_model(self, scale, zero_point, quant_min, quant_max): + class _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(torch.nn.Module): + def __init__(self, scale, zero_point, quant_min, quant_max): + super(_fake_quantize_per_tensor_affine_cachemask_tensor_qparams, self).__init__() + self.scale = torch.tensor(scale) + self.zero_point = torch.tensor(zero_point) + self.fake_quant_enabled = torch.tensor(1) + self.quant_min = quant_min + self.quant_max = quant_max + + def forward(self, x): + return torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + x, self.scale, self.zero_point, self.fake_quant_enabled, self.quant_min, self.quant_max + ) + + ref_net = None + + return ( + _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(scale, zero_point, quant_min, quant_max), + ref_net, + "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + ) + + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize( + "scale, zero_point, quant_min, quant_max", + [ + (1.0, 1, 0, 255), + (0.01, 0, 0, 255), + (-0.01, 0, 0, 255), + (0.5, 0, -128, 127), + (0.5, -1, -128, 127), + (1.0, 0, 0, 127), + ], + ) + @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', + reason='Ticket - 122715') + def test__fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + self, ie_device, precision, ir_version, scale, zero_point, quant_min, quant_max + ): + self._test( + *self.create_model(scale, zero_point, quant_min, quant_max), + ie_device, + precision, + ir_version, + freeze_model=False + ) + class TestFakeQuantizePerChannelAffine(PytorchLayerTest): def _prepare_input(self): @@ -91,6 +144,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize( "scale, zero_point, axis, quant_min, quant_max", [ diff --git a/tests/layer_tests/pytorch_tests/test_full.py b/tests/layer_tests/pytorch_tests/test_full.py index 4c54c14e06dd4a..ca949b14cb134a 100644 --- a/tests/layer_tests/pytorch_tests/test_full.py +++ b/tests/layer_tests/pytorch_tests/test_full.py @@ -84,6 +84,7 @@ def forward(self, x: float): @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.precommit_torch_export def test_full(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, @@ -94,6 +95,7 @@ def test_full(self, shape, value, ie_device, precision, ir_version): @pytest.mark.parametrize("dtype", ["int8", "int32", "int64", "float32", "float64"]) @pytest.mark.parametrize("with_names", [True, False]) @pytest.mark.nightly + @pytest.mark.precommit_fx_backend @pytest.mark.precommit_torch_export def test_full_dtype(self, shape, value, dtype, with_names, ie_device, precision, ir_version): self._test(*self.create_model(shape, dtype=dtype, use_dtype=True, with_names=with_names), ie_device, precision, @@ -280,6 +282,7 @@ def forward(self, input_t: torch.Tensor, x: float): @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.precommit_torch_export def test_full_like(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, @@ -349,6 +352,7 @@ def forward(self, input_tensor: torch.Tensor, x: float): @pytest.mark.parametrize("value,input_dtype", [(0, np.uint8), (1, np.int32), (-1, np.float32), (0.5, np.float64)]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.precommit_torch_export def test_new_full(self, shape, value, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, @@ -480,6 +484,7 @@ def forward(self, x): @pytest.mark.parametrize("op_type", ["aten::zeros", "aten::ones", "aten::zeros_like", "aten::ones_like"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.precommit_torch_export def test_zeros_ones(self, op_type, shape, ie_device, precision, ir_version): self._test(*self.create_model(op_type), ie_device, precision, @@ -631,6 +636,7 @@ def forward(self, input_tensor: torch.Tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_new_ones(self, shape, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True) @@ -640,6 +646,7 @@ def test_new_ones(self, shape, input_dtype, ie_device, precision, ir_version): @pytest.mark.parametrize("dtype", ["bool", "uint8", "int8", "int32", "int64", "float32", "float64"]) @pytest.mark.nightly @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_new_ones_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version, kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_glu.py b/tests/layer_tests/pytorch_tests/test_glu.py index 3dbb1f423ee4ab..8a1fbb4e23a153 100644 --- a/tests/layer_tests/pytorch_tests/test_glu.py +++ b/tests/layer_tests/pytorch_tests/test_glu.py @@ -30,6 +30,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2]) def test_glu(self, dim, ie_device, precision, ir_version): - self._test(*self.create_model(dim), ie_device, precision, ir_version) \ No newline at end of file + self._test(*self.create_model(dim), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_grid_sampler.py b/tests/layer_tests/pytorch_tests/test_grid_sampler.py index d81ba01aca3bca..5da728f9d564c0 100644 --- a/tests/layer_tests/pytorch_tests/test_grid_sampler.py +++ b/tests/layer_tests/pytorch_tests/test_grid_sampler.py @@ -37,6 +37,7 @@ def forward(self, input, grid): @pytest.mark.parametrize("align_corners", [True, False, None]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_grid_sampler(self, h_in, w_in, h_out, w_out, mode, padding_mode, align_corners, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_hardtanh.py b/tests/layer_tests/pytorch_tests/test_hardtanh.py new file mode 100644 index 00000000000000..d0c4c1aac1a38d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_hardtanh.py @@ -0,0 +1,42 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import platform + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestHardtanh(PytorchLayerTest): + def _prepare_input(self, input_dtype="float32", input_shape=(1, 3, 10, 10)): + return (np.random.default_rng().uniform(-100.0, 100.0, input_shape).astype(input_dtype),) + + def create_model(self, min_val, max_val, inplace): + import torch + import torch.nn.functional as F + + class aten_hardtanh(torch.nn.Module): + def __init__(self, min_val, max_val, inplace): + super(aten_hardtanh, self).__init__() + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + + def forward(self, x): + return F.hardtanh(x, min_val=self.min_val, max_val=self.max_val, inplace=self.inplace) + + ref_net = None + + return aten_hardtanh(min_val, max_val, inplace), ref_net, "aten::hardtanh" + + @pytest.mark.parametrize(("min_val", "max_val"), [[-1.0,1.0], [0, 1.0], [-2.0, 2.0]]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("input_dtype", ['float32', 'int32', 'int64', 'float64']) + @pytest.mark.parametrize("input_shape", [(1, 3, 10, 10), (100,), (24, 24)]) + @pytest.mark.precommit_fx_backend + def test_hardtanh(self, min_val, max_val, inplace, input_dtype, input_shape, ie_device, precision, ir_version): + self._test(*self.create_model(min_val, max_val, inplace), ie_device, precision, ir_version, + kwargs_to_prepare_input= {"input_dtype": input_dtype, "input_shape": input_shape}) diff --git a/tests/layer_tests/pytorch_tests/test_index_select.py b/tests/layer_tests/pytorch_tests/test_index_select.py index e74bc597661ebd..2cf29c9172b9c9 100644 --- a/tests/layer_tests/pytorch_tests/test_index_select.py +++ b/tests/layer_tests/pytorch_tests/test_index_select.py @@ -41,6 +41,7 @@ def forward_out(self, x, indices, out): @pytest.mark.parametrize("out", [False, True]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_index_select(self, dim, out, indices, ie_device, precision, ir_version): self._test(*self.create_model(dim, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"index": indices, "out": out, "dim": dim}) diff --git a/tests/layer_tests/pytorch_tests/test_isfinite.py b/tests/layer_tests/pytorch_tests/test_isfinite.py new file mode 100644 index 00000000000000..d028b50c8bb2cf --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_isfinite.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +@pytest.mark.parametrize('input_tensor', (torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))) +class TestIsFinite(PytorchLayerTest): + + def _prepare_input(self): + input_tensor = self.input_tensor + return (input_tensor,) + + def create_model(self): + class aten_isfinite(torch.nn.Module): + + def forward(self, input_tensor): + return torch.isfinite(input_tensor) + + ref_net = None + + return aten_isfinite(), ref_net, "aten::isfinite" + + @pytest.mark.precommit_fx_backend + def test_isfinite(self, ie_device, precision, ir_version, input_tensor): + self.input_tensor = input_tensor + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_isinf.py b/tests/layer_tests/pytorch_tests/test_isinf.py new file mode 100644 index 00000000000000..03e5dda64dd253 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_isinf.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +@pytest.mark.parametrize('input_tensor', (torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))) +class TestIsInf(PytorchLayerTest): + + def _prepare_input(self): + input_tensor = self.input_tensor + return (input_tensor,) + + def create_model(self): + class aten_isinf(torch.nn.Module): + + def forward(self, input_tensor): + return torch.isinf(input_tensor) + + ref_net = None + + return aten_isinf(), ref_net, "aten::isinf" + + @pytest.mark.precommit_fx_backend + def test_isinf(self, ie_device, precision, ir_version, input_tensor): + self.input_tensor = input_tensor + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_isnan.py b/tests/layer_tests/pytorch_tests/test_isnan.py new file mode 100644 index 00000000000000..463d59392f09b1 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_isnan.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +@pytest.mark.parametrize('input_tensor', (torch.tensor([1, float('nan'), 2]))) +class TestIsNan(PytorchLayerTest): + + def _prepare_input(self): + input_tensor = self.input_tensor + return (input_tensor,) + + def create_model(self): + class aten_isnan(torch.nn.Module): + + def forward(self, input_tensor): + return torch.isnan(input_tensor) + + ref_net = None + + return aten_isnan(), ref_net, "aten::isnan" + + @pytest.mark.precommit_fx_backend + def test_isnan(self, ie_device, precision, ir_version, input_tensor): + self.input_tensor = input_tensor + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_leaky_relu.py b/tests/layer_tests/pytorch_tests/test_leaky_relu.py index 002c76e5814001..f00e6d241220ab 100644 --- a/tests/layer_tests/pytorch_tests/test_leaky_relu.py +++ b/tests/layer_tests/pytorch_tests/test_leaky_relu.py @@ -32,5 +32,6 @@ def forward(self, x): @pytest.mark.parametrize("inplace", [skip_if_export(True), False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_leaky_relu(self, alpha, inplace, ie_device, precision, ir_version): self._test(*self.create_model(alpha, inplace), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_logical_ops.py b/tests/layer_tests/pytorch_tests/test_logical_ops.py index 210fd1a4bdb690..842d895542afb9 100644 --- a/tests/layer_tests/pytorch_tests/test_logical_ops.py +++ b/tests/layer_tests/pytorch_tests/test_logical_ops.py @@ -53,6 +53,7 @@ def forward_not_out(self, tensor_a, out): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"]) @pytest.mark.parametrize("first_dtype", ["bool", "int32", 'int8', 'float32']) @pytest.mark.parametrize("second_dtype", ["bool", "int32", 'int8', 'float32']) @@ -61,4 +62,4 @@ def test_logical(self, op_type, out, first_dtype, second_dtype, ie_device, preci self._test(*self.create_model(op_type, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "unary": op_type == "not", - "first_dtype": first_dtype, "second_dtype": second_dtype}) \ No newline at end of file + "first_dtype": first_dtype, "second_dtype": second_dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_masked_fill.py b/tests/layer_tests/pytorch_tests/test_masked_fill.py index 8fc59c9149c04e..4959411c26a04b 100644 --- a/tests/layer_tests/pytorch_tests/test_masked_fill.py +++ b/tests/layer_tests/pytorch_tests/test_masked_fill.py @@ -54,6 +54,7 @@ def forward(self, x, mask): @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_masked_fill(self, value, mask_fill, mask_dtype, input_dtype, inplace, ie_device, precision, ir_version): self._test(*self.create_model(value, inplace), ie_device, precision, ir_version, diff --git a/tests/layer_tests/pytorch_tests/test_mean.py b/tests/layer_tests/pytorch_tests/test_mean.py index 1ba45e1a3b3791..46ce6d33918baa 100644 --- a/tests/layer_tests/pytorch_tests/test_mean.py +++ b/tests/layer_tests/pytorch_tests/test_mean.py @@ -80,6 +80,7 @@ def forward_out(self, x, out): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_sum(self, axes, keep_dim, dtype, out, ie_device, precision, ir_version): if PytorchLayerTest.use_torch_export() and out: pytest.skip(reason="export fails for out") diff --git a/tests/layer_tests/pytorch_tests/test_min_max.py b/tests/layer_tests/pytorch_tests/test_min_max.py index beba07b1d02540..d857c222158ac5 100644 --- a/tests/layer_tests/pytorch_tests/test_min_max.py +++ b/tests/layer_tests/pytorch_tests/test_min_max.py @@ -76,6 +76,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_reduce_min_max(self, axes, keep_dims, op_type, ie_device, precision, ir_version): self._test(*self.create_model(op_type, axes, keep_dims, single_input=True), ie_device, precision, ir_version) @@ -86,6 +87,7 @@ def test_reduce_min_max(self, axes, keep_dims, op_type, ie_device, precision, ir @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_min_max(self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(op_type, None, None, single_input=False, dtypes=(first_input_dtype, second_input_dtype)), ie_device, precision, ir_version, kwargs_to_prepare_input= @@ -266,6 +268,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_minimum_maximum( self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version ): @@ -342,4 +345,4 @@ def test_amin_amax(self, op_type, input_dtype, axis, keep_dims, out, ie_device, self._test(*self.create_model(op_type, axis, keep_dims, out), ie_device, precision, ir_version, kwargs_to_prepare_input= {"input_dtype": input_dtype, "out": out, "axes": axis, "keep_dims": keep_dims} - ) \ No newline at end of file + ) diff --git a/tests/layer_tests/pytorch_tests/test_pooling.py b/tests/layer_tests/pytorch_tests/test_pooling.py index f8c190917c2c92..24fc01fdcffaed 100644 --- a/tests/layer_tests/pytorch_tests/test_pooling.py +++ b/tests/layer_tests/pytorch_tests/test_pooling.py @@ -157,6 +157,7 @@ def test_avg_pool1d(self, params, ceil_mode, count_include_pad, ie_device, preci @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): @@ -169,6 +170,7 @@ def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, preci @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_avg_pool3d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): @@ -232,6 +234,7 @@ def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precis @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version): @@ -248,6 +251,7 @@ def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, preci @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_max_pool3d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_pow.py b/tests/layer_tests/pytorch_tests/test_pow.py index fb59a8ab4e1cc5..1321f4d6dd79aa 100644 --- a/tests/layer_tests/pytorch_tests/test_pow.py +++ b/tests/layer_tests/pytorch_tests/test_pow.py @@ -47,6 +47,7 @@ def forward_inplace(self, input_data, exponent): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_pow(self, inplace, ie_device, precision, ir_version, test_input): if inplace and PytorchLayerTest.use_torch_export(): pytest.skip(reason="export fails for inplace") @@ -109,6 +110,7 @@ def forward3(self, lhs, rhs): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_pow_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape): self.lhs_type = lhs_type self.lhs_shape = lhs_shape diff --git a/tests/layer_tests/pytorch_tests/test_repeat.py b/tests/layer_tests/pytorch_tests/test_repeat.py index beab3ab8b10341..9ac59e2f02e5e3 100644 --- a/tests/layer_tests/pytorch_tests/test_repeat.py +++ b/tests/layer_tests/pytorch_tests/test_repeat.py @@ -30,6 +30,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat(self, repeats, ie_device, precision, ir_version): self._test(*self.create_model(repeats), ie_device, precision, ir_version) @@ -56,6 +57,7 @@ def forward(self, x, y): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat(self, repeats, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats}) @@ -79,5 +81,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_repeat_t5(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_roi_align.py b/tests/layer_tests/pytorch_tests/test_roi_align.py index 0079d9c0ca77e6..896bd1a13c0966 100644 --- a/tests/layer_tests/pytorch_tests/test_roi_align.py +++ b/tests/layer_tests/pytorch_tests/test_roi_align.py @@ -52,6 +52,7 @@ def forward(self, input_tensor, rois): @pytest.mark.parametrize('aligned', (True, False)) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_roi_align(self, ie_device, precision, ir_version, input_shape, boxes, output_size, spatial_scale, sampling_ratio, aligned): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) diff --git a/tests/layer_tests/pytorch_tests/test_roll.py b/tests/layer_tests/pytorch_tests/test_roll.py index 5a5f63c772d5c5..eabae207e2b3d6 100644 --- a/tests/layer_tests/pytorch_tests/test_roll.py +++ b/tests/layer_tests/pytorch_tests/test_roll.py @@ -18,12 +18,12 @@ class aten_roll(torch.nn.Module): def __init__(self, shifts, dim=None): super(aten_roll, self).__init__() self.dim = dim - self.shits = shifts + self.shifts = shifts def forward(self, x): if self.dim is not None: - return torch.roll(x, self.shits, self.dim) - return torch.roll(x, self.shits) + return torch.roll(x, self.shifts, self.dim) + return torch.roll(x, self.shifts) ref_net = None @@ -38,5 +38,6 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_roll(self, shifts, dim, ie_device, precision, ir_version): self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_rsub.py b/tests/layer_tests/pytorch_tests/test_rsub.py index 70380a9e5f807d..0d918120f29c6a 100644 --- a/tests/layer_tests/pytorch_tests/test_rsub.py +++ b/tests/layer_tests/pytorch_tests/test_rsub.py @@ -104,9 +104,10 @@ def forward2(self, lhs, rhs:int): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_rsub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type): self.lhs_type = lhs_type self.lhs_shape = lhs_shape self.rhs_type = rhs_type self._test(*self.create_model(lhs_type, rhs_type), - ie_device, precision, ir_version) \ No newline at end of file + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_scatter.py b/tests/layer_tests/pytorch_tests/test_scatter.py index 6f8e0cdd1623d7..34c4ad84adf142 100644 --- a/tests/layer_tests/pytorch_tests/test_scatter.py +++ b/tests/layer_tests/pytorch_tests/test_scatter.py @@ -91,6 +91,7 @@ def _forward_inplace_reduce(self, x: torch.Tensor): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dim", [1, -1, 0]) @pytest.mark.parametrize( "index", diff --git a/tests/layer_tests/pytorch_tests/test_select_scatter.py b/tests/layer_tests/pytorch_tests/test_select_scatter.py new file mode 100644 index 00000000000000..112675264c74a5 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_select_scatter.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest +import torch + + +class TestSelectScatter(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, src, dim, index): + + class aten_select_scatter(torch.nn.Module): + def __init__(self, src=None, dim=None, index=None): + super(aten_select_scatter, self).__init__() + self.src = src + self.dim = dim + self.index = index + + def forward(self, x): + return torch.select_scatter(x, self.src, self.dim, self.index); + + + ref_net = None + + return aten_select_scatter(src, dim, index), ref_net, "aten::select_scatter" + + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize(("src", "dim", "index"), + [(torch.ones(2), 0, 0),]) + def aten_select_scatter(self, src, dim, index, ie_device, precision, ir_version): + self._test(*self.create_model(src, dim, index), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_sign.py b/tests/layer_tests/pytorch_tests/test_sign.py index dac0b32f70d05d..9cad2fbd6ea745 100644 --- a/tests/layer_tests/pytorch_tests/test_sign.py +++ b/tests/layer_tests/pytorch_tests/test_sign.py @@ -45,6 +45,7 @@ def forward_out(self, x, out): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("input_type", ["zeros", "positive", "negative", "mixed"]) @pytest.mark.parametrize("out", [True, False]) def test_sign(self, input_type, out, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_slice_scatter.py b/tests/layer_tests/pytorch_tests/test_slice_scatter.py new file mode 100644 index 00000000000000..0d291f6bb4d3aa --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_slice_scatter.py @@ -0,0 +1,41 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + + +class TestSliceScatter(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 5, 3, 4).astype(np.float32),) + + def create_model(self, src, dim, start, end, step): + + import torch + class aten_slice_scatter(torch.nn.Module): + def __init__(self, src=None, dim=None, start=None, end=None, step=None): + super(aten_slice_scatter, self).__init__() + self.src = src + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x): + return torch.slice_scatter(x, src=self.src, dim=self.dim, start=self.start, end=self.end, step=self.step); + + + ref_net = None + + return aten_slice_scatter(src, dim, start, end, step), ref_net, "aten::slice_scatter" + + import torch + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize(("src", "dim", "start", "end", "step"), + [(torch.ones(2), 1, 1, 2, 1),]) + def aten_slice_scatter(self, src, dim, start, end, step, ie_device, precision, ir_version): + self._test(*self.create_model(src, dim, start, end, step), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_sort.py b/tests/layer_tests/pytorch_tests/test_sort.py index c92508b9e47d16..53f21833a21c50 100644 --- a/tests/layer_tests/pytorch_tests/test_sort.py +++ b/tests/layer_tests/pytorch_tests/test_sort.py @@ -78,6 +78,7 @@ def forward(self, input_tensor): ]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_sort(self, input_shape, descending, stable, ie_device, precision, ir_version): self.input_tensor = [] if type(input_shape) is list: diff --git a/tests/layer_tests/pytorch_tests/test_topk.py b/tests/layer_tests/pytorch_tests/test_topk.py index 512d9ed41f606e..cee8c103ab791d 100644 --- a/tests/layer_tests/pytorch_tests/test_topk.py +++ b/tests/layer_tests/pytorch_tests/test_topk.py @@ -61,6 +61,7 @@ def forward(self, input_tensor): ]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") == 'true', reason="Ticket - 115085") def test_topK(self, input_shape, k, dim, largest, sort, ie_device, precision, ir_version): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) diff --git a/tests/layer_tests/pytorch_tests/test_trilu.py b/tests/layer_tests/pytorch_tests/test_trilu.py index 87afd796da5268..0bdb1dfd983778 100644 --- a/tests/layer_tests/pytorch_tests/test_trilu.py +++ b/tests/layer_tests/pytorch_tests/test_trilu.py @@ -41,6 +41,7 @@ def forward(self, x): @pytest.mark.parametrize("op", ["triu", "tril"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) @@ -89,6 +90,7 @@ def triu_(self, x): @pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"]) @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version): self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, - kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) \ No newline at end of file + kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype}) diff --git a/tests/layer_tests/pytorch_tests/test_unary_ops.py b/tests/layer_tests/pytorch_tests/test_unary_ops.py index 25a6aeccf93d99..d54d1102737134 100644 --- a/tests/layer_tests/pytorch_tests/test_unary_ops.py +++ b/tests/layer_tests/pytorch_tests/test_unary_ops.py @@ -66,7 +66,8 @@ "aten::asinh": torch.asinh, "aten::asinh_": torch.asinh_, "aten::atanh": torch.atanh, - "aten::atanh_": torch.atanh_ + "aten::atanh_": torch.atanh_, + "aten::hardswish": F.hardswish } @@ -117,6 +118,7 @@ def _prepare_input(self): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int8, torch.uint8, torch.int32, torch.int64]) @pytest.mark.parametrize("op_type", [ @@ -160,6 +162,7 @@ def test_unary_op(self, op_type, dtype, ie_device, precision, ir_version): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("op_type", [ @@ -192,7 +195,8 @@ def test_unary_op(self, op_type, dtype, ie_device, precision, ir_version): "aten::atan_", "aten::acosh_", "aten::asinh_", - "aten::atanh_" + "aten::atanh_", + "aten::hardswish" ]) def test_unary_op_float(self, op_type, dtype, ie_device, precision, ir_version): self.dtype = dtype @@ -241,12 +245,14 @@ def test_unary_op_out(self, op_type, dtype, ie_device, precision, ir_version): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("op_type", [ "aten::relu6", "aten::selu", "aten::silu", + "aten::hardswish", "aten::mish", ]) def test_unary_func_op_inplace(self, op_type, dtype, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_unfold.py b/tests/layer_tests/pytorch_tests/test_unfold.py index 4d5b9ee57ffe21..671af872d96973 100644 --- a/tests/layer_tests/pytorch_tests/test_unfold.py +++ b/tests/layer_tests/pytorch_tests/test_unfold.py @@ -39,6 +39,7 @@ def forward(self, input_tensor): @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export + @pytest.mark.precommit_fx_backend def test_unfold(self, ie_device, precision, ir_version, dimension, size, step, input_shape): self.input_tensor = np.random.randn(*input_shape).astype(np.float32) self._test(*self.create_model(dimension, size, step), diff --git a/tests/layer_tests/pytorch_tests/test_var_mean.py b/tests/layer_tests/pytorch_tests/test_var_mean.py index 4863d4d29677b7..8318b6330e0bc0 100644 --- a/tests/layer_tests/pytorch_tests/test_var_mean.py +++ b/tests/layer_tests/pytorch_tests/test_var_mean.py @@ -52,6 +52,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("unbiased", [True, False]) @pytest.mark.parametrize("op_type", ["var", "var_mean", "std", "std_mean"]) @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', @@ -61,6 +62,7 @@ def test_op2args(self, unbiased, op_type, ie_device, precision, ir_version): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_fx_backend @pytest.mark.parametrize("unbiased", [False, True]) @pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, (0, 1), (-1, -2), (0, 1, -1), (0, 1, 2, 3)]) @pytest.mark.parametrize("keepdim", [True, False]) @@ -68,4 +70,4 @@ def test_op2args(self, unbiased, op_type, ie_device, precision, ir_version): @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122715') def test_op(self, unbiased, dim, keepdim, op_type, ie_device, precision, ir_version): - self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, op_type=op_type), ie_device, precision, ir_version) \ No newline at end of file + self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, op_type=op_type), ie_device, precision, ir_version)