diff --git a/integration_test/circt-lec/comb.mlir b/integration_test/circt-lec/comb.mlir index b52db7ce7be3..5b02b403f6b7 100644 --- a/integration_test/circt-lec/comb.mlir +++ b/integration_test/circt-lec/comb.mlir @@ -59,10 +59,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) { // TODO // comb.divs -// TODO +// RUN: circt-lec %s -c1=divs_unsafe -c2=divs_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS_UNSAFE +// RUN: circt-lec %s -c1=divs -c2=divs --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS +// COMB_DIVS_UNSAFE: c1 != c2 +// COMB_DIVS: c1 == c2 + +hw.module @divs_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.divs %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @divs(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.divs %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.divu -// TODO +// RUN: circt-lec %s -c1=divu_unsafe -c2=divu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU_UNSAFE +// RUN: circt-lec %s -c1=divu -c2=divu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU +// COMB_DIVU_UNSAFE: c1 != c2 +// COMB_DIVU: c1 == c2 + +hw.module @divu_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.divu %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @divu(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.divu %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.extract // TODO @@ -71,10 +103,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) { // TODO // comb.mods -// TODO +// RUN: circt-lec %s -c1=mods_unsafe -c2=mods_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_UNSAFE +// RUN: circt-lec %s -c1=mods -c2=mods --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS +// COMB_MODS_UNSAFE: c1 != c2 +// COMB_MODS: c1 == c2 + +hw.module @mods_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.mods %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @mods(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.mods %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.modu -// TODO +// RUN: circt-lec %s -c1=modu_unsafe -c2=modu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU_UNSAFE +// RUN: circt-lec %s -c1=modu -c2=modu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU +// COMB_MODU_UNSAFE: c1 != c2 +// COMB_MODU: c1 == c2 + +hw.module @modu_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.modu %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @modu(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.modu %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.mul // RUN: circt-lec %s -c1=mulBy2 -c2=addTwice --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL diff --git a/lib/Conversion/CombToSMT/CombToSMT.cpp b/lib/Conversion/CombToSMT/CombToSMT.cpp index 103e7aacb3cf..b632cf137bbe 100644 --- a/lib/Conversion/CombToSMT/CombToSMT.cpp +++ b/lib/Conversion/CombToSMT/CombToSMT.cpp @@ -192,6 +192,34 @@ struct OneToOneOpConversion : OpConversionPattern { } }; +/// Lower the SourceOp to the TargetOp special-casing if the second operand is +/// zero to return a new symbolic value. +template +struct DivisionOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto type = dyn_cast(adaptor.getRhs().getType()); + if (!type) + return failure(); + + auto resultType = OpConversionPattern::typeConverter->convertType( + op.getResult().getType()); + Value zero = + rewriter.create(loc, APInt(type.getWidth(), 0)); + Value isZero = rewriter.create(loc, adaptor.getRhs(), zero); + Value symbolicVal = rewriter.create(loc, resultType); + Value division = + rewriter.create(loc, resultType, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, isZero, symbolicVal, division); + return success(); + } +}; + /// Converts an operation with a variadic number of operands to a chain of /// binary operations assuming left-associativity of the operation. template @@ -236,10 +264,10 @@ void circt::populateCombToSMTConversionPatterns(TypeConverter &converter, OneToOneOpConversion, OneToOneOpConversion, OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, VariadicToBinaryOpConversion, VariadicToBinaryOpConversion, VariadicToBinaryOpConversion, diff --git a/lib/Tools/circt-lec/ConstructLEC.cpp b/lib/Tools/circt-lec/ConstructLEC.cpp index 29fe0921821e..cd1869546fe9 100644 --- a/lib/Tools/circt-lec/ConstructLEC.cpp +++ b/lib/Tools/circt-lec/ConstructLEC.cpp @@ -110,37 +110,32 @@ void ConstructLECPass::runOnOperation() { builder.createBlock(&entryFunc.getBody()); - Value areEquivalent; - if (moduleA == moduleB) { - // Trivially equivalent - areEquivalent = - builder.create(loc, builder.getI1Type(), 1); - moduleA->erase(); - } else { - auto lecOp = builder.create(loc); - areEquivalent = lecOp.getAreEquivalent(); - auto *outputOpA = moduleA.getBodyBlock()->getTerminator(); - auto *outputOpB = moduleB.getBodyBlock()->getTerminator(); - lecOp.getFirstCircuit().takeBody(moduleA.getBody()); - lecOp.getSecondCircuit().takeBody(moduleB.getBody()); - - moduleA->erase(); + auto lecOp = builder.create(loc); + Value areEquivalent = lecOp.getAreEquivalent(); + builder.cloneRegionBefore(moduleA.getBody(), lecOp.getFirstCircuit(), + lecOp.getFirstCircuit().end()); + builder.cloneRegionBefore(moduleB.getBody(), lecOp.getSecondCircuit(), + lecOp.getSecondCircuit().end()); + + moduleA->erase(); + if (moduleA != moduleB) moduleB->erase(); - { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(outputOpA); - builder.create(loc, outputOpA->getOperands()); - outputOpA->erase(); - builder.setInsertionPoint(outputOpB); - builder.create(loc, outputOpB->getOperands()); - outputOpB->erase(); - } - - sortTopologically(&lecOp.getFirstCircuit().front()); - sortTopologically(&lecOp.getSecondCircuit().front()); + { + auto *term = lecOp.getFirstCircuit().front().getTerminator(); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(term); + builder.create(loc, term->getOperands()); + term->erase(); + term = lecOp.getSecondCircuit().front().getTerminator(); + builder.setInsertionPoint(term); + builder.create(loc, term->getOperands()); + term->erase(); } + sortTopologically(&lecOp.getFirstCircuit().front()); + sortTopologically(&lecOp.getSecondCircuit().front()); + // TODO: we should find a more elegant way of reporting the result than // already inserting some LLVM here Value eqFormatString = diff --git a/test/Conversion/CombToSMT/comb-to-smt.mlir b/test/Conversion/CombToSMT/comb-to-smt.mlir index 78745b20696c..5d480409bdde 100644 --- a/test/Conversion/CombToSMT/comb-to-smt.mlir +++ b/test/Conversion/CombToSMT/comb-to-smt.mlir @@ -10,13 +10,29 @@ func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt. %arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1 %arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4 - // CHECK: smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %0 = comb.divs %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %1 = comb.divu %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.srem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %2 = comb.mods %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.urem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %3 = comb.modu %arg0, %arg1 : i32 // CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32>