Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: allow result variadic inference #3559

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
BoolAttr,
Float64Type,
FloatAttr,
IndexType,
IntegerAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -1450,6 +1451,25 @@ class VariadicResultOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


def test_variadic_result_failure():
"""Test that inferring a range of inferrable attributes of unknown length fails."""

with pytest.raises(
PyRDLOpDefinitionError,
match="type of result 'res' cannot be inferred",
):

@irdl_op_definition
class VariadicResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.var_results_op"

res = var_result_def(IndexType())

irdl_options = [AttrSizedResultSegments()]

assembly_format = "attr-dict"


@pytest.mark.parametrize(
"format, program, generic_program",
[
Expand Down Expand Up @@ -2374,23 +2394,19 @@ class RangeVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]

assembly_format = "$ins attr-dict `:` type($ins)"

with pytest.raises(
NotImplementedError,
match="Inference of length of variadic result 'outs' not implemented",
):
ctx = MLContext()
ctx.load_op(RangeVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent("""\
%in0, %in1 = "test.op"() : () -> (index, index)
%out0, %out1 = test.range_var %in0, %in1 : index, index
""")
ctx = MLContext()
ctx.load_op(RangeVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent("""\
%in0, %in1 = "test.op"() : () -> (index, index)
%out0, %out1 = test.range_var %in0, %in1 : index, index
""")

parser = Parser(ctx, program)
test_op = parser.parse_optional_operation()
assert isinstance(test_op, test.Operation)
my_op = parser.parse_optional_operation()
assert isinstance(my_op, RangeVarOp)
parser = Parser(ctx, program)
test_op = parser.parse_optional_operation()
assert isinstance(test_op, test.Operation)
my_op = parser.parse_optional_operation()
assert isinstance(my_op, RangeVarOp)


################################################################################
Expand Down
34 changes: 23 additions & 11 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,17 +677,21 @@ def get_variable_extractors(
"""
return {}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
"""
Check if there is enough information to infer the attribute given the
constraint variables that are already set.
constraint variables that are already set, and whether the length of the
range is known in advance.
"""
# By default, we cannot infer anything.
return False

def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
def infer(
self, context: InferenceContext, *, length: int | None
) -> Sequence[AttributeCovT]:
"""
Infer the attribute given the the values for all variables.
Infer the attribute given the the values for all variables, and possibly
the length of the range if known.

Raises an exception if the attribute cannot be inferred. If `can_infer`
returns `True` with the given constraint variables, this method should
Expand Down Expand Up @@ -737,10 +741,12 @@ def get_variable_extractors(
) -> dict[str, VarExtractor[Sequence[AttributeCovT]]]:
return {self.name: IdExtractor[Sequence[AttributeCovT]]()}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
return self.name in var_constraint_names

def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
def infer(
self, context: InferenceContext, *, length: int | None
) -> Sequence[AttributeCovT]:
v = context.variables[self.name]
return cast(Sequence[AttributeCovT], v)

Expand All @@ -761,14 +767,16 @@ def verify(
for a in attrs:
self.constr.verify(a, constraint_context)

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.constr.can_infer(var_constraint_names)
def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
return length_known and self.constr.can_infer(var_constraint_names)

def infer(
self,
length: int,
context: InferenceContext,
*,
length: int | None,
) -> Sequence[AttributeCovT]:
assert length is not None
attr = self.constr.infer(context)
return (attr,) * length

Expand Down Expand Up @@ -805,10 +813,14 @@ def get_variable_extractors(
for v, r in self.constr.get_variable_extractors().items()
}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
def can_infer(
self, var_constraint_names: Set[str], *, length_known: int | None
) -> bool:
return self.constr.can_infer(var_constraint_names)

def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
def infer(
self, context: InferenceContext, *, length: int | None
) -> Sequence[AttributeCovT]:
return (self.constr.infer(context),)


Expand Down
29 changes: 12 additions & 17 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None:
operand = state.operands[i]
range_length = len(operand) if isinstance(operand, Sequence) else 1
operand_type = operand_def.constr.infer(
range_length,
InferenceContext(state.variables),
length=range_length,
)
resolved_operand_type: Attribute | Sequence[Attribute]
if isinstance(operand_def, OptionalDef):
Expand All @@ -203,27 +203,22 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None:
Use the inferred type resolutions to fill missing result types from other parsed
types.
"""
for i, (result_type, (result_name, result_def)) in enumerate(
for i, (result_type, (_, result_def)) in enumerate(
zip(state.result_types, op_def.results, strict=True)
):
if result_type is None:
# The number of results is not passed in when parsing operations.
# In the generic format, the type of the operation always specifies the
# types of the results, and `resultSegmentSizes` specifies the ranges of
# of the results if multiple are variadic.
# In order to support variadic results, the types an length of all
# variadic results must be present in the custom syntax.
if isinstance(result_def, OptionalDef | VariadicDef):
raise NotImplementedError(
f"Inference of length of variadic result '{result_name}' not "
"implemented"
)
range_length = 1
inferred_result_types = result_def.constr.infer(
range_length,
InferenceContext(state.variables),
InferenceContext(state.variables), length=None
)
resolved_result_type = inferred_result_types[0]
resolved_result_type: Attribute | Sequence[Attribute]
if isinstance(result_def, OptionalDef):
resolved_result_type = (
inferred_result_types[0] if inferred_result_types else ()
)
elif isinstance(result_def, VariadicDef):
resolved_result_type = inferred_result_types
else:
resolved_result_type = inferred_result_types[0]
state.result_types[i] = resolved_result_type

def print(self, printer: Printer, op: IRDLOperation) -> None:
Expand Down
8 changes: 6 additions & 2 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ def verify_operands(self, var_constraint_names: Set[str]):
"directive to the custom assembly format"
)
if not seen_operand_type:
if not operand_def.constr.can_infer(var_constraint_names):
if not operand_def.constr.can_infer(
var_constraint_names, length_known=True
):
self.raise_error(
f"type of operand '{operand_name}' cannot be inferred, "
f"consider adding a 'type(${operand_name})' directive to the "
Expand All @@ -345,7 +347,9 @@ def verify_results(self, var_constraint_names: Set[str]):
self.seen_result_types, self.op_def.results, strict=True
):
if not result_type:
if not result_def.constr.can_infer(var_constraint_names):
if not result_def.constr.can_infer(
var_constraint_names, length_known=False
):
self.raise_error(
f"type of result '{result_name}' cannot be inferred, "
f"consider adding a 'type(${result_name})' directive to the "
Expand Down
Loading