Skip to content

Commit

Permalink
[FIRRTL] Change min width of shr for UInt to 0
Browse files Browse the repository at this point in the history
Major change in FIRRTL 4.0.0.

Co-authored-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
  • Loading branch information
jackkoenig and seldridge committed Feb 15, 2024
1 parent 4001ec8 commit 24d784e
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 24 deletions.
6 changes: 1 addition & 5 deletions lib/Conversion/FIRRTLToHW/LowerToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3748,10 +3748,6 @@ LogicalResult FIRRTLLowering::visitExpr(ShlPrimOp op) {
}

LogicalResult FIRRTLLowering::visitExpr(ShrPrimOp op) {
// If this is a 0-bit value shifted by any amount, then return a 1-bit zero.
if (isZeroBitFIRRTLType(op.getInput().getType()))
return setLowering(op, getOrCreateIntConstant(1, 0));

auto input = getLoweredValue(op.getInput());
if (!input)
return failure();
Expand All @@ -3762,7 +3758,7 @@ LogicalResult FIRRTLLowering::visitExpr(ShrPrimOp op) {
if (shiftAmount >= inWidth) {
// Unsigned shift by full width returns a single-bit zero.
if (type_cast<IntType>(op.getInput().getType()).isUnsigned())
return setLowering(op, getOrCreateIntConstant(1, 0));
return setLowering(op, {});

// Signed shift by full width is equivalent to extracting the sign bit.
shiftAmount = inWidth - 1;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
// shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
// If x is signed, it is the sign bit.
if (shiftAmount >= inputWidth && inputType.isUnsigned())
return getIntAttr(getType(), APInt(1, 0));
return getIntAttr(getType(), APInt(0, 0, false));

// Constant fold.
if (auto cst = getConstant(adaptor.getInput())) {
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/FIRRTL/FIRRTLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5359,8 +5359,11 @@ FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
loc, "shr input must be integer and amount must be >= 0");

int32_t width = inputi.getWidthOrSentinel();
if (width != -1)
width = std::max<int32_t>(1, width - amount);
if (width != -1) {
// UInt saturates at 0 bits, SInt at 1 bit
int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
width = std::max<int32_t>(minWidth, width - amount);
}

return IntType::get(input.getContext(), inputi.isSigned(), width,
inputi.isConst());
Expand Down
17 changes: 14 additions & 3 deletions lib/Dialect/FIRRTL/Import/FIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2278,12 +2278,23 @@ ParseResult FIRStmtParser::parsePrimExp(Value &result) {
return failure(); \
} \
result = builder.create<CLASS>(resultTy, operands, attrs); \
return success(); \
break; \
}
#include "FIRTokenKinds.def"
}

llvm_unreachable("all cases should return");
// Don't add code here, the common cases of these switch statements will be
// merged. This allows for fixing up primops after they have been created.
switch (kind) {
default:
break;
case FIRToken::lp_shr:
// For FIRRTL versions earlier than 4.0.0, insert pad(_, 1) around any
// unsigned shr This ensures the minimum width is 1 (but can be greater)
if (version < FIRVersion(4, 0, 0) && type_isa<UIntType>(result.getType()))
result = builder.create<PadPrimOp>(result, 1);
break;
}
return success();
}

