-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Eager mode for pytorch with torch-mlir backend. Uses
python
dispatc…
…h key and `__torch_dispatch__` mechanism. short term future work: 1. remove `TorchMLIRTensor` 2. compile cache 3. tracing jit
- Loading branch information
1 parent
b6d1330
commit 5aecdb6
Showing
12 changed files
with
842 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#------------------------------------------------------------------------------- | ||
# Setup PyTorch | ||
#------------------------------------------------------------------------------- | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") | ||
TorchMLIRProbeForPyTorchInstall() | ||
find_package(Torch 1.8 REQUIRED) | ||
|
||
TorchMLIRConfigurePyTorch() | ||
|
||
#------------------------------------------------------------------------------- | ||
# Subdirectories | ||
#------------------------------------------------------------------------------- | ||
|
||
## Declare the sources of the Python module. | ||
|
||
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode | ||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" | ||
ADD_TO_PARENT TorchMLIRPythonSources | ||
SOURCES_GLOB eager_mode/*.py lazytensor/*.py | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# -*- Python -*- | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
# | ||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme | ||
|
||
from typing import Iterable, Union | ||
from torch.fx import GraphModule | ||
from torch_mlir import ir | ||
from torch_mlir.eager_mode.torch_mlir_types import TorchTensorType, PythonType | ||
|
||
class Annotation: | ||
def __init__(self, types: Iterable[Union[TorchTensorType, type]]): | ||
self.types = list(map(lambda t: | ||
PythonType(t) if isinstance(t, type) else t, | ||
types)) | ||
|
||
def __str__(self): | ||
result = f'Annotation instance with {len(self.types)} types\n' | ||
for e, type_ in enumerate(self.types): | ||
result += f' Type of argument {e + 1}: {str(type_)}\n' | ||
return result | ||
|
||
def __iter__(self): | ||
return iter(self.types) | ||
|
||
|
||
class AnnotationConverter: | ||
@staticmethod | ||
def to_mlir_array_attr(annotation: Annotation, | ||
context: ir.Context) -> ir.ArrayAttr: | ||
dict_attrs = [] | ||
for type_ in annotation: | ||
if not isinstance(type_, TorchTensorType): | ||
dict_attrs.append(ir.DictAttr.get({}, context=context)) | ||
continue | ||
|
||
ir_type = type_.to_mlir(context) | ||
with context: | ||
type_attr = ir.TypeAttr.get(ir_type) | ||
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr}) | ||
dict_attrs.append(dict_attr) | ||
|
||
return ir.ArrayAttr.get(dict_attrs, context=context) | ||
|
||
|
||
def annotate_forward_args(module: GraphModule, | ||
types: Iterable[Union[TorchTensorType, type]] | ||
) -> GraphModule: | ||
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes) | ||
for operand, type_ in zip(operands, types): | ||
if isinstance(type_, type): | ||
type_ = PythonType(type_) | ||
operand.update_kwarg('torch_mlir_type', type_) | ||
|
||
return module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
""" | ||
Translator from torch.jit.ScriptFunction to MLIR. | ||
The following defines a function that take a torch.jit.ScriptFunction | ||
and converts it into an MLIR module. | ||
The expected use for this module is to use the function | ||
`build_module(jit_function: torch.jit.ScriptFunction | ||
annotation: Annotation) -> ir.Module` | ||
to convert the TorchScript function into MLIR using the `torch` | ||
dialect. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from torch.jit import ScriptFunction | ||
|
||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder | ||
from torch_mlir.dialects.builtin import FuncOp | ||
from torch_mlir import ir | ||
|
||
from torch_mlir.eager_mode.annotator import Annotation, AnnotationConverter as ac | ||
|
||
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]: | ||
with module.context: | ||
name_attr = ir.StringAttr.get(name) | ||
for op in module.body.operations: | ||
if isinstance(op, FuncOp) and op.name == name_attr: | ||
return op | ||
|
||
return None | ||
|
||
def build_module(jit_function: ScriptFunction, | ||
annotation: Annotation) -> ir.Module: | ||
""" | ||
Translate input function into an MLIR module in the `torch` dialect. | ||
Parameters | ||
---------- | ||
jit_function: ScriptFunction | ||
Function in TorchScript IR to turn into MLIR. | ||
annotation: Annotation | ||
Annotation object representing the types of | ||
the operands of `jit_function`. | ||
Returns | ||
------- | ||
ir.Module | ||
Translation of the input module into an MLIR module | ||
""" | ||
mb = ModuleBuilder() | ||
mb.import_function(jit_function) | ||
|
||
func_op = _get_func_op_with_name(mb.module, jit_function.name) | ||
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder' | ||
|
||
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context) | ||
func_op.attributes['arg_attrs'] = arg_attrs | ||
|
||
return mb.module |
Oops, something went wrong.