Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Add ModuleExtension #23536

Merged
merged 23 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3f1bc39
Draft of ModuleExtension class which triggers PyTorch model patching …
slyalin Feb 2, 2024
bbd11ec
Provided a stub op PagedAttentionPlaceholder that translates input sh…
slyalin Feb 5, 2024
6d9c2fb
Merge remote-tracking branch 'origin/master' into pytorch_module_exte…
slyalin Feb 5, 2024
4a2f534
Wrapper to glue custom op and tracable PyTorch code
slyalin Feb 5, 2024
2d4a145
Added a stub for PageAttentionPlaceholder to run in CPU plugin
slyalin Feb 6, 2024
e1cb833
Cleanup, more checks, removed PagedAttentionPlacehoder, more accurate…
slyalin Feb 7, 2024
86e0263
Fixed model unpatching when using ModuleExtension
slyalin Feb 8, 2024
2af4abc
Fixed missing sharing of torch.Tensor when passing to OV
slyalin Feb 8, 2024
a157c65
Fix relative paths for shared library extensions in convert_model
slyalin Feb 12, 2024
b4293f1
Merged with master
slyalin Feb 12, 2024
5c6218e
Merged from master
slyalin Feb 12, 2024
ac97fa8
Added support for OPENVINO_FRAMEWORK_MAP(pytorch), including convert_…
slyalin Feb 13, 2024
0ead7f1
Merge remote-tracking branch 'origin/master' into pytorch_module_exte…
slyalin Feb 13, 2024
95b9f9e
Merged from master
slyalin Feb 15, 2024
c81019c
ModuleExtension: documentation, new names for arguments, reasonable d…
slyalin Feb 15, 2024
757caa1
Merge branch 'master' into pytorch_module_extension
slyalin Feb 26, 2024
a16b402
Add tests
mvafin Mar 11, 2024
a9ce494
Merge remote-tracking branch 'upstream/master' into mvafin/pt_fe/modu…
mvafin Mar 19, 2024
e8b8c55
Fix style
mvafin Mar 19, 2024
ebfe4c1
Merge branch 'master' into mvafin/pt_fe/module_extension
mvafin Mar 19, 2024
9f4813e
Remove debug print
mvafin Mar 19, 2024
1867c33
Update tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_uti…
mvafin Mar 19, 2024
25f9eb2
Merge branch 'master' into mvafin/pt_fe/module_extension
mvafin Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.frontend.pytorch.py_pytorch_frontend import ConversionExtensionPytorch as ConversionExtension
from openvino.frontend.pytorch.py_pytorch_frontend import OpExtensionPytorch as OpExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
except ImportError as err:
raise ImportError("OpenVINO PyTorch frontend is not available, please make sure the frontend is built."
"{}".format(err))
mvafin marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