/// integer-literal-exp ::= 'UInt' optional-width '(' intLit ')'
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/FIRRTLToHW/lower-to-hw.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ firrtl.circuit "Simple" attributes {annotations = [{class =
// CHECK-NEXT: = comb.extract [[CONCAT1]] from 3 : (i8) -> i5
%11 = firrtl.shr %6, 3 : (!firrtl.uint<8>) -> !firrtl.uint<5>

%12 = firrtl.shr %6, 8 : (!firrtl.uint<8>) -> !firrtl.uint<1>
%12 = firrtl.shr %6, 8 : (!firrtl.uint<8>) -> !firrtl.uint<0>

// CHECK-NEXT: = comb.extract %in3 from 7 : (i8) -> i1
%13 = firrtl.shr %in3, 8 : (!firrtl.sint<8>) -> !firrtl.sint<1>
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/FIRRTLToHW/zero-width.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ firrtl.circuit "Arithmetic" {
// See: https://github.com/llvm/circt/issues/6652
// CHECK-LABEL: hw.module @ShrZW
firrtl.module @ShrZW(in %x: !firrtl.uint<0>, out %out: !firrtl.uint<1>) attributes {convention = #firrtl<convention scalarized>} {
%0 = firrtl.shr %x, 5 : (!firrtl.uint<0>) -> !firrtl.uint<1>
firrtl.connect %out, %0 : !firrtl.uint<1>, !firrtl.uint<1>
%0 = firrtl.shr %x, 5 : (!firrtl.uint<0>) -> !firrtl.uint<0>
firrtl.connect %out, %0 : !firrtl.uint<1>, !firrtl.uint<0>
// CHECK: %[[false:.+]] = hw.constant false
// CHECK-NEXT: hw.output %false
}
Expand Down
15 changes: 8 additions & 7 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ firrtl.module @Shr(in %in1u: !firrtl.uint<1>,
in %ins: !firrtl.sint,
in %in0u: !firrtl.uint<0>,
in %in0s: !firrtl.sint<0>,
out %out0u: !firrtl.uint<0>,
out %out1s: !firrtl.sint<1>,
out %out1u: !firrtl.uint<1>,
out %out4u: !firrtl.uint<4>,
Expand All @@ -617,12 +618,12 @@ firrtl.module @Shr(in %in1u: !firrtl.uint<1>,
firrtl.connect %out1u, %0 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: firrtl.strictconnect %out1u, %c0_ui1
%1 = firrtl.shr %in4u, 4 : (!firrtl.uint<4>) -> !firrtl.uint<1>
firrtl.connect %out1u, %1 : !firrtl.uint<1>, !firrtl.uint<1>
%1 = firrtl.shr %in4u, 4 : (!firrtl.uint<4>) -> !firrtl.uint<0>
firrtl.connect %out1u, %1 : !firrtl.uint<1>, !firrtl.uint<0>

// CHECK: firrtl.strictconnect %out1u, %c0_ui1
%2 = firrtl.shr %in4u, 5 : (!firrtl.uint<4>) -> !firrtl.uint<1>
firrtl.connect %out1u, %2 : !firrtl.uint<1>, !firrtl.uint<1>
%2 = firrtl.shr %in4u, 5 : (!firrtl.uint<4>) -> !firrtl.uint<0>
firrtl.connect %out1u, %2 : !firrtl.uint<1>, !firrtl.uint<0>

// CHECK: [[BITS:%.+]] = firrtl.bits %in4s 3 to 3
// CHECK-NEXT: [[CAST:%.+]] = firrtl.asSInt [[BITS]]
Expand Down Expand Up @@ -664,9 +665,9 @@ firrtl.module @Shr(in %in1u: !firrtl.uint<1>,
firrtl.connect %out1u, %9 : !firrtl.uint<1>, !firrtl.uint<0>

// Issue #6608: https://github.com/llvm/circt/issues/6608
// CHECK: firrtl.strictconnect %out1u, %c0_ui1
%10 = firrtl.shr %in0u, 0 : (!firrtl.uint<0>) -> !firrtl.uint<1>
firrtl.strictconnect %out1u, %10 : !firrtl.uint<1>
// CHECK: firrtl.strictconnect %out0u, %c0_ui0
%10 = firrtl.shr %in0u, 0 : (!firrtl.uint<0>) -> !firrtl.uint<0>
firrtl.strictconnect %out0u, %10 : !firrtl.uint<0>

// Issue #6608: https://github.com/llvm/circt/issues/6608
// CHECK: firrtl.strictconnect %out1s, %c0_si1
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/FIRRTL/infer-widths.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ firrtl.circuit "Foo" {
// CHECK: %1 = firrtl.shl {{.*}} -> !firrtl.sint<8>
// CHECK: %2 = firrtl.shr {{.*}} -> !firrtl.uint<2>
// CHECK: %3 = firrtl.shr {{.*}} -> !firrtl.sint<2>
// CHECK: %4 = firrtl.shr {{.*}} -> !firrtl.uint<1>
// CHECK: %4 = firrtl.shr {{.*}} -> !firrtl.uint<0>
// CHECK: %5 = firrtl.shr {{.*}} -> !firrtl.sint<1>
%ui = firrtl.wire : !firrtl.uint
%si = firrtl.wire : !firrtl.sint
Expand Down
122 changes: 120 additions & 2 deletions test/Dialect/FIRRTL/parse-basic.fir
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ circuit MyModule : ; CHECK: firrtl.circuit "MyModule" {
node n4 = bits(i8, 4, 2)

; CHECK: firrtl.shl %i8, 4 : (!firrtl.uint<8>) -> !firrtl.uint<12>
; CHECK: firrtl.shr %i8, 8 : (!firrtl.uint<8>) -> !firrtl.uint<1>
; CHECK: firrtl.shr %i8, 8 : (!firrtl.uint<8>) -> !firrtl.uint<0>
node n5 = or(shl(i8, 4), shr(i8, 8))

; CHECK: firrtl.dshl %i8, %{{.*}} : (!firrtl.uint<8>, !firrtl.const.uint<4>) -> !firrtl.uint<23>
Expand Down Expand Up @@ -487,7 +487,7 @@ circuit MyModule : ; CHECK: firrtl.circuit "MyModule" {
; CHECK-LABEL: firrtl.module private @oversize_shift(
module oversize_shift :
wire value : UInt<2>
; CHECK: firrtl.shr %value, 5 : (!firrtl.uint<2>) -> !firrtl.uint<1>
; CHECK: firrtl.shr %value, 5 : (!firrtl.uint<2>) -> !firrtl.uint<0>
node n = shr(value, 5)

; CHECK-LABEL: firrtl.module private @when_else_ambiguity(
Expand Down Expand Up @@ -1934,3 +1934,121 @@ circuit LayerEnabledModule:
module UserOfLayerEnabledModule enablelayer A enablelayer B.C:
; CHECK: firrtl.instance i interesting_name {layers = [@A, @B::@C]} @LayerEnabledModule()
inst i of LayerEnabledModule

;// -----

FIRRTL version 3.3.0
; CHECK-LABEL: circuit "StaticShiftRight"
circuit StaticShiftRight:
; CHECK: firrtl.module @StaticShiftRight
module StaticShiftRight:
input a : UInt<8>
input b : UInt<0>
input c : SInt<8>
input d : SInt<0>

wire w : UInt
connect w, a

wire x : SInt
connect x, c

; CHECK: %0 = firrtl.shr %a, 1
; CHECK: %1 = firrtl.pad %0, 1
; CHECK: %a_1 = firrtl.node {{.*}} %1 : !firrtl.uint<7>
node a_1 = shr(a, 1)
; CHECK: %2 = firrtl.shr %a, 8
; CHECK: %3 = firrtl.pad %2, 1
; CHECK: %a_2 = firrtl.node {{.*}} %3 : !firrtl.uint<1>
node a_2 = shr(a, 8)
; CHECK: %4 = firrtl.shr %a, 10
; CHECK: %5 = firrtl.pad %4, 1
; CHECK: %a_3 = firrtl.node {{.*}} %5 : !firrtl.uint<1>
node a_3 = shr(a, 10)
; CHECK: %6 = firrtl.shr %b, 0
; CHECK: %7 = firrtl.pad %6, 1
; CHECK: %b_1 = firrtl.node {{.*}} %7 : !firrtl.uint<1>
node b_1 = shr(b, 0)
; CHECK: %8 = firrtl.shr %b, 1
; CHECK: %9 = firrtl.pad %8, 1
; CHECK: %b_2 = firrtl.node {{.*}} %9 : !firrtl.uint<1>
node b_2 = shr(b, 1)
; CHECK: %10 = firrtl.shr %w, 10
; CHECK: %11 = firrtl.pad %10, 1
; CHECK: %w_1 = firrtl.node {{.*}} %11 : !firrtl.uint
node w_1 = shr(w, 10)

; CHECK: %12 = firrtl.shr %c, 1
; CHECK: %c_1 = firrtl.node {{.*}} %12 : !firrtl.sint<7>
node c_1 = shr(c, 1)
; CHECK: %13 = firrtl.shr %c, 8
; CHECK: %c_2 = firrtl.node {{.*}} %13 : !firrtl.sint<1>
node c_2 = shr(c, 8)
; CHECK: %14 = firrtl.shr %c, 10
; CHECK: %c_3 = firrtl.node {{.*}} %14 : !firrtl.sint<1>
node c_3 = shr(c, 10)
; CHECK: %15 = firrtl.shr %d, 0
; CHECK: %d_1 = firrtl.node {{.*}} %15 : !firrtl.sint<1>
node d_1 = shr(d, 0)
; CHECK: %16 = firrtl.shr %d, 1
; CHECK: %d_2 = firrtl.node {{.*}} %16 : !firrtl.sint<1>
node d_2 = shr(d, 1)
; CHECK: %17 = firrtl.shr %x, 10
; CHECK: %x_1 = firrtl.node {{.*}} %17 : !firrtl.sint
node x_1 = shr(x, 10)

;// -----

FIRRTL version 4.0.0
; CHECK-LABEL: circuit "StaticShiftRight"
circuit StaticShiftRight:
; CHECK: firrtl.module @StaticShiftRight
module StaticShiftRight:
input a : UInt<8>
input b : UInt<0>
input c : SInt<8>
input d : SInt<0>

wire w : UInt
connect w, a

wire x : SInt
connect x, c

; CHECK: %0 = firrtl.shr %a, 1
; CHECK: %a_1 = firrtl.node {{.*}} %0 : !firrtl.uint<7>
node a_1 = shr(a, 1)
; CHECK: %1 = firrtl.shr %a, 8
; CHECK: %a_2 = firrtl.node {{.*}} %1 : !firrtl.uint<0>
node a_2 = shr(a, 8)
; CHECK: %2 = firrtl.shr %a, 10
; CHECK: %a_3 = firrtl.node {{.*}} %2 : !firrtl.uint<0>
node a_3 = shr(a, 10)
; CHECK: %3 = firrtl.shr %b, 0
; CHECK: %b_1 = firrtl.node {{.*}} %3 : !firrtl.uint<0>
node b_1 = shr(b, 0)
; CHECK: %4 = firrtl.shr %b, 1
; CHECK: %b_2 = firrtl.node {{.*}} %4 : !firrtl.uint<0>
node b_2 = shr(b, 1)
; CHECK: %5 = firrtl.shr %w, 10
; CHECK: %w_1 = firrtl.node {{.*}} %5 : !firrtl.uint
node w_1 = shr(w, 10)

; CHECK: %6 = firrtl.shr %c, 1
; CHECK: %c_1 = firrtl.node {{.*}} %6 : !firrtl.sint<7>
node c_1 = shr(c, 1)
; CHECK: %7 = firrtl.shr %c, 8
; CHECK: %c_2 = firrtl.node {{.*}} %7 : !firrtl.sint<1>
node c_2 = shr(c, 8)
; CHECK: %8 = firrtl.shr %c, 10
; CHECK: %c_3 = firrtl.node {{.*}} %8 : !firrtl.sint<1>
node c_3 = shr(c, 10)
; CHECK: %9 = firrtl.shr %d, 0
; CHECK: %d_1 = firrtl.node {{.*}} %9 : !firrtl.sint<1>
node d_1 = shr(d, 0)
; CHECK: %10 = firrtl.shr %d, 1
; CHECK: %d_2 = firrtl.node {{.*}} %10 : !firrtl.sint<1>
node d_2 = shr(d, 1)
; CHECK: %11 = firrtl.shr %x, 10
; CHECK: %x_1 = firrtl.node {{.*}} %11 : !firrtl.sint
node x_1 = shr(x, 10)

0 comments on commit 24d784e

Please sign in to comment.