Skip to content

Commit

Permalink
[CombToAIG] Implement ArrayGet/ArrayCreate/AggregateConst lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
uenoku committed Jan 12, 2025
1 parent fab85cf commit 592f785
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 3 deletions.
16 changes: 16 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou
%2 = comb.shrs %lhs, %rhs : i5
hw.output %0, %1, %2 : i5, i5, i5
}

// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY
// COMB_ARRAY: c1 == c2
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2
%2 = hw.array_create %arg0, %arg1, %arg2 : i2
%c3_i2 = hw.constant 3 : i2
// NOTE: If the index is out of bounds, the result value is undefined.
// In LEC such value is lowered into unbounded SMT variable and cause
// the LEC to fail. So just asssume that the index is in bounds.
%inbound = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound : i1
%3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2
hw.output %1, %3 : i2, i2
}
143 changes: 140 additions & 3 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,126 @@ struct CombShiftOpConversion : OpConversionPattern<OpTy> {
}
};

template <typename OpTy>
struct HWArrayCreateLikeOpConversion : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Lower to concat.
auto inputs = adaptor.getInputs();
SmallVector<Value> results;
for (auto input : inputs)
results.push_back(rewriter.getRemappedValue(input));
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
return success();
}
};

struct HWAggregateConstantOpConversion
: OpConversionPattern<hw::AggregateConstantOp> {
using OpConversionPattern<hw::AggregateConstantOp>::OpConversionPattern;

static LogicalResult peelAttribute(Location loc, Attribute attr,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &results) {
SmallVector<Attribute> worklist;
worklist.push_back(attr);

while (!worklist.empty()) {
auto current = worklist.pop_back_val();
if (auto innerArray = dyn_cast<ArrayAttr>(current)) {
for (auto elem : llvm::reverse(innerArray))
worklist.push_back(elem);
continue;
}

if (auto intAttr = dyn_cast<IntegerAttr>(current)) {
results.push_back(rewriter.create<hw::ConstantOp>(loc, intAttr));
continue;
}

return failure();
}

return success();
}

LogicalResult
matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Lower to concat.
SmallVector<Value> results;
if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
results)))
return failure();
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
return success();
}
};

struct HWArrayGetOpConversion : OpConversionPattern<hw::ArrayGetOp> {
using OpConversionPattern<hw::ArrayGetOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> results;
auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
auto elemType = arrayType.getElementType();
auto numElements = arrayType.getNumElements();
auto elemWidth = hw::getBitWidth(elemType);
if (elemWidth < 0)
return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");

auto lowered = rewriter.getRemappedValue(op.getInput());
if (!lowered)
return failure();

for (size_t i = 0; i < numElements; ++i)
results.push_back(rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), lowered, i * elemWidth, elemWidth));

auto bits = extractBits(rewriter, op.getIndex());
auto result = constructMuxTree(rewriter, op.getLoc(), numElements, bits,
results, results.back());

rewriter.replaceOp(op, result);
return success();
}
};

/// A type converter is needed to perform the in-flight materialization of
/// aggregate types to integer types.
class CombToAIGTypeConverter : public TypeConverter {
public:
CombToAIGTypeConverter() {
addConversion([](Type type) -> Type { return type; });
addConversion([](hw::ArrayType t) -> Type {
return IntegerType::get(t.getContext(), hw::getBitWidth(t));
});
addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1)
return Value();

return builder.create<hw::BitcastOp>(loc, resultType, inputs[0])
->getResult(0);
});

addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1)
return Value();

return builder.create<hw::BitcastOp>(loc, resultType, inputs[0])
->getResult(0);
});
}
};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -610,7 +730,9 @@ struct ConvertCombToAIGPass
};
} // namespace

