Skip to content

Commit

Permalink
Torch Compile - New Op Support (#23310)
Browse files Browse the repository at this point in the history
New op support for:
 - torch.export updates
 - benchmarking model support
 - chatglm2 support

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: ynimmaga <yamini.nimmagadda@intel.com>
Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
Co-authored-by: suryasidd <surya.siddharth.pemmaraju@intel.com>
  • Loading branch information
5 people authored Mar 21, 2024
1 parent 82021a3 commit 09a388f
Show file tree
Hide file tree
Showing 57 changed files with 759 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,104 +29,217 @@ class OperatorSupport(OperatorSupport):
def __init__(self, options):
support_dict = {
"_operator.getitem": None,
"torch.ops.aten._adaptive_avg_pool1d.default": None,
"torch.ops.aten._adaptive_avg_pool2d.default": None,
"torch.ops.aten._adaptive_avg_pool3d.default": None,
"torch.ops.aten._convolution.default": None,
"torch.ops.aten._embedding_bag.default": None,
"torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None,
"torch.ops.aten._local_scalar_dense.default": None,
"torch.ops.aten._log_softmax.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit.no_stats": None,
"torch.ops.aten._native_batch_norm_legit_functional.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default": None,
"torch.ops.aten._softmax.default": None,
"torch.ops.aten._to_copy.default": None,
"torch.ops.aten._unsafe_view.default": None,
"torch.ops.aten._unsafe_view.default": None,
"torch.ops.aten.abs.default": None,
"torch.ops.aten.acos.default": None,
"torch.ops.aten.acosh.default": None,
"torch.ops.aten.adaptive_max_pool1d.default": None,
"torch.ops.aten.adaptive_max_pool2d.default": None,
"torch.ops.aten.adaptive_max_pool3d.default": None,
"torch.ops.aten.add.Scalar": None,
"torch.ops.aten.add.Tensor": None,
"torch.ops.aten.add_.Tensor": None,
"torch.ops.aten.addcmul.default": None,
"torch.ops.aten.addmm.default": None,
"torch.ops.aten.alias.default": None,
"torch.ops.aten.all.default": None,
"torch.ops.aten.amax.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.amin.default": None,
"torch.ops.aten.any.default": None,
"torch.ops.aten.any.dim": None,
"torch.ops.aten.arange.default": None,
"torch.ops.aten.arange.start": None,
"torch.ops.aten.arange.start_step": None,
"torch.ops.aten.argmax.default": None,
"torch.ops.aten.argmin.default": None,
"torch.ops.aten.as_strided.default": None,
"torch.ops.aten.asin.default": None,
"torch.ops.aten.asinh.default": None,
"torch.ops.aten.asinh.default": None,
"torch.ops.aten.atanh.default": None,
"torch.ops.aten.avg_pool2d.default": None,
"torch.ops.aten.avg_pool3d.default": None,
"torch.ops.aten.baddbmm.default": None,
"torch.ops.aten.bitwise_and.Tensor": None,
"torch.ops.aten.bitwise_not.default": None,
"torch.ops.aten.bitwise_or.Tensor": None,
"torch.ops.aten.bitwise_xor.Tensor": None,
"torch.ops.aten.bmm.default": None,
"torch.ops.aten.cat.default": None,
"torch.ops.aten.ceil.default": None,
"torch.ops.aten.clamp.default": None,
"torch.ops.aten.clamp_max.default": None,
"torch.ops.aten.clamp_max.Tensor": None,
"torch.ops.aten.clamp_min.default": None,
"torch.ops.aten.clamp_min.Tensor": None,
"torch.ops.aten.clone.default": None,
"torch.ops.aten.constant_pad_nd.default": None,
"torch.ops.aten.convolution.default": None,
"torch.ops.aten.copy.default": None,
"torch.ops.aten.copy_.default": None,
"torch.ops.aten.cos.default": None,
"torch.ops.aten.cosh.default": None,
"torch.ops.aten.cumsum.default": None,
"torch.ops.aten.detach.default": None,
"torch.ops.aten.detach_.default": None,
"torch.ops.aten.div.Scalar": None,
"torch.ops.aten.div.Tensor": None,
"torch.ops.aten.div.Tensor_mode": None,
"torch.ops.aten.div_.Tensor": None,
"torch.ops.aten.elu.default": None,
"torch.ops.aten.elu_.default": None,
"torch.ops.aten.embedding.default": None,
"torch.ops.aten.empty.memory_format": None,
"torch.ops.aten.erf.default": None,
"torch.ops.aten.eq.Scalar": None,
"torch.ops.aten.eq.Tensor": None,
"torch.ops.aten.erf.default": None,
"torch.ops.aten.exp.default": None,
"torch.ops.aten.expand.default": None,
"torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None,
"torch.ops.aten.fill.Scalar": None,
"torch.ops.aten.fill_.Scalar": None,
"torch.ops.aten.fill.Tensor": None,
"torch.ops.aten.fill_.Tensor": None,
"torch.ops.aten.flip.default": None,
"torch.ops.aten.floor.default": None,
"torch.ops.aten.floor.default": None,
"torch.ops.aten.fmod.Scalar": None,
"torch.ops.aten.fmod.Tensor": None,
"torch.ops.aten.full.default": None,
"torch.ops.aten.full.names": None,
"torch.ops.aten.full_like.default": None,
"torch.ops.aten.gather.default": None,
"torch.ops.aten.ge.Scalar": None,
"torch.ops.aten.ge.Tensor": None,
"torch.ops.aten.gelu.default": None,
"torch.ops.aten.glu.default": None,
"torch.ops.aten.grid_sampler_2d.default": None,
"torch.ops.aten.gt.Scalar": None,
"torch.ops.aten.gt.Tensor": None,
"torch.ops.aten.hardsigmoid.default": None,
"torch.ops.aten.hardswish.default": None,
"torch.ops.aten.hardswish_.default": None,
"torch.ops.aten.hardtanh.default": None,
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten.index_select.default": None,
"torch.ops.aten.isfinite.default": None,
"torch.ops.aten.isinf.default": None,
"torch.ops.aten.isnan.default": None,
"torch.ops.aten.le.Scalar": None,
"torch.ops.aten.le.Tensor": None,
"torch.ops.aten.leaky_relu.default": None,
"torch.ops.aten.leaky_relu_.default": None,
"torch.ops.aten.lift_fresh_copy.default": None,
"torch.ops.aten.linalg_vector_norm.default": None,
"torch.ops.aten.lt.Tensor": None,
"torch.ops.aten.log.default": None,
"torch.ops.aten.log_sigmoid_forward.default": None,
"torch.ops.aten.log10.default": None,
"torch.ops.aten.log1p.default": None,
"torch.ops.aten.log2.default": None,
"torch.ops.aten.logical_not.default": None,
"torch.ops.aten.logsumexp.default": None,
"torch.ops.aten.masked_fill_.Scalar": None,
"torch.ops.aten.lt.Scalar": None,
"torch.ops.aten.lt.Tensor": None,
"torch.ops.aten.masked_fill.Scalar": None,
"torch.ops.aten.masked_fill.Tensor": None,
"torch.ops.aten.masked_fill_.Scalar": None,
"torch.ops.aten.masked_fill_.Tensor": None,
"torch.ops.aten.max.default": None,
"torch.ops.aten.max.dim": None,
"torch.ops.aten.max_pool2d_with_indices.default": None,
"torch.ops.aten.max_pool3d_with_indices.default": None,
"torch.ops.aten.maximum.default": None,
"torch.ops.aten.mean.default": None,
"torch.ops.aten.mean.dim": None,
"torch.ops.aten.min.default": None,
"torch.ops.aten.min.dim": None,
"torch.ops.aten.minimum.default": None,
"torch.ops.aten.mm.default": None,
"torch.ops.aten.mul.Scalar": None,
"torch.ops.aten.mul.Tensor": None,
"torch.ops.aten.native_batch_norm.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten.native_dropout.default": None,
"torch.ops.aten.native_group_norm.default": None,
"torch.ops.aten.native_layer_norm.default": None,
"torch.ops.aten.new_full.default": None,
"torch.ops.aten.ne.Scalar": None,
"torch.ops.aten.ne.Tensor": None,
"torch.ops.aten.neg.default": None,
"torch.ops.aten.new_full.default": None,
"torch.ops.aten.new_ones.default": None,
"torch.ops.aten.new_zeros.default": None,
"torch.ops.aten.ones.default": None,
"torch.ops.aten.permute.default": None,
"torch.ops.aten.pow.Scalar": None,
"torch.ops.aten.pow.Tensor_Scalar": None,
"torch.ops.aten.pow.Tensor_Tensor": None,
"torch.ops.aten.rand.default": None,
"torch.ops.aten.reciprocal.default": None,
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
"torch.ops.aten.repeat.default": None,
"torch.ops.aten.roll.default": None,
"torch.ops.aten.rsqrt.default": None,
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten.rsub.Tensor": None,
"torch.ops.aten.scalar_tensor.default": None,
"torch.ops.aten.scatter.src": None,
"torch.ops.aten.scatter.value": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.select_scatter.default": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.sign.default": None,
"torch.ops.aten.silu.default": None,
"torch.ops.aten.silu_.default": None,
"torch.ops.aten.sin.default": None,
"torch.ops.aten.sinh.default": None,
"torch.ops.aten.slice.Tensor": None,
"torch.ops.aten.slice_scatter.default": None,
"torch.ops.aten.sort.default": None,
"torch.ops.aten.split.Tensor": None,
"torch.ops.aten.split_with_sizes.default": None,
"torch.ops.aten.sqrt.default": None,
"torch.ops.aten.squeeze.dim": None,
"torch.ops.aten.squeeze.dims": None,
"torch.ops.aten.stack.default": None,
"torch.ops.aten.sub.default": None,
"torch.ops.aten.sub.Tensor": None,
"torch.ops.aten.sum.default": None,
"torch.ops.aten.sum.dim_IntList": None,
"torch.ops.aten.t.default": None,
"torch.ops.aten.tan.default": None,
"torch.ops.aten.tanh.default": None,
"torch.ops.aten.topk.default": None,
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
"torch.ops.aten.upsample_nearest2d.default": None,
"torch.ops.aten.var.correction": None,
"torch.ops.aten.var_mean.correction": None,
"torch.ops.aten.view.default": None,
"torch.ops.aten.where.self": None,
"torch.ops.aten.zeros_like.default": None,
"torch.ops.torchvision.deform_conv2d.default": None,
"torch.ops.torchvision.roi_align.default": None,
}

for op in _get_disabled_ops(options):
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
const auto& outputs = m_model_decoder->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
m_name_to_place.emplace(std::to_string(outputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
for (const auto& name : out_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
}
Expand Down
38 changes: 38 additions & 0 deletions src/frontends/pytorch/src/op/any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_logical_or.hpp"
#include "openvino/op/reshape.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_any_fx(const NodeContext& context) {
num_inputs_check(context, 1, 3);
auto x = context.get_input(0);

Output<Node> dims;
if (!context.input_is_none(1)) {
dims = context.get_input(1);
} else {
dims = get_axes_range(context, 0);
}
bool keep_dims = false;
if (!context.input_is_none(2))
keep_dims = context.const_input<bool>(2);
auto any = context.mark_node(std::make_shared<ov::op::v1::ReduceLogicalOr>(x, dims, keep_dims));
return {any};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
6 changes: 4 additions & 2 deletions src/frontends/pytorch/src/op/argmax_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode)
}
if (!context.input_is_none(1)) {
auto axis = context.const_input<int64_t>(1);
auto topk = context.mark_node(std::make_shared<v3::TopK>(input, k, axis, mode, TopKSortType::NONE));
auto topk = context.mark_node(
std::make_shared<v11::TopK>(input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true));
indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keep_dims) {
auto axis_to_remove = context.mark_node(v0::Constant::create(element::i32, Shape{}, {axis}));
Expand All @@ -41,7 +42,8 @@ OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode)
int64_t axis = 0;
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto flatten_input = context.mark_node(std::make_shared<v1::Reshape>(input, minus_one, false));
auto topk = context.mark_node(std::make_shared<v3::TopK>(flatten_input, k, axis, mode, TopKSortType::NONE));
auto topk = context.mark_node(
std::make_shared<v11::TopK>(flatten_input, k, axis, mode, TopKSortType::SORT_VALUES, element::i32, true));
indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (keep_dims) {
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
Expand Down
16 changes: 10 additions & 6 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,24 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
};

OutputVector translate_stack_fx(const NodeContext& context) {
num_inputs_check(context, 2, context.get_input_size());
num_inputs_check(context, 1, context.get_input_size());
auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
std::deque<Output<Node>> list_elems;
auto num_elements = context.get_input_size();
if (num_elements > 2)
num_elements = num_elements - 1;
for (size_t i = 0; i < num_elements; i++) {
for (size_t i = 0; i < num_elements - 1; i++) {
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(i)), dim));
list_elems.push_back(stack_input);
}
int64_t axis = 0;
if (context.get_input_size() > 2)
axis = context.const_input<int64_t>(context.get_input_size() - 1);
if (!context.get_input_type(num_elements - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(num_elements - 1);
} else {
auto stack_input = context.mark_node(
std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(num_elements - 1)), dim));
list_elems.push_back(stack_input);
}
return translate_cat_common(context, list_elems, axis, true);
}

