diff --git a/src/Grisette/Internal/Core/Data/Class/CEGISSolver.hs b/src/Grisette/Internal/Core/Data/Class/CEGISSolver.hs index 55a46a44..bfb57573 100644 --- a/src/Grisette/Internal/Core/Data/Class/CEGISSolver.hs +++ b/src/Grisette/Internal/Core/Data/Class/CEGISSolver.hs @@ -188,9 +188,11 @@ solverGenericCEGIS solver rerun initConstr synthConstr verifiers = do firstResult <- solverSolve solver initConstr case firstResult of Left err -> return ([], CEGISSolverFailure err) - Right model -> go model False verifiers + Right model -> go model False numAllVerifiers 0 verifiers where - go prevModel needRerun (verifier : remainingVerifiers) = do + numAllVerifiers = length verifiers + go prevModel _ 0 _ (_ : _) = return ([], CEGISSuccess prevModel) + go prevModel needRerun runBound nextBound (verifier : remainingVerifiers) = do verifierResult <- verifier prevModel case verifierResult of CEGISVerifierFoundCex cex -> do @@ -199,14 +201,20 @@ solverGenericCEGIS solver rerun initConstr synthConstr verifiers = do Left err -> return ([cex], CEGISSolverFailure err) Right model -> do (cexes, result) <- - go model (needRerun || rerun) $ - verifier : remainingVerifiers + go + model + (needRerun || rerun) + (length remainingVerifiers + 1) + (numAllVerifiers - length remainingVerifiers) + $ verifier : remainingVerifiers return (cex : cexes, result) - CEGISVerifierNoCex {} -> go prevModel needRerun remainingVerifiers + CEGISVerifierNoCex {} -> + go prevModel needRerun (runBound - 1) nextBound remainingVerifiers CEGISVerifierException exception -> return ([], CEGISVerifierFailure exception) - go prevModel False [] = return ([], CEGISSuccess prevModel) - go prevModel True [] = go prevModel False verifiers + go prevModel False _ _ [] = return ([], CEGISSuccess prevModel) + go prevModel True _runBound nextBound [] = + go prevModel False nextBound 0 verifiers -- | Generic CEGIS procedure with refinement. See 'genericCEGISWithRefinement' -- for more details. diff --git a/test/Grisette/Backend/CEGISTests.hs b/test/Grisette/Backend/CEGISTests.hs index 4059c819..9b442ee6 100644 --- a/test/Grisette/Backend/CEGISTests.hs +++ b/test/Grisette/Backend/CEGISTests.hs @@ -9,6 +9,7 @@ module Grisette.Backend.CEGISTests (cegisTests) where import Control.Monad.Except (ExceptT) +import Data.IORef (atomicModifyIORef', modifyIORef', newIORef, readIORef) import Data.Proxy (Proxy (Proxy)) import Data.String (IsString (fromString)) import GHC.Stack (HasCallStack) @@ -20,7 +21,7 @@ import Grisette Function ((#)), GrisetteSMTConfig, ITEOp (symIte), - LogicalOp (symNot, symXor, (.&&), (.||)), + LogicalOp (symNot, symXor, true, (.&&), (.||)), ModelRep (buildModel), ModelValuePair ((::=)), SizedBV (sizedBVConcat, sizedBVSelect, sizedBVSext, sizedBVZext), @@ -29,6 +30,7 @@ import Grisette SymOrd ((.<), (.>=)), Union, VerificationConditions, + VerifierResult (CEGISVerifierFoundCex, CEGISVerifierNoCex), cegis, cegisExceptVC, cegisForAll, @@ -37,8 +39,10 @@ import Grisette cegisPostCond, mrgIf, solve, + solverGenericCEGIS, symAssert, symAssume, + withSolver, z3, ) import Grisette.SymPrim @@ -471,5 +475,45 @@ cegisTests = ) m @?= expectedModel CEGISVerifierFailure _ -> fail "Verifier failed" - CEGISSolverFailure failure -> fail $ show failure + CEGISSolverFailure failure -> fail $ show failure, + testGroup "rerun" $ do + let verifier n trace retsIORef _ = do + modifyIORef' trace (n :) + ret <- atomicModifyIORef' retsIORef (\(x : xs) -> (xs, x)) + if ret + then return $ CEGISVerifierFoundCex "Found" + else return $ CEGISVerifierNoCex True + let createTestCase :: String -> Bool -> [[Bool]] -> [Int] -> Test + createTestCase name rerun rets expected = testCase name $ do + trace <- newIORef [] + retsIORefs <- traverse newIORef rets + withSolver unboundedConfig $ \handle -> + solverGenericCEGIS + handle + rerun + true + (const $ return true) + (zipWith (`verifier` trace) [0 ..] retsIORefs) + tracev <- readIORef trace + tracev @?= expected + [ createTestCase + "no rerun" + False + [[False], [True, False], [False]] + [2, 1, 1, 0], + createTestCase + "do rerun" + True + [[False, False], [True, False, False], [False]] + [1, 0, 2, 1, 1, 0], + createTestCase + "do rerun complex" + True + [ [False, False, False], + [True, False, True, False, False], + [True, False, False], + [False, False] + ] + [1, 0, 3, 2, 1, 1, 0, 3, 2, 2, 1, 1, 0] + ] ]