Skip to content

Commit

Permalink
api: use new base helper in more operation definitions (#2893)
Browse files Browse the repository at this point in the history
Pylance errors: 725 -> 678
  • Loading branch information
superlopuh authored Jul 17, 2024
1 parent 63ad0db commit ed144cf
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 44 deletions.
3 changes: 3 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,9 @@ def get_strides(self) -> Iterable[int | None]:
return self.layout.get_strides()


AnyMemRefType: TypeAlias = MemRefType[Attribute]


@irdl_attr_definition
class UnrankedMemrefType(
Generic[_UnrankedMemrefTypeElems], ParametrizedAttribute, TypeAttribute
Expand Down
21 changes: 13 additions & 8 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import cast

from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import IndexType, IntegerAttr, TensorType
from xdsl.dialects.builtin import (
AnyIntegerAttr,
AnyMemRefType,
IndexType,
TensorType,
)
from xdsl.dialects.experimental import dmp
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Expand All @@ -17,6 +22,7 @@
IRDLOperation,
Operand,
ParameterDef,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand All @@ -40,6 +46,7 @@
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr


@irdl_attr_definition
Expand Down Expand Up @@ -113,7 +120,7 @@ class PrefetchOp(IRDLOperation):
name = "csl_stencil.prefetch"

input_stencil = operand_def(
stencil.TempType[Attribute] | memref.MemRefType[Attribute]
base(stencil.TempType[Attribute]) | base(memref.MemRefType[Attribute])
)

swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr])
Expand Down Expand Up @@ -180,7 +187,7 @@ class ApplyOp(IRDLOperation):
name = "csl_stencil.apply"

communicated_stencil = operand_def(
stencil.TempType[Attribute] | memref.MemRefType[Attribute]
base(stencil.TempType[Attribute]) | base(memref.MemRefType[Attribute])
)

iter_arg = operand_def(TensorType[Attribute])
Expand All @@ -194,7 +201,7 @@ class ApplyOp(IRDLOperation):

topo = prop_def(dmp.RankTopoAttr)

num_chunks = prop_def(IntegerAttr)
num_chunks = prop_def(AnyIntegerAttr)

res = var_result_def(stencil.TempType)

