From 592f785c5f668964d3399aa0c6208e11587121e1 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Wed, 1 Jan 2025 22:03:53 -0800 Subject: [PATCH] [CombToAIG] Implement ArrayGet/ArrayCreate/AggregateConst lowering --- .../circt-synth/comb-lowering-lec.mlir | 16 ++ lib/Conversion/CombToAIG/CombToAIG.cpp | 143 +++++++++++++++++- .../CombToAIG/comb-to-aig-arith.mlir | 35 +++++ test/Conversion/CombToAIG/comb-to-aig.mlir | 40 +++++ 4 files changed, 231 insertions(+), 3 deletions(-) diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir index 895c1d81d4a0..df8bd9f7cdf9 100644 --- a/integration_test/circt-synth/comb-lowering-lec.mlir +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou %2 = comb.shrs %lhs, %rhs : i5 hw.output %0, %1, %2 : i5, i5, i5 } + +// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY +// COMB_ARRAY: c1 == c2 +hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) { + %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2 + %1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2 + %2 = hw.array_create %arg0, %arg1, %arg2 : i2 + %c3_i2 = hw.constant 3 : i2 + // NOTE: If the index is out of bounds, the result value is undefined. + // In LEC such value is lowered into unbounded SMT variable and cause + // the LEC to fail. So just asssume that the index is in bounds. + %inbound = comb.icmp ult %sel2, %c3_i2 : i2 + verif.assume %inbound : i1 + %3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2 + hw.output %1, %3 : i2, i2 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index f4a8c2652e44..79d72fa0bbde 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -595,6 +595,126 @@ struct CombShiftOpConversion : OpConversionPattern { } }; +template +struct HWArrayCreateLikeOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Lower to concat. + auto inputs = adaptor.getInputs(); + SmallVector results; + for (auto input : inputs) + results.push_back(rewriter.getRemappedValue(input)); + rewriter.replaceOpWithNewOp(op, results); + return success(); + } +}; + +struct HWAggregateConstantOpConversion + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static LogicalResult peelAttribute(Location loc, Attribute attr, + ConversionPatternRewriter &rewriter, + SmallVector &results) { + SmallVector worklist; + worklist.push_back(attr); + + while (!worklist.empty()) { + auto current = worklist.pop_back_val(); + if (auto innerArray = dyn_cast(current)) { + for (auto elem : llvm::reverse(innerArray)) + worklist.push_back(elem); + continue; + } + + if (auto intAttr = dyn_cast(current)) { + results.push_back(rewriter.create(loc, intAttr)); + continue; + } + + return failure(); + } + + return success(); + } + + LogicalResult + matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Lower to concat. + SmallVector results; + if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter, + results))) + return failure(); + rewriter.replaceOpWithNewOp(op, results); + return success(); + } +}; + +struct HWArrayGetOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector results; + auto arrayType = cast(op.getInput().getType()); + auto elemType = arrayType.getElementType(); + auto numElements = arrayType.getNumElements(); + auto elemWidth = hw::getBitWidth(elemType); + if (elemWidth < 0) + return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width"); + + auto lowered = rewriter.getRemappedValue(op.getInput()); + if (!lowered) + return failure(); + + for (size_t i = 0; i < numElements; ++i) + results.push_back(rewriter.createOrFold( + op.getLoc(), lowered, i * elemWidth, elemWidth)); + + auto bits = extractBits(rewriter, op.getIndex()); + auto result = constructMuxTree(rewriter, op.getLoc(), numElements, bits, + results, results.back()); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// A type converter is needed to perform the in-flight materialization of +/// aggregate types to integer types. +class CombToAIGTypeConverter : public TypeConverter { +public: + CombToAIGTypeConverter() { + addConversion([](Type type) -> Type { return type; }); + addConversion([](hw::ArrayType t) -> Type { + return IntegerType::get(t.getContext(), hw::getBitWidth(t)); + }); + addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1) + return Value(); + + return builder.create(loc, resultType, inputs[0]) + ->getResult(0); + }); + + addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1) + return Value(); + + return builder.create(loc, resultType, inputs[0]) + ->getResult(0); + }); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -610,7 +730,9 @@ struct ConvertCombToAIGPass }; } // namespace -static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { +static void +populateCombToAIGConversionPatterns(RewritePatternSet &patterns, + CombToAIGTypeConverter &typeConverter) { patterns.add< // Bitwise Logical Ops CombAndOpConversion, CombOrOpConversion, CombXorOpConversion, @@ -625,17 +747,31 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { /*isSigned=*/false>, CombShiftOpConversion, + // Array Ops + HWArrayGetOpConversion, HWArrayCreateLikeOpConversion, + HWArrayCreateLikeOpConversion, + HWAggregateConstantOpConversion, // Variadic ops that must be lowered to binary operations CombLowerVariadicOp, CombLowerVariadicOp, - CombLowerVariadicOp>(patterns.getContext()); + CombLowerVariadicOp>(typeConverter, patterns.getContext()); } void ConvertCombToAIGPass::runOnOperation() { ConversionTarget target(getContext()); + + // Comb is source dialect. target.addIllegalDialect(); // Keep data movement operations like Extract, Concat and Replicate. target.addLegalOp(); + + // Treat array operations as illegal. Strictly speaking, other than array get + // operation with non-const index are legal but array types prevent a bunch of + // optimizations so just lower them to integer operations. + target.addIllegalOp(); + + // AIG is target dialect. target.addLegalDialect(); // This is a test only option to add logical ops. @@ -644,7 +780,8 @@ void ConvertCombToAIGPass::runOnOperation() { target.addLegalOp(OperationName(opName, &getContext())); RewritePatternSet patterns(&getContext()); - populateCombToAIGConversionPatterns(patterns); + CombToAIGTypeConverter typeConverter; + populateCombToAIGConversionPatterns(patterns, typeConverter); if (failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir index efcc1e862047..ffec8bffb7f9 100644 --- a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -158,3 +158,38 @@ hw.module @shift2(in %lhs: i2, in %rhs: i2, out out_shl: i2, out out_shr: i2, ou // ALLOW_ICMP-NEXT: hw.output %[[L_SHIFT_WITH_BOUND_CHECK]], %[[R_SHIFT_WITH_BOUND_CHECK]], %[[R_SIGNED_SHIFT_WITH_BOUND_CHECK]] hw.output %0, %1, %2 : i2, i2, i2 } + + +// CHECK-LABEL: @array( +hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, out out: !hw.array<4xi2>, in %sel: i2, out out_get: i2, out out_agg: !hw.array<4xi2>, out out_agg_get: i2) { + %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2 + %1 = hw.array_get %0[%sel] : !hw.array<4xi2>, i2 + %2 = hw.aggregate_constant [0 : i2, 1 : i2, -2 : i2, -1 : i2] : !hw.array<4xi2> + %3 = hw.array_get %2[%sel] : !hw.array<4xi2>, i2 + // CHECK: %[[CONCAT:.+]] = comb.concat %arg0, %arg1, %arg2, %arg3 : i2, i2, i2, i2 + // CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONCAT]] : (i8) -> !hw.array<4xi2> + // CHECK-NEXT: %[[ARG_3:.+]] = comb.extract %[[CONCAT]] from 0 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_2:.+]] = comb.extract %[[CONCAT]] from 2 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_1:.+]] = comb.extract %[[CONCAT]] from 4 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_0:.+]] = comb.extract %[[CONCAT]] from 6 : (i8) -> i2 + // CHECK-NEXT: %[[SEL_0:.+]] = comb.extract %sel from 0 : (i2) -> i1 + // CHECK-NEXT: %[[SEL_1:.+]] = comb.extract %sel from 1 : (i2) -> i1 + // CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[SEL_0]], %[[ARG_0]], %[[ARG_1]] : i2 + // CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[SEL_0]], %[[ARG_2]], %[[ARG_3]] : i2 + // CHECK-NEXT: %[[ARRAY_GET_1:.+]] = comb.mux %[[SEL_1]], %[[MUX_0]], %[[MUX_1]] : i2 + // CHECK-NEXT: %[[C_0:.+]] = hw.constant 0 + // CHECK-NEXT: %[[C_1:.+]] = hw.constant 1 + // CHECK-NEXT: %[[C_2:.+]] = hw.constant -2 + // CHECK-NEXT: %[[C_3:.+]] = hw.constant -1 + // CHECK-NEXT: %[[AGG_CONST:.+]] = comb.concat %[[C_0]], %[[C_1]], %[[C_2]], %[[C_3]] + // CHECK-NEXT: %[[BITCAST_CONST:.+]] = hw.bitcast %[[AGG_CONST]] + // CHECK-NEXT: %[[ARG_3:.+]] = comb.extract %[[AGG_CONST]] from 0 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_2:.+]] = comb.extract %[[AGG_CONST]] from 2 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_1:.+]] = comb.extract %[[AGG_CONST]] from 4 : (i8) -> i2 + // CHECK-NEXT: %[[ARG_0:.+]] = comb.extract %[[AGG_CONST]] from 6 : (i8) -> i2 + // CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[SEL_0]], %[[ARG_0]], %[[ARG_1]] : i2 + // CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[SEL_0]], %[[ARG_2]], %[[ARG_3]] : i2 + // CHECK-NEXT: %[[AGG_GET_2:.+]] = comb.mux %[[SEL_1]], %[[MUX_0]], %[[MUX_1]] : i2 + // CHECK-NEXT: hw.output %[[BITCAST]], %[[ARRAY_GET_1]], %[[BITCAST_CONST]], %[[AGG_GET_2]] + hw.output %0, %1, %2, %3 : !hw.array<4xi2>, i2, !hw.array<4xi2>, i2 +} diff --git a/test/Conversion/CombToAIG/comb-to-aig.mlir b/test/Conversion/CombToAIG/comb-to-aig.mlir index d65c0c150b29..1832c55c468c 100644 --- a/test/Conversion/CombToAIG/comb-to-aig.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig.mlir @@ -52,3 +52,43 @@ hw.module @mux(in %cond: i1, in %high: !hw.array<2xi4>, in %low: !hw.array<2xi4> %0 = comb.mux %cond, %high, %low : !hw.array<2xi4> hw.output %0 : !hw.array<2xi4> } + +// CHECK-LABEL: @agg_const +hw.module @agg_const(out out: !hw.array<4xi4>) { + // CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4 + // CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4> + // CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4> + %0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4> + hw.output %0 : !hw.array<4xi4> +} + +// CHECK-LABEL: @array_get_for_port +hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) { + %c_i2 = hw.constant 3 : i3 + // CHECK-NEXT: %[[BITCAST_IN:.+]] = hw.bitcast %in : (!hw.array<5xi4>) -> i20 + // CHECK: %[[EXTRACT:.+]] = comb.extract %[[BITCAST_IN]] from 12 : (i20) -> i4 + // CHECK: hw.output %[[EXTRACT]] : i4 + %1 = hw.array_get %in[%c_i2] : !hw.array<5xi4>, i3 + hw.output %1 : i4 +} + +// CHECK-LABEL: @array_concat +hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) { + %0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4> + %c_i2 = hw.constant 3 : i3 + // CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12 + // CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8 + // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12 + // CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4 + // CHECK: hw.output %[[EXTRACT]] : i4 + %1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3 + hw.output %1 : i4 +} + +hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) +// CHECK-LABEL: @array_instance( +hw.module @array_instance(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) { + // CHECK-NEXT: hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>) + %0 = hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>) + hw.output %0 : !hw.array<4xi2> +}