Expand Down
11 changes: 11 additions & 0 deletions src/frontends/pytorch/src/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ OutputVector translate_div_fx(const NodeContext& context) {
return translate_div_common(context, x, y, rounding_mode, false);
};

OutputVector translate_div_fx_(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
std::string rounding_mode = "";
if (context.has_attribute("rounding_mode")) {
rounding_mode = context.get_attribute<std::string>("rounding_mode");
}
return translate_div_common(context, x, y, rounding_mode, true);
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
20 changes: 16 additions & 4 deletions src/frontends/pytorch/src/op/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_embedding_bag(const NodeContext& context) {
OutputVector translate_embedding_bag_common(const NodeContext& context) {
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
num_inputs_check(context, 9, 9);
// we have only EmbeddingBagSum case support, check it before translation
auto mode = context.const_input<int64_t>(4);
PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
Expand All @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
// with offsets case
auto offsets = context.get_input(2);
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
auto include_last_offset = context.const_input<bool>(7);
bool include_last_offset = false;
if (!context.input_is_none(7))
include_last_offset = context.const_input<bool>(7);
PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
// no per_sample_wights
if (context.input_is_none(6)) {
Expand All @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
return {result, zero, zero, zero};
};

OutputVector translate_embedding_bag(const NodeContext& context) {
num_inputs_check(context, 9, 9);
return translate_embedding_bag_common(context);
}

OutputVector translate_embedding_bag_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
ov::OutputVector output = translate_embedding_bag_common(context);
return {context.mark_node(make_list_construct(output))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
Loading

0 comments on commit 09a388f

Please sign in to comment.