class ModuleExtension:
mvafin marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, module, target_op, evaluate=None, convert=None):
"""
Creates an extension that replaces entire PyTorch module by a single operation.
This functionality works with PyTorch models only. A module can be identified by
module type (e.g. torch.nn.Linear), module instance in the model or module name.

Args:
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch module to replace

target_op (str): a target operation that will be used as a replacer for the module,
could be a name of the extension operation or existing PyTorch operation
(with prim:: or aten:: prefix following TorchScript syntax).

evaluate (callable with args module, *args, **kwargs): a callable that will replace a target
module in model execution it is responsible for producing valid output for
the module to allow correct model tracing. By default it calls original module
forward with the same arguments. The provided code will not be a part of the final
traced model, it is used only to produce valid results in the tracing.

convert (callable with args target_op, *args, **kwargs): a callable that will be traced and become
a part of the final model instead of the target module. It accepts target_op as
the first parameter, target_op is callable that will appear as a single node in the
graph, the type of the node is target_op provided as another argument above.
"""
self.module = module
self.target_op = target_op
self.evaluate = evaluate
if self.evaluate is None:
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs)
self.convert = convert
if self.convert is None:
self.convert = lambda module, target_op, *args, **kwargs: target_op(*args, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

import torch


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None


def patch_model(model, module_extensions, orig_forward_name):
for name, m in model.named_modules():
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. '
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.')
mvafin marked this conversation as resolved.
Show resolved Hide resolved
continue
extension = None
if m in module_extensions:
extension = module_extensions[m]
elif m.__class__ in module_extensions:
extension = module_extensions[m.__class__]
elif name in module_extensions:
extension = module_extensions[name]

if extension:
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
class Trampoline(torch.autograd.Function):
target_extension = extension
original_module = m
stashed_args = None
stashed_kwargs = None

@staticmethod
@torch.jit.ignore
def forward(*args, **kwargs):
with no_jit_trace():
# `module` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
# stash the current forward to be able to return it back
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward


def unpatch_model(model, orig_forward_name):
for _, m in model.named_modules():
if hasattr(m, orig_forward_name):
try:
m.forward = getattr(m, orig_forward_name)
delattr(m, orig_forward_name)
except Exception as error:
print('[ WARNING ] Exception raised during model unpatching. Depending on the exact issue it may lead to broken original model.')
print('Original exception details:')
print(error)
mvafin marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 42 additions & 4 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,32 @@
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq
from openvino.frontend.pytorch import patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import typing
import torch


class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False, constant_cache=None):
def __init__(
self,
pt_module,
graph_element=None,
example_input=None,
alias_db=None,
shared_memory=True,
skip_freeze=False,
constant_cache=None,
module_extensions=None):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
self._shared_memory = shared_memory
self._input_is_list = False
self.constant_cache = constant_cache if constant_cache is not None else dict()
self.module_extensions = module_extensions
if graph_element is None:
try:
pt_module = self._get_scripted_model(
Expand Down Expand Up @@ -89,14 +101,22 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
input_params = inspect.signature(pt_module.forward if hasattr(
pt_module, "forward") else pt_module.__call__).parameters
input_signature = list(input_params)

if example_inputs is None:
if self.module_extensions:
raise RuntimeError("ModuleExtension is not supported for scripting. Please provide valid example_input argument to run tracing.")
scripted = torch.jit.script(pt_module)
freeze_by_default = True
else:
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)
gptq_patched = False

# name of attribute in a patched module where the original forward method is kept
orig_forward_name = '_openvino_module_extension_patch_orig_forward'
if self.module_extensions:
patch_model.patch_model(pt_module, self.module_extensions, orig_forward_name)

gptq_patched = False
if gptq.detect_gptq_model(pt_module):
try:
gptq.patch_model(pt_module)
Expand All @@ -115,6 +135,8 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)
if self.module_extensions:
patch_model.unpatch_model(pt_module, orig_forward_name)

if not freeze_by_default and graph_has_ops(scripted.inlined_graph, ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"]):
# freeze models with unsupported ops
Expand Down Expand Up @@ -232,7 +254,8 @@ def visit_subgraph(self, node_visitor) -> None:
node,
alias_db=self.alias_db,
shared_memory=self._shared_memory,
constant_cache=self.constant_cache)
constant_cache=self.constant_cache,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
node_visitor(decoder)

Expand All @@ -255,13 +278,28 @@ def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module,
self.get_subgraphs()[index],
alias_db=self.alias_db,
shared_memory=self._shared_memory)
shared_memory=self._shared_memory,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
return decoder

def get_op_type(self) -> str:
assert isinstance(
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
if self.graph_element.kind() == "prim::PythonOp":
if hasattr(self.graph_element, 'pyobj') and callable(self.graph_element.pyobj) and hasattr(self.graph_element.pyobj(), '__self__'):
trampoline = self.graph_element.pyobj().__self__
if hasattr(trampoline, 'target_extension') and isinstance(trampoline.target_extension, ModuleExtension):
target_op = trampoline.target_extension.target_op
if callable(target_op):
target = target_op(trampoline.original_module)
elif isinstance(target_op, str):
target = target_op
# TODO: Support target as a callable that will play a role of ConversionExtension for an entire module instead of a single op.
mvafin marked this conversation as resolved.
Show resolved Hide resolved
# Without supporting target as a callable here, ConversionExtension functionality is still possible to implement
# by combining two extensions: ModuleExtension that use temporary name as a target op and another extension of type ConversionExtension
# that translates that particular temporary name to custom graph. But providing conversion code as a callable `target` is more convenient.
return target
return self.graph_element.kind()

def get_schema(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
}

std::map<std::string, CreatorFunction> FrontEnd::get_supported_ops(const ov::frontend::InputModel::Ptr& model) const {
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops_fx();
std::map<std::string, CreatorFunction> supported_ops;
if (std::dynamic_pointer_cast<pytorch::InputModel>(model)->decoder_type_name() == "fx")
supported_ops = get_supported_ops_fx();
else
Expand Down
74 changes: 54 additions & 20 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,17 @@ def forward(self, x):
"Parameter", "ReluCustom", "Result"]


def test_op_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import OpExtension

class CosModel(torch.nn.Module):
def __init__(self):
class CosModel(torch.nn.Module):
def __init__(self):
super(CosModel, self).__init__()

def forward(self, x):
def forward(self, x):
return torch.cos(x.to(torch.float32))

def test_op_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import OpExtension

model = CosModel()
decoder = TorchScriptPythonDecoder(get_scripted_model(model))

Expand All @@ -327,13 +327,6 @@ def test_op_extension_generic():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend import OpExtension

class CosModel(torch.nn.Module):
def __init__(self):
super(CosModel, self).__init__()

def forward(self, x):
return torch.cos(x.to(torch.float32))

model = CosModel()
decoder = TorchScriptPythonDecoder(get_scripted_model(model))

Expand All @@ -355,6 +348,49 @@ def forward(self, x):
"Parameter", "Convert", "Sin", "Result"]


