Skip to content

Commit

Permalink
Fix the aten::mv for pytorch models openvinotoolkit#22073 (openvinoto…
Browse files Browse the repository at this point in the history
…olkit#22677)

### Details:
 - *item1*
 - *...*
Add aten::mv operator
close openvinotoolkit#22073
### Tickets:
 - *ticket-id*

---------

Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
3 people authored and alvoron committed Apr 29, 2024
1 parent c131bb0 commit 4d4d0e2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::masked_scatter_", op::inplace_op<op::translate_masked_scatter>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max},
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
Expand Down
42 changes: 42 additions & 0 deletions tests/layer_tests/pytorch_tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest

class TestMatMulOperation(PytorchLayerTest):
def _prepare_input(self, matrix, vector):
matrix_input = np.array(matrix, dtype=np.float32)
vector_input = np.array(vector, dtype=np.float32)
return matrix_input, vector_input

def create_model(self, matrix, vector):
class CustomMatMulOperation(torch.nn.Module):
def forward(self, matrix, vector):
return torch.mv(matrix, vector)

model_class = CustomMatMulOperation()
ref_net = None
return model_class, ref_net, "aten::mv"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("matrix, vector, dtype", [
(np.array([[1, 2], [3, 4]]), np.array([5, 6]), torch.float64),
(np.array([[0, 0], [0, 0]]), np.array([1, 2]), torch.float32),
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([0, 1, 0]), torch.float64),
(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([2, 3, 4]), torch.float32),
])
def test_matmul_operation(self, matrix, vector, dtype, ie_device, precision, ir_version):
matrix_input = torch.tensor(matrix, dtype=torch.float32)
vector_input = torch.tensor(vector, dtype=torch.float32)

matrix_input = matrix_input.to(dtype=dtype)
vector_input = vector_input.to(dtype=dtype)

self._test(
*self.create_model(matrix_input, vector_input),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"matrix": matrix_input, "vector": vector_input}
)

0 comments on commit 4d4d0e2

Please sign in to comment.