From 51bfb5caa9265b75b79b1ac6fcb47dded9e63c9e Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Sun, 12 May 2024 11:45:55 +0200 Subject: [PATCH] [CombToSMT] Make result of div-by-zero undefined This adapts the conversion pass to match the recently agreed upon definition for division by zero. Integration tests for circt-lec are added to check the behavior. Note that two syntactically equivalent modules are not considered equivalent if they aren't guaranteed to deterministically produce the same outputs. Alternatively, we could consider two undefined output values equivalent by modeling each value as a pair of a boolean and the bit-vector where the boolean determines if the value is undefined, then two outputs are equivalent if either the boolean is true or the boolean is false and the bitvectors match. There are probably use-cases for both, so maybe we'd want a flag to let the user decide. --- integration_test/circt-lec/comb.mlir | 72 ++++++++++++++++++++-- lib/Conversion/CombToSMT/CombToSMT.cpp | 36 +++++++++-- lib/Tools/circt-lec/ConstructLEC.cpp | 49 +++++++-------- test/Conversion/CombToSMT/comb-to-smt.mlir | 24 ++++++-- 4 files changed, 142 insertions(+), 39 deletions(-) diff --git a/integration_test/circt-lec/comb.mlir b/integration_test/circt-lec/comb.mlir index b52db7ce7be3..5b02b403f6b7 100644 --- a/integration_test/circt-lec/comb.mlir +++ b/integration_test/circt-lec/comb.mlir @@ -59,10 +59,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) { // TODO // comb.divs -// TODO +// RUN: circt-lec %s -c1=divs_unsafe -c2=divs_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS_UNSAFE +// RUN: circt-lec %s -c1=divs -c2=divs --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS +// COMB_DIVS_UNSAFE: c1 != c2 +// COMB_DIVS: c1 == c2 + +hw.module @divs_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.divs %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @divs(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.divs %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.divu -// TODO +// RUN: circt-lec %s -c1=divu_unsafe -c2=divu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU_UNSAFE +// RUN: circt-lec %s -c1=divu -c2=divu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU +// COMB_DIVU_UNSAFE: c1 != c2 +// COMB_DIVU: c1 == c2 + +hw.module @divu_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.divu %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @divu(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.divu %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.extract // TODO @@ -71,10 +103,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) { // TODO // comb.mods -// TODO +// RUN: circt-lec %s -c1=mods_unsafe -c2=mods_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_UNSAFE +// RUN: circt-lec %s -c1=mods -c2=mods --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS +// COMB_MODS_UNSAFE: c1 != c2 +// COMB_MODS: c1 == c2 + +hw.module @mods_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.mods %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @mods(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.mods %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.modu -// TODO +// RUN: circt-lec %s -c1=modu_unsafe -c2=modu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU_UNSAFE +// RUN: circt-lec %s -c1=modu -c2=modu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU +// COMB_MODU_UNSAFE: c1 != c2 +// COMB_MODU: c1 == c2 + +hw.module @modu_unsafe(in %in1: i32, in %in2: i32, out out: i32) { + %0 = comb.modu %in1, %in2 : i32 + hw.output %0 : i32 +} + +hw.module @modu(in %in1: i32, in %in2: i32, out out: i32) { + %0 = hw.constant 0 : i32 + %1 = comb.icmp eq %in2, %0 : i32 + %2 = comb.modu %in1, %in2 : i32 + %3 = comb.mux %1, %0, %2 : i32 + hw.output %3 : i32 +} // comb.mul // RUN: circt-lec %s -c1=mulBy2 -c2=addTwice --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL diff --git a/lib/Conversion/CombToSMT/CombToSMT.cpp b/lib/Conversion/CombToSMT/CombToSMT.cpp index 103e7aacb3cf..b632cf137bbe 100644 --- a/lib/Conversion/CombToSMT/CombToSMT.cpp +++ b/lib/Conversion/CombToSMT/CombToSMT.cpp @@ -192,6 +192,34 @@ struct OneToOneOpConversion : OpConversionPattern { } }; +/// Lower the SourceOp to the TargetOp special-casing if the second operand is +/// zero to return a new symbolic value. +template +struct DivisionOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto type = dyn_cast(adaptor.getRhs().getType()); + if (!type) + return failure(); + + auto resultType = OpConversionPattern::typeConverter->convertType( + op.getResult().getType()); + Value zero = + rewriter.create(loc, APInt(type.getWidth(), 0)); + Value isZero = rewriter.create(loc, adaptor.getRhs(), zero); + Value symbolicVal = rewriter.create(loc, resultType); + Value division = + rewriter.create(loc, resultType, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, isZero, symbolicVal, division); + return success(); + } +}; + /// Converts an operation with a variadic number of operands to a chain of /// binary operations assuming left-associativity of the operation. template @@ -236,10 +264,10 @@ void circt::populateCombToSMTConversionPatterns(TypeConverter &converter, OneToOneOpConversion, OneToOneOpConversion, OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, - OneToOneOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, + DivisionOpConversion, VariadicToBinaryOpConversion, VariadicToBinaryOpConversion, VariadicToBinaryOpConversion, diff --git a/lib/Tools/circt-lec/ConstructLEC.cpp b/lib/Tools/circt-lec/ConstructLEC.cpp index 29fe0921821e..cd1869546fe9 100644 --- a/lib/Tools/circt-lec/ConstructLEC.cpp +++ b/lib/Tools/circt-lec/ConstructLEC.cpp @@ -110,37 +110,32 @@ void ConstructLECPass::runOnOperation() { builder.createBlock(&entryFunc.getBody()); - Value areEquivalent; - if (moduleA == moduleB) { - // Trivially equivalent - areEquivalent = - builder.create(loc, builder.getI1Type(), 1); - moduleA->erase(); - } else { - auto lecOp = builder.create(loc); - areEquivalent = lecOp.getAreEquivalent(); - auto *outputOpA = moduleA.getBodyBlock()->getTerminator(); - auto *outputOpB = moduleB.getBodyBlock()->getTerminator(); - lecOp.getFirstCircuit().takeBody(moduleA.getBody()); - lecOp.getSecondCircuit().takeBody(moduleB.getBody()); - - moduleA->erase(); + auto lecOp = builder.create(loc); + Value areEquivalent = lecOp.getAreEquivalent(); + builder.cloneRegionBefore(moduleA.getBody(), lecOp.getFirstCircuit(), + lecOp.getFirstCircuit().end()); + builder.cloneRegionBefore(moduleB.getBody(), lecOp.getSecondCircuit(), + lecOp.getSecondCircuit().end()); + + moduleA->erase(); + if (moduleA != moduleB) moduleB->erase(); - { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(outputOpA); - builder.create(loc, outputOpA->getOperands()); - outputOpA->erase(); - builder.setInsertionPoint(outputOpB); - builder.create(loc, outputOpB->getOperands()); - outputOpB->erase(); - } - - sortTopologically(&lecOp.getFirstCircuit().front()); - sortTopologically(&lecOp.getSecondCircuit().front()); + { + auto *term = lecOp.getFirstCircuit().front().getTerminator(); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(term); + builder.create(loc, term->getOperands()); + term->erase(); + term = lecOp.getSecondCircuit().front().getTerminator(); + builder.setInsertionPoint(term); + builder.create(loc, term->getOperands()); + term->erase(); } + sortTopologically(&lecOp.getFirstCircuit().front()); + sortTopologically(&lecOp.getSecondCircuit().front()); + // TODO: we should find a more elegant way of reporting the result than // already inserting some LLVM here Value eqFormatString = diff --git a/test/Conversion/CombToSMT/comb-to-smt.mlir b/test/Conversion/CombToSMT/comb-to-smt.mlir index 78745b20696c..5d480409bdde 100644 --- a/test/Conversion/CombToSMT/comb-to-smt.mlir +++ b/test/Conversion/CombToSMT/comb-to-smt.mlir @@ -10,13 +10,29 @@ func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt. %arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1 %arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4 - // CHECK: smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %0 = comb.divs %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %1 = comb.divu %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.srem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %2 = comb.mods %arg0, %arg1 : i32 - // CHECK-NEXT: smt.bv.urem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> + // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32> + // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32> + // CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32> + // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32> %3 = comb.modu %arg0, %arg1 : i32 // CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32>