Skip to content

Commit

Permalink
Eager mode for pytorch with torch-mlir backend. Uses python dispatc…
Browse files Browse the repository at this point in the history
…h key and `__torch_dispatch__` mechanism.

short term future work:

1. remove `TorchMLIRTensor`
2. compile cache
3. tracing jit
  • Loading branch information
makslevental committed Mar 14, 2022
1 parent b6d1330 commit 5aecdb6
Show file tree
Hide file tree
Showing 12 changed files with 842 additions and 3 deletions.
9 changes: 6 additions & 3 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
)

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS, EAGERMODEBACKED_XFAIL_SET

# Import tests to register them in the global registry.
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
Expand Down Expand Up @@ -56,7 +56,7 @@
from . import cast

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external', 'eager_mode']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
Expand Down Expand Up @@ -120,6 +120,9 @@ def main():
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
xfail_set = {}
elif args.config == 'eager_mode':
config = EagerModeTestConfig()
xfail_set = EAGERMODEBACKED_XFAIL_SET
elif args.config == 'external':
with open(args.external_config, 'r') as f:
code = compile(f.read(), args.external_config, 'exec')
Expand Down
10 changes: 10 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS

EAGERMODEBACKED_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
# scalars get passed down as tensors of float64 for some reason
# note these don't actually fail the behavioral tests (because the values get interpreted as the same
# but they fail the dtype checks in framework.py
"ElementwiseMulScalarModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalarInt64Module_basic",
})

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
6 changes: 6 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
add_subdirectory(torch_mlir_e2e_test)
endif()

################################################################################
# Eager mode
################################################################################

add_subdirectory(torch_mlir/eager_mode)

################################################################################
# Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir/eager_mode/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#-------------------------------------------------------------------------------
# Setup PyTorch
#-------------------------------------------------------------------------------

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.8 REQUIRED)

TorchMLIRConfigurePyTorch()

#-------------------------------------------------------------------------------
# Subdirectories
#-------------------------------------------------------------------------------

## Declare the sources of the Python module.

declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
)
Empty file.
58 changes: 58 additions & 0 deletions python/torch_mlir/eager_mode/annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme

from typing import Iterable, Union
from torch.fx import GraphModule
from torch_mlir import ir
from torch_mlir.eager_mode.torch_mlir_types import TorchTensorType, PythonType

class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(map(lambda t:
PythonType(t) if isinstance(t, type) else t,
types))

def __str__(self):
result = f'Annotation instance with {len(self.types)} types\n'
for e, type_ in enumerate(self.types):
result += f' Type of argument {e + 1}: {str(type_)}\n'
return result

def __iter__(self):
return iter(self.types)


class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation,
context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue

ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
dict_attrs.append(dict_attr)

return ir.ArrayAttr.get(dict_attrs, context=context)


def annotate_forward_args(module: GraphModule,
types: Iterable[Union[TorchTensorType, type]]
) -> GraphModule:
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
for operand, type_ in zip(operands, types):
if isinstance(type_, type):
type_ = PythonType(type_)
operand.update_kwarg('torch_mlir_type', type_)

return module
64 changes: 64 additions & 0 deletions python/torch_mlir/eager_mode/lazytensor/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a function that take a torch.jit.ScriptFunction
and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction
annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch`
dialect.
"""

from typing import Optional

from torch.jit import ScriptFunction

from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir import ir

from torch_mlir.eager_mode.annotator import Annotation, AnnotationConverter as ac

def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
return op

return None

def build_module(jit_function: ScriptFunction,
annotation: Annotation) -> ir.Module:
"""
Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
mb = ModuleBuilder()
mb.import_function(jit_function)

func_op = _get_func_op_with_name(mb.module, jit_function.name)
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'

arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
func_op.attributes['arg_attrs'] = arg_attrs

return mb.module
Loading

0 comments on commit 5aecdb6

Please sign in to comment.