From d7c833b96096520a161aa6922a8467f47eca36ee Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Tue, 19 Jul 2022 12:25:13 -0700 Subject: [PATCH] python: trim registration and loading of dialects and passes In the interest of merging upstream LLVM quickly, a previous patch (7f08169) updated the torch-mlir build to register all dialects and passes through Python bindings. This patch limits the dialects and passes to only those that are used in torch-mlir. Key to this change are the removal of `MLIRPythonExtension.RegisterEverything` and the introduction of a new Python module (`_mlir_libs/_site_initialize_0.py`), where we register the dialects and passes used by torch-mlir. --- include/torch-mlir-c/Registration.h | 3 ++ lib/CAPI/CMakeLists.txt | 17 +++++++ lib/CAPI/Registration.cpp | 44 +++++++++++++++++-- python/CMakeLists.txt | 13 ++---- python/TorchMLIRModule.cpp | 13 ++---- .../_mlir_libs/_site_initialize_0.py | 4 ++ python/torch_mlir/dialects/torch/__init__.py | 2 +- test/python/smoketest.py | 2 +- 8 files changed, 73 insertions(+), 25 deletions(-) create mode 100644 python/torch_mlir/_mlir_libs/_site_initialize_0.py diff --git a/include/torch-mlir-c/Registration.h b/include/torch-mlir-c/Registration.h index 4d582e61f13..e83823b7e9f 100644 --- a/include/torch-mlir-c/Registration.h +++ b/include/torch-mlir-c/Registration.h @@ -22,6 +22,9 @@ extern "C" { */ MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context); +/** Registers upstream (MLIR) dialects used in Torch-MLIR IRs. */ +MLIR_CAPI_EXPORTED void torchMlirRegisterRequiredDialects(MlirContext context); + /** Registers all passes for symbolic access with the global registry. */ MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(); diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 87977a86fa1..4e4a058dd1e 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -13,6 +13,23 @@ add_mlir_public_c_api_library(TorchMLIRCAPI LINK_LIBS PUBLIC MLIRIR + MLIRAffineToStandard + MLIRArithmeticToLLVM + MLIRArithmeticTransforms + MLIRBufferizationTransforms + MLIRControlFlowToLLVM + MLIRFuncToLLVM + MLIRFuncTransforms + MLIRLinalgToLLVM + MLIRLinalgTransforms + MLIRMathToLLVM + MLIRMemRefToLLVM + MLIRReconcileUnrealizedCasts + MLIRSCFToControlFlow + MLIRSCFTransforms + MLIRTensorTransforms + MLIRTosaToArith + MLIRTosaToLinalg MLIRSupport TorchMLIRTorchDialect TorchMLIRInitAll diff --git a/lib/CAPI/Registration.cpp b/lib/CAPI/Registration.cpp index 52cd10b38ba..c83f3f43ee7 100644 --- a/lib/CAPI/Registration.cpp +++ b/lib/CAPI/Registration.cpp @@ -10,11 +10,27 @@ #include "torch-mlir-c/Registration.h" #include "mlir/CAPI/IR.h" -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/InitAllPasses.h" #include "torch-mlir/InitAll.h" +void torchMlirRegisterRequiredDialects(MlirContext context) { + mlir::DialectRegistry registry; + registry.insert(); + unwrap(context)->appendDialectRegistry(registry); +} + void torchMlirRegisterAllDialects(MlirContext context) { mlir::DialectRegistry registry; mlir::torch::registerAllDialects(registry); @@ -23,4 +39,24 @@ void torchMlirRegisterAllDialects(MlirContext context) { unwrap(context)->loadAllAvailableDialects(); } -void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); } +void torchMlirRegisterAllPasses() { + mlir::arith::registerArithmeticPasses(); + mlir::bufferization::registerBufferizationPasses(); + mlir::func::registerFuncPasses(); + mlir::registerConvertAffineToStandardPass(); + mlir::registerConvertArithmeticToLLVMPass(); + mlir::registerConvertControlFlowToLLVMPass(); + mlir::registerConvertFuncToLLVMPass(); + mlir::registerConvertLinalgToLLVMPass(); + mlir::registerConvertMathToLLVMPass(); + mlir::registerConvertMemRefToLLVMPass(); + mlir::registerLinalgPasses(); + mlir::registerReconcileUnrealizedCastsPass(); + mlir::registerSCFPasses(); + mlir::registerSCFToControlFlowPass(); + mlir::registerTosaToArithPass(); + mlir::registerTosaToLinalgNamedPass(); + mlir::registerTosaToLinalgPass(); + mlir::tensor::registerTensorPasses(); + mlir::torch::registerAllPasses(); +} diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 77a31184c45..8bd781d5585 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -26,6 +26,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel SOURCES __init__.py compiler_utils.py + _mlir_libs/_site_initialize_0.py ) declare_mlir_python_sources(TorchMLIRPythonSources.Dialects @@ -102,16 +103,10 @@ add_subdirectory(torch_mlir/eager_mode) ################################################################################ set(_source_components - # TODO: Core is now implicitly building/registering all dialects, increasing - # build burden by ~5x. Make it stop. - # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes - # for the reference backend, but logically they can be separate. But seemingly - # the only way to handle that is to create a separate mlir python package - # tree, which seems excessive. - MLIRPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.func + MLIRPythonSources.ExecutionEngine MLIRPythonExtension.Core - MLIRPythonExtension.RegisterEverything - MLIRPythonExtension.ExecutionEngine TorchMLIRPythonSources TorchMLIRPythonExtensions ) diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index e0b04514336..509fb374901 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -9,6 +9,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include "torch-mlir-c/Dialects.h" #include "torch-mlir-c/Registration.h" @@ -19,14 +20,6 @@ PYBIND11_MODULE(_torchMlir, m) { m.doc() = "torch-mlir main python extension"; - m.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__torch__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - }, - py::arg("context"), py::arg("load") = true); + m.def("register_required_dialects", torchMlirRegisterRequiredDialects, + py::arg("context")); } diff --git a/python/torch_mlir/_mlir_libs/_site_initialize_0.py b/python/torch_mlir/_mlir_libs/_site_initialize_0.py new file mode 100644 index 00000000000..9978b72c739 --- /dev/null +++ b/python/torch_mlir/_mlir_libs/_site_initialize_0.py @@ -0,0 +1,4 @@ +from . import _torchMlir + +def context_init_hook(context): + _torchMlir.register_required_dialects(context) diff --git a/python/torch_mlir/dialects/torch/__init__.py b/python/torch_mlir/dialects/torch/__init__.py index bd362849a7e..b94a7bdab45 100644 --- a/python/torch_mlir/dialects/torch/__init__.py +++ b/python/torch_mlir/dialects/torch/__init__.py @@ -4,4 +4,4 @@ # Also available under a BSD-style license. See LICENSE. from .._torch_ops_gen import * -from ..._mlir_libs._torchMlir import register_dialect +from ..._mlir_libs._torchMlir import register_required_dialects diff --git a/test/python/smoketest.py b/test/python/smoketest.py index 88e0a10f7ef..c423351239e 100644 --- a/test/python/smoketest.py +++ b/test/python/smoketest.py @@ -4,4 +4,4 @@ from torch_mlir.dialects import torch with torch_mlir.ir.Context() as ctx: - torch.register_dialect(ctx) + torch.register_required_dialects(ctx)