Expand Down Expand Up @@ -366,7 +373,7 @@ class AccessOp(IRDLOperation):
"""

name = "csl_stencil.access"
op = operand_def(memref.MemRefType | stencil.TempType)
op = operand_def(base(AnyMemRefType) | base(stencil.AnyTempType))
offset = prop_def(stencil.IndexAttr)
offset_mapping = opt_prop_def(stencil.IndexAttr)
result = result_def(TensorType)
Expand Down Expand Up @@ -444,9 +451,7 @@ def parse(cls, parser: Parser):
props["offset_mapping"] = stencil.IndexAttr.get(*offset_mapping)
parser.parse_punctuation(":")
res_type = parser.parse_attribute()
if not isa(
res_type, memref.MemRefType[Attribute] | stencil.TempType[Attribute]
):
if not isattr(res_type, base(AnyMemRefType) | base(stencil.AnyTempType)):
parser.raise_error("Expected return type to be a memref or stencil.temp")
return cls.build(
operands=[temp], result_types=[res_type.element_type], properties=props
Expand Down
5 changes: 3 additions & 2 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from math import prod
from typing import Literal

from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects import builtin, stencil
from xdsl.ir import Attribute, Dialect, Operation, ParametrizedAttribute, SSAValue
from xdsl.irdl import (
IRDLOperation,
Operand,
ParameterDef,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -434,7 +435,7 @@ class SwapOp(IRDLOperation):
name = "dmp.swap"

input_stencil: Operand = operand_def(
stencil.TempType[Attribute] | memref.MemRefType[Attribute]
base(stencil.AnyTempType) | base(builtin.AnyMemRefType)
)

swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = opt_attr_def(
Expand Down
10 changes: 6 additions & 4 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloat,
AnyMemRefType,
AnyShapedType,
AnyTensorType,
ArrayAttr,
Expand All @@ -34,6 +35,7 @@
VarOperand,
VarOpResult,
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -556,8 +558,8 @@ class TransposeOp(IRDLOperation):

name = "linalg.transpose"

input = operand_def(MemRefType | AnyTensorType)
init = operand_def(MemRefType | AnyTensorType)
input = operand_def(base(AnyMemRefType) | base(AnyTensorType))
init = operand_def(base(AnyMemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

permutation = attr_def(DenseArrayBase)
Expand Down Expand Up @@ -806,8 +808,8 @@ class BroadcastOp(IRDLOperation):

name = "linalg.broadcast"

input = operand_def(MemRefType | AnyTensorType)
init = operand_def(MemRefType | AnyTensorType)
input = operand_def(base(AnyMemRefType) | base(AnyTensorType))
init = operand_def(base(AnyMemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

dimensions = attr_def(DenseArrayBase)
Expand Down
13 changes: 11 additions & 2 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Annotated, Generic, Literal, TypeVar

from xdsl.dialects.builtin import (
I64,
AnyIntegerAttr,
ArrayAttr,
ContainerType,
Expand Down Expand Up @@ -42,6 +43,7 @@
OptOpResult,
ParameterDef,
VarOperand,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -732,7 +734,7 @@ class InlineAsmOp(IRDLOperation):
# 0 for AT&T inline assembly dialect
# 1 for Intel inline assembly dialect
# In this context dialect does not refer to an MLIR dialect
asm_dialect = opt_prop_def(IntegerAttr[Annotated[IntegerType, IntegerType(64)]])
asm_dialect = opt_prop_def(IntegerAttr[I64])

asm_string: StringAttr = prop_def(StringAttr)
constraints: StringAttr = prop_def(StringAttr)
Expand Down Expand Up @@ -1316,6 +1318,13 @@ def __init__(
LLVMType = (
LLVMStructType | LLVMPointerType | LLVMArrayType | LLVMVoidType | LLVMFunctionType
)
LLVMTypeConstr = (
base(LLVMStructType)
| base(LLVMPointerType)
| base(LLVMArrayType)
| base(LLVMVoidType)
| base(LLVMFunctionType)
)


@irdl_op_definition
Expand All @@ -1324,7 +1333,7 @@ class ZeroOp(IRDLOperation):

assembly_format = "attr-dict `:` type($res)"

res = result_def(LLVMType)
res = result_def(LLVMTypeConstr)


LLVM = Dialect(
Expand Down
23 changes: 17 additions & 6 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ParsePropInAttrDict,
VarOperand,
VarOpResult,
base,
irdl_op_definition,
operand_def,
opt_prop_def,
Expand Down Expand Up @@ -376,7 +377,9 @@ class AtomicRMWOp(IRDLOperation):
@irdl_op_definition
class Dealloc(IRDLOperation):
name = "memref.dealloc"
memref: Operand = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
memref: Operand = operand_def(
base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
)

@staticmethod
def get(operand: Operation | SSAValue) -> Dealloc:
Expand Down Expand Up @@ -463,7 +466,9 @@ def get(
class Dim(IRDLOperation):
name = "memref.dim"

source: Operand = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
source: Operand = operand_def(
base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
)
index: Operand = operand_def(IndexType)

result: OpResult = result_def(IndexType)
Expand Down Expand Up @@ -838,8 +843,12 @@ def parse(cls, parser: Parser) -> Subview:
class Cast(IRDLOperation):
name = "memref.cast"

source: Operand = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
dest: OpResult = result_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
source: Operand = operand_def(
base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
)
dest: OpResult = result_def(
base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
)

traits = frozenset([NoMemoryEffect()])

Expand All @@ -855,8 +864,10 @@ def get(
class MemorySpaceCast(IRDLOperation):
name = "memref.memory_space_cast"

source = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
dest = result_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
source = operand_def(
base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
)
dest = result_def(base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute]))

traits = frozenset([NoMemoryEffect()])

Expand Down
4 changes: 3 additions & 1 deletion xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from xdsl.dialects import memref, stream
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyMemRefType,
ArrayAttr,
IndexType,
IntAttr,
Expand All @@ -38,6 +39,7 @@
ConstraintVar,
IRDLOperation,
ParameterDef,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -367,7 +369,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(memref.MemRefType | stream.WritableStreamType)
outputs = var_operand_def(base(AnyMemRefType) | base(stream.AnyWritableStreamType))
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
10 changes: 6 additions & 4 deletions xdsl/dialects/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OptOpResult,
ParameterDef,
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -111,6 +112,7 @@ class DataType(ParametrizedAttribute, TypeAttribute):


VectorWrappable = RequestType | StatusType | DataType
VectorWrappableConstr = base(RequestType) | base(StatusType) | base(DataType)
_VectorT = TypeVar("_VectorT", bound=VectorWrappable)


Expand Down Expand Up @@ -563,7 +565,7 @@ class Waitall(MPIBaseOp):
statuses: OptOpResult = opt_result_def(VectorType[StatusType])

def __init__(self, requests: Operand, count: Operand, ignore_status: bool = True):
result_types: list[list[Attribute]] = [[VectorType.of(StatusType)]]
result_types: list[list[Attribute]] = [[VectorType[StatusType].of(StatusType)]]
if ignore_status:
result_types = [[]]

Expand Down Expand Up @@ -712,7 +714,7 @@ class AllocateTypeOp(MPIBaseOp):
name = "mpi.allocate"

bindc_name: StringAttr | None = opt_attr_def(StringAttr)
dtype: VectorWrappable = attr_def(VectorWrappable)
dtype: VectorWrappable = attr_def(VectorWrappableConstr)
count: Operand = operand_def(i32)

result: OpResult = result_def(VectorType)
Expand All @@ -724,7 +726,7 @@ def __init__(
bindc_name: StringAttr | None = None,
):
return super().__init__(
result_types=[VectorType.of(dtype)],
result_types=[VectorType[dtype].of(dtype)],
attributes={
"dtype": dtype(),
"bindc_name": bindc_name,
Expand All @@ -745,7 +747,7 @@ class VectorGetOp(MPIBaseOp):
vect: Operand = operand_def(VectorType)
element: Operand = operand_def(i32)

result: OpResult = result_def(VectorWrappable)
result: OpResult = result_def(VectorWrappableConstr)

def __init__(self, vect: SSAValue | Operation, element: SSAValue | Operation):
ssa_val = SSAValue.get(vect)
Expand Down
9 changes: 5 additions & 4 deletions xdsl/dialects/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
base,
irdl_attr_definition,
irdl_op_definition,
opt_operand_def,
Expand Down Expand Up @@ -68,9 +69,9 @@ class OrderKindAttr(EnumAttribute[OrderKind], SpacedOpaqueSyntaxAttribute):
class WsLoopOp(IRDLOperation):
name = "omp.wsloop"

lowerBound = var_operand_def(IntegerType | IndexType)
upperBound = var_operand_def(IntegerType | IndexType)
step = var_operand_def(IntegerType | IndexType)
lowerBound = var_operand_def(base(IntegerType) | base(IndexType))
upperBound = var_operand_def(base(IntegerType) | base(IndexType))
step = var_operand_def(base(IntegerType) | base(IndexType))
linear_vars = var_operand_def()
linear_step_vars = var_operand_def(i32)
# TODO: this is constrained to OpenMP_PointerLikeTypeInterface upstream
Expand Down Expand Up @@ -108,7 +109,7 @@ class ParallelOp(IRDLOperation):
name = "omp.parallel"

if_expr_var = opt_operand_def(IntegerType(1))
num_threads_var = opt_operand_def(IntegerType | IndexType)
num_threads_var = opt_operand_def(base(IntegerType) | base(IndexType))
allocate_vars = var_operand_def()
allocators_vars = var_operand_def()
# TODO: this is constrained to OpenMP_PointerLikeTypeInterface upstream
Expand Down
9 changes: 5 additions & 4 deletions xdsl/dialects/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ConstraintVar,
IRDLOperation,
attr_def,
base,
irdl_op_definition,
operand_def,
opt_attr_def,
Expand Down Expand Up @@ -470,7 +471,7 @@ class Conv(IRDLOperation):
T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")]
data = operand_def(TensorType[T])
weight = operand_def(TensorType[T])
bias = operand_def(TensorType[T] | NoneType)
bias = operand_def(base(TensorType[T]) | base(NoneType))
res = result_def(TensorType[T])

auto_pad = attr_def(StringAttr)
Expand Down Expand Up @@ -727,8 +728,8 @@ class MaxPoolSingleOut(IRDLOperation):
name = "onnx.MaxPoolSingleOut"

T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")]
data = operand_def(TensorType[T] | MemRefType[T])
output = result_def(TensorType[T] | MemRefType[T])
data = operand_def(base(TensorType[T]) | base(MemRefType[T]))
output = result_def(base(TensorType[T]) | base(MemRefType[T]))

auto_pad = attr_def(StringAttr)
ceil_mode = attr_def(AnyIntegerAttr)
Expand Down Expand Up @@ -1029,7 +1030,7 @@ class Squeeze(IRDLOperation):

T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")]
input_tensor = operand_def(TensorType[T])
axes = opt_attr_def(IntegerAttr, attr_name="axes")
axes = opt_attr_def(base(AnyIntegerAttr), attr_name="axes")

output_tensor = result_def(TensorType[T])

Expand Down
Loading

0 comments on commit ed144cf

Please sign in to comment.