Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python: trim registration and loading of dialects and passes #1084

Merged
merged 1 commit into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/torch-mlir-c/Registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ extern "C" {
*/
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);

/** Registers upstream (MLIR) dialects used in Torch-MLIR IRs. */
MLIR_CAPI_EXPORTED void torchMlirRegisterRequiredDialects(MlirContext context);

/** Registers all passes for symbolic access with the global registry. */
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();

Expand Down
17 changes: 17 additions & 0 deletions lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ add_mlir_public_c_api_library(TorchMLIRCAPI

LINK_LIBS PUBLIC
MLIRIR
MLIRAffineToStandard
MLIRArithmeticToLLVM
MLIRArithmeticTransforms
MLIRBufferizationTransforms
MLIRControlFlowToLLVM
MLIRFuncToLLVM
MLIRFuncTransforms
MLIRLinalgToLLVM
MLIRLinalgTransforms
MLIRMathToLLVM
MLIRMemRefToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRTensorTransforms
MLIRTosaToArith
MLIRTosaToLinalg
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRInitAll
Expand Down
44 changes: 40 additions & 4 deletions lib/CAPI/Registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,27 @@
#include "torch-mlir-c/Registration.h"

#include "mlir/CAPI/IR.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/InitAllPasses.h"
#include "torch-mlir/InitAll.h"

void torchMlirRegisterRequiredDialects(MlirContext context) {
mlir::DialectRegistry registry;
registry.insert<mlir::AffineDialect, mlir::arith::ArithmeticDialect,
mlir::bufferization::BufferizationDialect,
mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::tensor::TensorDialect,
mlir::tosa::TosaDialect>();
unwrap(context)->appendDialectRegistry(registry);
}

void torchMlirRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry;
mlir::torch::registerAllDialects(registry);
Expand All @@ -23,4 +39,24 @@ void torchMlirRegisterAllDialects(MlirContext context) {
unwrap(context)->loadAllAvailableDialects();
}

void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }
void torchMlirRegisterAllPasses() {
mlir::arith::registerArithmeticPasses();
mlir::bufferization::registerBufferizationPasses();
mlir::func::registerFuncPasses();
mlir::registerConvertAffineToStandardPass();
mlir::registerConvertArithmeticToLLVMPass();
mlir::registerConvertControlFlowToLLVMPass();
mlir::registerConvertFuncToLLVMPass();
mlir::registerConvertLinalgToLLVMPass();
mlir::registerConvertMathToLLVMPass();
mlir::registerConvertMemRefToLLVMPass();
mlir::registerLinalgPasses();
mlir::registerReconcileUnrealizedCastsPass();
mlir::registerSCFPasses();
mlir::registerSCFToControlFlowPass();
mlir::registerTosaToArithPass();
mlir::registerTosaToLinalgNamedPass();
mlir::registerTosaToLinalgPass();
mlir::tensor::registerTensorPasses();
mlir::torch::registerAllPasses();
}
13 changes: 4 additions & 9 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
__init__.py
compiler_utils.py
_mlir_libs/_site_initialize_0.py
)

declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
Expand Down Expand Up @@ -102,16 +103,10 @@ add_subdirectory(torch_mlir/eager_mode)
################################################################################

set(_source_components
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
# for the reference backend, but logically they can be separate. But seemingly
# the only way to handle that is to create a separate mlir python package
# tree, which seems excessive.
MLIRPythonSources
MLIRPythonSources.Core
MLIRPythonSources.Dialects.func
MLIRPythonSources.ExecutionEngine
ashay marked this conversation as resolved.
Show resolved Hide resolved
MLIRPythonExtension.Core
MLIRPythonExtension.RegisterEverything
MLIRPythonExtension.ExecutionEngine
TorchMLIRPythonSources
TorchMLIRPythonExtensions
)
Expand Down
13 changes: 3 additions & 10 deletions python/TorchMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "torch-mlir-c/Dialects.h"
#include "torch-mlir-c/Registration.h"

Expand All @@ -19,14 +20,6 @@ PYBIND11_MODULE(_torchMlir, m) {

m.doc() = "torch-mlir main python extension";

m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__torch__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}
},
py::arg("context"), py::arg("load") = true);
m.def("register_required_dialects", torchMlirRegisterRequiredDialects,
py::arg("context"));
}
4 changes: 4 additions & 0 deletions python/torch_mlir/_mlir_libs/_site_initialize_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import _torchMlir

def context_init_hook(context):
_torchMlir.register_required_dialects(context)
2 changes: 1 addition & 1 deletion python/torch_mlir/dialects/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Also available under a BSD-style license. See LICENSE.

from .._torch_ops_gen import *
from ..._mlir_libs._torchMlir import register_dialect
from ..._mlir_libs._torchMlir import register_required_dialects
2 changes: 1 addition & 1 deletion test/python/smoketest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from torch_mlir.dialects import torch

with torch_mlir.ir.Context() as ctx:
torch.register_dialect(ctx)
torch.register_required_dialects(ctx)