static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
static void
populateCombToAIGConversionPatterns(RewritePatternSet &patterns,
CombToAIGTypeConverter &typeConverter) {
patterns.add<
// Bitwise Logical Ops
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
Expand All @@ -625,17 +747,31 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
/*isSigned=*/false>,
CombShiftOpConversion<comb::ShrSOp, /*isLeftShift=*/false,
/*isSigned=*/true>,
// Array Ops
HWArrayGetOpConversion, HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
HWAggregateConstantOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
CombLowerVariadicOp<MulOp>>(patterns.getContext());
CombLowerVariadicOp<MulOp>>(typeConverter, patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());

// Comb is source dialect.
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp, hw::ConstantOp>();

// Treat array operations as illegal. Strictly speaking, other than array get
// operation with non-const index are legal but array types prevent a bunch of
// optimizations so just lower them to integer operations.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

// AIG is target dialect.
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
Expand All @@ -644,7 +780,8 @@ void ConvertCombToAIGPass::runOnOperation() {
target.addLegalOp(OperationName(opName, &getContext()));

RewritePatternSet patterns(&getContext());
populateCombToAIGConversionPatterns(patterns);
CombToAIGTypeConverter typeConverter;
populateCombToAIGConversionPatterns(patterns, typeConverter);

if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
35 changes: 35 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,38 @@ hw.module @shift2(in %lhs: i2, in %rhs: i2, out out_shl: i2, out out_shr: i2, ou
// ALLOW_ICMP-NEXT: hw.output %[[L_SHIFT_WITH_BOUND_CHECK]], %[[R_SHIFT_WITH_BOUND_CHECK]], %[[R_SIGNED_SHIFT_WITH_BOUND_CHECK]]
hw.output %0, %1, %2 : i2, i2, i2
}


// CHECK-LABEL: @array(
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, out out: !hw.array<4xi2>, in %sel: i2, out out_get: i2, out out_agg: !hw.array<4xi2>, out out_agg_get: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel] : !hw.array<4xi2>, i2
%2 = hw.aggregate_constant [0 : i2, 1 : i2, -2 : i2, -1 : i2] : !hw.array<4xi2>
%3 = hw.array_get %2[%sel] : !hw.array<4xi2>, i2
// CHECK: %[[CONCAT:.+]] = comb.concat %arg0, %arg1, %arg2, %arg3 : i2, i2, i2, i2
// CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONCAT]] : (i8) -> !hw.array<4xi2>
// CHECK-NEXT: %[[ARG_3:.+]] = comb.extract %[[CONCAT]] from 0 : (i8) -> i2
// CHECK-NEXT: %[[ARG_2:.+]] = comb.extract %[[CONCAT]] from 2 : (i8) -> i2
// CHECK-NEXT: %[[ARG_1:.+]] = comb.extract %[[CONCAT]] from 4 : (i8) -> i2
// CHECK-NEXT: %[[ARG_0:.+]] = comb.extract %[[CONCAT]] from 6 : (i8) -> i2
// CHECK-NEXT: %[[SEL_0:.+]] = comb.extract %sel from 0 : (i2) -> i1
// CHECK-NEXT: %[[SEL_1:.+]] = comb.extract %sel from 1 : (i2) -> i1
// CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[SEL_0]], %[[ARG_0]], %[[ARG_1]] : i2
// CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[SEL_0]], %[[ARG_2]], %[[ARG_3]] : i2
// CHECK-NEXT: %[[ARRAY_GET_1:.+]] = comb.mux %[[SEL_1]], %[[MUX_0]], %[[MUX_1]] : i2
// CHECK-NEXT: %[[C_0:.+]] = hw.constant 0
// CHECK-NEXT: %[[C_1:.+]] = hw.constant 1
// CHECK-NEXT: %[[C_2:.+]] = hw.constant -2
// CHECK-NEXT: %[[C_3:.+]] = hw.constant -1
// CHECK-NEXT: %[[AGG_CONST:.+]] = comb.concat %[[C_0]], %[[C_1]], %[[C_2]], %[[C_3]]
// CHECK-NEXT: %[[BITCAST_CONST:.+]] = hw.bitcast %[[AGG_CONST]]
// CHECK-NEXT: %[[ARG_3:.+]] = comb.extract %[[AGG_CONST]] from 0 : (i8) -> i2
// CHECK-NEXT: %[[ARG_2:.+]] = comb.extract %[[AGG_CONST]] from 2 : (i8) -> i2
// CHECK-NEXT: %[[ARG_1:.+]] = comb.extract %[[AGG_CONST]] from 4 : (i8) -> i2
// CHECK-NEXT: %[[ARG_0:.+]] = comb.extract %[[AGG_CONST]] from 6 : (i8) -> i2
// CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[SEL_0]], %[[ARG_0]], %[[ARG_1]] : i2
// CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[SEL_0]], %[[ARG_2]], %[[ARG_3]] : i2
// CHECK-NEXT: %[[AGG_GET_2:.+]] = comb.mux %[[SEL_1]], %[[MUX_0]], %[[MUX_1]] : i2
// CHECK-NEXT: hw.output %[[BITCAST]], %[[ARRAY_GET_1]], %[[BITCAST_CONST]], %[[AGG_GET_2]]
hw.output %0, %1, %2, %3 : !hw.array<4xi2>, i2, !hw.array<4xi2>, i2
}
40 changes: 40 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,43 @@ hw.module @mux(in %cond: i1, in %high: !hw.array<2xi4>, in %low: !hw.array<2xi4>
%0 = comb.mux %cond, %high, %low : !hw.array<2xi4>
hw.output %0 : !hw.array<2xi4>
}

// CHECK-LABEL: @agg_const
hw.module @agg_const(out out: !hw.array<4xi4>) {
// CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4
// CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4>
// CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4>
%0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4>
hw.output %0 : !hw.array<4xi4>
}

// CHECK-LABEL: @array_get_for_port
hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) {
%c_i2 = hw.constant 3 : i3
// CHECK-NEXT: %[[BITCAST_IN:.+]] = hw.bitcast %in : (!hw.array<5xi4>) -> i20
// CHECK: %[[EXTRACT:.+]] = comb.extract %[[BITCAST_IN]] from 12 : (i20) -> i4
// CHECK: hw.output %[[EXTRACT]] : i4
%1 = hw.array_get %in[%c_i2] : !hw.array<5xi4>, i3
hw.output %1 : i4
}

// CHECK-LABEL: @array_concat
hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) {
%0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4>
%c_i2 = hw.constant 3 : i3
// CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12
// CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8
// CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12
// CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4
// CHECK: hw.output %[[EXTRACT]] : i4
%1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3
hw.output %1 : i4
}

hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>)
// CHECK-LABEL: @array_instance(
hw.module @array_instance(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) {
// CHECK-NEXT: hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>)
%0 = hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>)
hw.output %0 : !hw.array<4xi2>
}

0 comments on commit 592f785

Please sign in to comment.