Skip to content

Commit

Permalink
dialects: (riscv) Fix I, S and shift operation immediate bounds (#2785)
Browse files Browse the repository at this point in the history
Turns out I broke the verification a while back, the assembly format
expects the immediates to be within the signed range, not unsigned. I
expect that the arith lowering will be impacted, but this PR does not
address that.

fixes #2056
  • Loading branch information
superlopuh authored Jun 26, 2024
1 parent a36af1f commit 17ec8ba
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 33 deletions.
8 changes: 4 additions & 4 deletions tests/dialects/test_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def test_return_op():


def test_immediate_i_inst():
# I-Type - 12-bits immediate
lb, ub = Signedness.SIGNLESS.value_range(12)
# I-Type - 12-bits signed immediate
lb, ub = Signedness.SIGNED.value_range(12)
a1 = TestSSAValue(riscv.Registers.A1)

with pytest.raises(VerifyException):
Expand All @@ -149,8 +149,8 @@ def test_immediate_i_inst():


def test_immediate_s_inst():
# S-Type - 12-bits immediate
lb, ub = Signedness.SIGNLESS.value_range(12)
# S-Type - 12-bits signed immediate
lb, ub = Signedness.SIGNED.value_range(12)
a1 = TestSSAValue(riscv.Registers.A1)
a2 = TestSSAValue(riscv.Registers.A2)

Expand Down
10 changes: 9 additions & 1 deletion tests/filecheck/backend/riscv/verify.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s
// RUN: xdsl-opt --split-input-file --verify-diagnostics --parsing-diagnostics %s | filecheck %s

%i1 = "test.op"() : () -> !riscv.reg<a1>
%1 = riscv.li 1 : () -> !riscv.reg<>
Expand Down Expand Up @@ -28,3 +28,11 @@
%wrong_1 = riscv_snitch.scfgwi %i1, 1 : (!riscv.reg<a1>) -> !riscv.reg<t0>

// CHECK: Operation does not verify: scfgwi rd must be ZERO, got !riscv.reg<t0>

// -----

%i1 = "test.op"() : () -> !riscv.reg<a1>
%ok_imm = riscv.addi %i1, 1 : (!riscv.reg<a1>) -> !riscv.reg<t0>
%big_imm = riscv.addi %i1, 2048 : (!riscv.reg<a1>) -> !riscv.reg<t0>

// CHECK: Integer value 2048 is out of range for type si12 which supports values in the range [-2048, 2048)
28 changes: 14 additions & 14 deletions tests/filecheck/dialects/riscv/riscv_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,12 @@
// CHECK-GENERIC-NEXT: "riscv_func.func"() ({
// CHECK-GENERIC-NEXT: %0 = "riscv.get_register"() : () -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %1 = "riscv.get_register"() : () -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %addi = "riscv.addi"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %slti = "riscv.slti"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %sltiu = "riscv.sltiu"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %andi = "riscv.andi"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %ori = "riscv.ori"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %xori = "riscv.xori"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %addi = "riscv.addi"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %slti = "riscv.slti"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %sltiu = "riscv.sltiu"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %andi = "riscv.andi"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %ori = "riscv.ori"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %xori = "riscv.xori"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %slli = "riscv.slli"(%0) {"immediate" = 1 : ui5} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %srli = "riscv.srli"(%0) {"immediate" = 1 : ui5} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %srai = "riscv.srai"(%0) {"immediate" = 1 : ui5} : (!riscv.reg<>) -> !riscv.reg<>
Expand Down Expand Up @@ -389,14 +389,14 @@
// CHECK-GENERIC-NEXT: "riscv.bge"(%0, %1) {"offset" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.bltu"(%0, %1) {"offset" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.bgeu"(%0, %1) {"offset" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: %lb = "riscv.lb"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lbu = "riscv.lbu"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lh = "riscv.lh"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lhu = "riscv.lhu"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lw = "riscv.lw"(%0) {"immediate" = 1 : i12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: "riscv.sb"(%0, %1) {"immediate" = 1 : i12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.sh"(%0, %1) {"immediate" = 1 : i12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.sw"(%0, %1) {"immediate" = 1 : i12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: %lb = "riscv.lb"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lbu = "riscv.lbu"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lh = "riscv.lh"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lhu = "riscv.lhu"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %lw = "riscv.lw"(%0) {"immediate" = 1 : si12} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: "riscv.sb"(%0, %1) {"immediate" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.sh"(%0, %1) {"immediate" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: "riscv.sw"(%0, %1) {"immediate" = 1 : si12} : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-GENERIC-NEXT: %csrrw_rw = "riscv.csrrw"(%0) {"csr" = 1024 : i32} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %csrrw_w = "riscv.csrrw"(%0) {"csr" = 1024 : i32, "writeonly"} : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %csrrs_rw = "riscv.csrrs"(%0) {"csr" = 1024 : i32} : (!riscv.reg<>) -> !riscv.reg<>
Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/riscv_scf/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
%step = "riscv.li"() {"immediate" = 1: i32} : () -> !riscv.reg<>
%acc = "riscv.li"() {"immediate" = 0 : i32} : () -> !riscv.reg<t0>
riscv_scf.for %i : !riscv.reg<> = %lb to %ub step %step {
"riscv.addi"(%acc) {"immediate" = 1 : i12} : (!riscv.reg<t0>) -> !riscv.reg<t0>
riscv.addi %acc, 1 : (!riscv.reg<t0>) -> !riscv.reg<t0>
}
riscv_scf.rof %j : !riscv.reg<> = %ub down to %lb step %step {
"riscv.addi"(%acc) {"immediate" = 1 : i12} : (!riscv.reg<t0>) -> !riscv.reg<t0>
riscv.addi %acc, 1 : (!riscv.reg<t0>) -> !riscv.reg<t0>
}
%i_last, %ub_last, %step_last = riscv_scf.while (%i0 = %lb, %ub_arg0 = %ub, %step_arg0 = %step) : (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>) {
%cond = riscv.slt %i0, %ub_arg0 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
riscv_scf.condition(%cond : !riscv.reg<>) %i0, %ub_arg0, %step_arg0 : !riscv.reg<>, !riscv.reg<>, !riscv.reg<>
} do {
^1(%i1 : !riscv.reg<>, %ub_arg1 : !riscv.reg<>, %step_arg1 : !riscv.reg<>):
"riscv.addi"(%acc) {"immediate" = 1 : i12} : (!riscv.reg<t0>) -> !riscv.reg<t0>
riscv.addi %acc, 1 : (!riscv.reg<t0>) -> !riscv.reg<t0>
%i_next = "riscv.add"(%i1, %step_arg1) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"riscv_scf.yield"(%i_next, %ub_arg1, %step_arg1) : (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -> ()
}
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/dialects/riscv_snitch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ riscv_func.func @xdma() {
// CHECK-GENERIC-NEXT: %0 = "riscv.get_register"() : () -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %1 = "riscv.get_register"() : () -> !riscv.reg<>
// CHECK-GENERIC-NEXT: %scfgw = "riscv_snitch.scfgw"(%0, %1) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<zero>
// CHECK-GENERIC-NEXT: %scfgwi_zero = "riscv_snitch.scfgwi"(%0) {"immediate" = 42 : i12} : (!riscv.reg<>) -> !riscv.reg<zero>
// CHECK-GENERIC-NEXT: %scfgwi_zero = "riscv_snitch.scfgwi"(%0) {"immediate" = 42 : si12} : (!riscv.reg<>) -> !riscv.reg<zero>
// CHECK-GENERIC-NEXT: "riscv_snitch.frep_outer"(%{{.*}}) ({
// CHECK-GENERIC-NEXT: %{{.*}} = "riscv.add"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-GENERIC-NEXT: "riscv_snitch.frep_yield"() : () -> ()
Expand Down
50 changes: 40 additions & 10 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,7 @@ class RdRsImmIntegerOperation(IRDLOperation, RISCVInstruction, ABC):

rd = result_def(IntRegisterType)
rs1 = operand_def(IntRegisterType)
# https://github.com/xdslproject/xdsl/issues/2056
immediate = attr_def(IntegerAttr[IntegerType] | LabelAttr)
immediate: SImm12Attr | LabelAttr = attr_def(SImm12Attr | LabelAttr)

def __init__(
self,
Expand All @@ -772,7 +771,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, i12)
immediate = IntegerAttr(immediate, si12)
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)

Expand All @@ -797,7 +796,7 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
attributes = dict[str, Attribute]()
attributes["immediate"] = _parse_immediate_value(parser, i12)
attributes["immediate"] = _parse_immediate_value(parser, si12)
return attributes

def custom_print_attributes(self, printer: Printer) -> Set[str]:
Expand All @@ -806,7 +805,7 @@ def custom_print_attributes(self, printer: Printer) -> Set[str]:
return {"immediate"}


class RdRsImmShiftOperation(RdRsImmIntegerOperation):
class RdRsImmShiftOperation(IRDLOperation, RISCVInstruction, ABC):
"""
A base class for RISC-V operations that have one destination register, one source
register and one immediate operand.
Expand All @@ -820,6 +819,10 @@ class RdRsImmShiftOperation(RdRsImmIntegerOperation):
imm[5] 6 != 0 but the shift amount is encoded in the lower 6 bits of the I-immediate field for RV64I.
"""

rd = result_def(IntRegisterType)
rs1 = operand_def(IntRegisterType)
immediate: UImm5Attr | LabelAttr = attr_def(UImm5Attr | LabelAttr)

def __init__(
self,
rs1: Operation | SSAValue,
Expand All @@ -830,15 +833,38 @@ def __init__(
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, ui5)
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)

if rd is None:
rd = IntRegisterType.unallocated()
elif isinstance(rd, str):
rd = IntRegisterType(rd)
if isinstance(comment, str):
comment = StringAttr(comment)
super().__init__(
operands=[rs1],
result_types=[rd],
attributes={
"immediate": immediate,
"comment": comment,
},
)

super().__init__(rs1, immediate, rd=rd, comment=comment)
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.rs1, self.immediate

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
attributes = dict[str, Attribute]()
attributes["immediate"] = _parse_immediate_value(parser, ui5)
return attributes

def custom_print_attributes(self, printer: Printer) -> Set[str]:
printer.print(", ")
_print_immediate_value(printer, self.immediate)
return {"immediate"}


class RdRsImmJumpOperation(IRDLOperation, RISCVInstruction, ABC):
"""
Expand Down Expand Up @@ -997,7 +1023,7 @@ class RsRsImmIntegerOperation(IRDLOperation, RISCVInstruction, ABC):

rs1 = operand_def(IntRegisterType)
rs2 = operand_def(IntRegisterType)
immediate = attr_def(Imm12Attr)
immediate = attr_def(SImm12Attr)

def __init__(
self,
Expand All @@ -1008,7 +1034,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, i12)
immediate = IntegerAttr(immediate, si12)
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)
if isinstance(comment, str):
Expand All @@ -1028,7 +1054,7 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
attributes = dict[str, Attribute]()
attributes["immediate"] = _parse_immediate_value(parser, i12)
attributes["immediate"] = _parse_immediate_value(parser, si12)
return attributes

def custom_print_attributes(self, printer: Printer) -> Set[str]:
Expand Down Expand Up @@ -3724,8 +3750,12 @@ def _parse_optional_immediate_value(
"""
Parse an optional immediate value. If an integer is parsed, an integer attr with the specified type is created.
"""
pos = parser.pos
if (immediate := parser.parse_optional_integer()) is not None:
return IntegerAttr(immediate, integer_type)
try:
return IntegerAttr(immediate, integer_type)
except VerifyException as e:
parser.raise_error(e.args[0], pos)
if (immediate := parser.parse_optional_str_literal()) is not None:
return LabelAttr(immediate)

Expand Down

0 comments on commit 17ec8ba

Please sign in to comment.