def test_module_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import ModuleExtension
from openvino import convert_model

class ModelWithModule(torch.nn.Module):
def __init__(self):
super(ModelWithModule, self).__init__()
self.cos_module = CosModel()

def forward(self, x):
return self.cos_module(x)

model = ModelWithModule()
decoder = TorchScriptPythonDecoder(model)

fem = FrontEndManager()
fe = fem.load_by_framework(framework="pytorch")
assert fe

input_model = fe.load(decoder)
assert input_model
converted_model = fe.convert(input_model)
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Convert", "Cos", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(model.cos_module, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension("cos_module", "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]


def test_pytorch_telemetry():
from openvino.frontend import TelemetryExtension
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -547,7 +583,7 @@ def forward(self, x: float, y: torch.Tensor):
assert PartialShape(pt_out_shape) == om.get_output_partial_shape(0)


class TestModel1(torch.nn.Module):
class ModelTest1(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.pool = torch.nn.AdaptiveAvgPool2d(1)
Expand All @@ -559,8 +595,7 @@ def forward(self, x):
def test_output_dict_names():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

input = torch.ones((1, 3, 224, 224))
model = TestModel1()
model = ModelTest1()
decoder = TorchScriptPythonDecoder(
model, example_input=(torch.randn(1, 3, 224, 224),))
fe_manager = FrontEndManager()
Expand All @@ -570,7 +605,7 @@ def test_output_dict_names():
assert om.outputs[0].any_name == "x1" and om.outputs[1].any_name == "x2", "Output dict names are not expected"


class TestModel2(torch.nn.Module):
class ModelTest2(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.pool = torch.nn.AdaptiveAvgPool2d(1)
Expand All @@ -582,8 +617,7 @@ def forward(self, x):
def test_output_tuple_names():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

input = torch.ones((1, 3, 224, 224))
model = TestModel2()
model = ModelTest2()
decoder = TorchScriptPythonDecoder(
model, example_input=(torch.randn(1, 3, 224, 224),))
fe_manager = FrontEndManager()
Expand Down
4 changes: 3 additions & 1 deletion tools/ovc/openvino/tools/ovc/convert_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# pylint: disable=no-name-in-module,import-error
from openvino.frontend import FrontEndManager, OpConversionFailure, TelemetryExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
mvafin marked this conversation as resolved.
Show resolved Hide resolved
from openvino.runtime import get_version as get_rt_version
from openvino.runtime import Type, PartialShape

Expand Down Expand Up @@ -173,7 +174,8 @@ def prepare_ir(argv: argparse.Namespace):
moc_front_end.add_extension(TelemetryExtension("ovc", t.send_event, t.send_error, t.send_stack_trace))
if any_extensions_used(argv):
for extension in argv.extension:
moc_front_end.add_extension(extension)
if not isinstance(extension, ModuleExtension):
mvafin marked this conversation as resolved.
Show resolved Hide resolved
moc_front_end.add_extension(extension)
mvafin marked this conversation as resolved.
Show resolved Hide resolved
ov_model = moc_pipeline(argv, moc_front_end)
return ov_model

Expand Down
Loading
Loading