Skip to content

Commit

Permalink
Support aten::bucketize for pytorch models #23328 (#23527)
Browse files Browse the repository at this point in the history
### Closes [#23328
](#23328)
### Details:
 - Support aten::bucketize for pytorch models
  • Loading branch information
awayzjj authored Mar 25, 2024
1 parent 6af58f2 commit b8b2bd9
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/frontends/pytorch/src/op/bucketize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/bucketize.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_bucketize(const NodeContext& context) {
num_inputs_check(context, 2, 5);
auto input = context.get_input(0);
auto boundaries = context.get_input(1);

element::Type output_type = ov::element::i64;
if (!context.input_is_none(2) && context.const_input<bool>(2)) {
output_type = ov::element::i32;
}

bool with_right_bound = true;
if (!context.input_is_none(3)) {
with_right_bound = !context.const_input<bool>(3);
}

auto bucketize =
context.mark_node(std::make_shared<v3::Bucketize>(input, boundaries, output_type, with_right_bound));

if (!context.input_is_none(4)) {
context.mutate_input(4, bucketize);
}

return {bucketize};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OP_CONVERTER(translate_bitwise_and);
OP_CONVERTER(translate_bitwise_not);
OP_CONVERTER(translate_bitwise_or);
OP_CONVERTER(translate_bitwise_xor);
OP_CONVERTER(translate_bucketize);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_cdist);
OP_CONVERTER(translate_celu);
Expand Down Expand Up @@ -374,6 +375,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::Bool", op::translate_bool},
// aten::broadcast_tensors - Supported in limited set of patterns
{"aten::broadcast_to", op::translate_expand},
{"aten::bucketize", op::translate_bucketize},
{"aten::cat", op::translate_cat},
{"aten::cdist", op::translate_cdist},
{"aten::ceil", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Ceiling>, 1>},
Expand Down
53 changes: 53 additions & 0 deletions tests/layer_tests/pytorch_tests/test_bucketize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestBucketize(PytorchLayerTest):

def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype):
return (
np.random.randn(*input_shape).astype(input_dtype),
np.arange(*boundaries_range).astype(boundaries_dtype))

def create_model(self, out_int32, right, is_out):
class aten_bucketize(torch.nn.Module):

def __init__(self, out_int32, right, is_out) -> None:
super().__init__()
self.out_int32 = out_int32
self.right = right
self.is_out = is_out

def forward(self, input, boundaries):
if self.is_out:
output_dtype = torch.int32 if self.out_int32 else torch.int64
output = torch.zeros_like(input, dtype=output_dtype)
torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right, out=output)
return output
else:
return torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right)

ref_net = None

return aten_bucketize(out_int32, right, is_out), ref_net, "aten::bucketize"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("out_int32", [True, False])
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("is_out", [True, False])
@pytest.mark.parametrize("input_shape", [[1, ], [2, 1], [2, 2, 1]])
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
@pytest.mark.parametrize("boundaries_range", [[1, 10], (100, 200)])
@pytest.mark.parametrize("boundaries_dtype", ["float32", "int32"])
def test_bucketize(self, input_shape, boundaries_range, input_dtype, boundaries_dtype, out_int32, right, is_out, ie_device, precision, ir_version):
self._test(*self.create_model(out_int32, right, is_out), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "input_dtype": input_dtype,
"boundaries_range": boundaries_range, "boundaries_dtype": boundaries_dtype,
})

0 comments on commit b8b2bd9

Please sign in to comment.