From 9724331970afe53032f2a672e4cdbd9d360d3499 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 28 Jun 2024 14:07:40 +0200 Subject: [PATCH] python bindings for emitc --- mlir/python/CMakeLists.txt | 8 ++ mlir/python/mlir/dialects/EmitCOps.td | 19 +++++ mlir/python/mlir/dialects/emitc.py | 116 ++++++++++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 mlir/python/mlir/dialects/EmitCOps.td create mode 100644 mlir/python/mlir/dialects/emitc.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 563d035f1552676..fb7f13e87939aa2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -124,6 +124,14 @@ declare_mlir_dialect_python_bindings( dialects/func.py DIALECT_NAME func) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/EmitCOps.td + SOURCES + dialects/emitc.py + DIALECT_NAME emitc) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/EmitCOps.td b/mlir/python/mlir/dialects/EmitCOps.td new file mode 100644 index 000000000000000..e27d589920eda11 --- /dev/null +++ b/mlir/python/mlir/dialects/EmitCOps.td @@ -0,0 +1,19 @@ +//===-- EmitCOps.td - Entry point for Func bind -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the main file from which the Python bindings for the Func dialect +// are generated. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_FUNC +#define PYTHON_BINDINGS_FUNC + +include "mlir/Dialect/EmitC/IR/EmitC.td" + +#endif diff --git a/mlir/python/mlir/dialects/emitc.py b/mlir/python/mlir/dialects/emitc.py new file mode 100644 index 000000000000000..9bc0b48452d91c8 --- /dev/null +++ b/mlir/python/mlir/dialects/emitc.py @@ -0,0 +1,116 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._emitc_ops_gen import * +from ._emitc_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + _cext as _ods_cext, + ) + + from typing import Optional, Sequence, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute