Skip to content

Commit

Permalink
✨ Add derive Mergeable for GADT
Browse files Browse the repository at this point in the history
  • Loading branch information
lsrcz committed Dec 8, 2024
1 parent 96f7fab commit 7ba3934
Show file tree
Hide file tree
Showing 4 changed files with 697 additions and 36 deletions.
153 changes: 118 additions & 35 deletions examples/basic/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,139 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Main (main) where

import GHC.Generics
import Grisette
import Grisette.Internal.TH.GADT.DeriveMergeable

data SymExpr
= SymConst SymInteger
| SymInput SymInteger
| SymAdd (Union SymExpr) (Union SymExpr)
| SymMul (Union SymExpr) (Union SymExpr)
deriving stock (Generic, Show)
deriving (Mergeable, EvalSym) via (Default SymExpr)
type IntExpr = Expr SymInteger

-- You may use the following template haskell call to derive everything we need.
-- It will require a long list of extensions though, as it generates some
-- redundant constraints. We will refine this in future releases.
type BoolExpr = Expr SymBool

-- deriveAllExcept ''SymExpr [''Ord]
type UExpr a = Union (Expr a)

mkMergeConstructor "mrg" ''SymExpr
type UIntExpr = UExpr SymInteger

progSpace :: SymInteger -> SymExpr
progSpace x =
SymAdd
(mrgSymInput x)
(mrgIf "choice" (mrgSymInput x) (mrgSymConst "c"))
type UBoolExpr = UExpr SymBool

interpret :: SymExpr -> SymInteger
interpret (SymConst x) = x
interpret (SymInput x) = x
interpret (SymAdd x y) = interpretSpace x + interpretSpace y
interpret (SymMul x y) = interpretSpace x * interpretSpace y
data Expr a where
IntVal :: SymInteger -> IntExpr
BoolVal :: SymBool -> BoolExpr
Add :: UIntExpr -> UIntExpr -> IntExpr
Mul :: UIntExpr -> UIntExpr -> IntExpr
BAnd :: UBoolExpr -> UBoolExpr -> BoolExpr
BOr :: UBoolExpr -> UBoolExpr -> BoolExpr
Eq :: (BasicSymPrim a) => UExpr a -> UExpr a -> BoolExpr

interpretSpace :: Union SymExpr -> SymInteger
interpretSpace = onUnion interpret
deriving instance Show (Expr a)

executableSpace :: Integer -> SymInteger
executableSpace = interpret . progSpace . toSym
deriveGADTMergeable ''Expr
makeSmartCtor ''Expr

instance (EvalSym a) => EvalSym (Expr a) where
evalSym fillDefault m (IntVal a) = IntVal $ evalSym fillDefault m a
evalSym fillDefault m (BoolVal a) = BoolVal $ evalSym fillDefault m a
evalSym fillDefault m (Add a b) = Add (evalSym fillDefault m a) (evalSym fillDefault m b)
evalSym fillDefault m (Mul a b) = Mul (evalSym fillDefault m a) (evalSym fillDefault m b)
evalSym fillDefault m (BAnd a b) = BAnd (evalSym fillDefault m a) (evalSym fillDefault m b)
evalSym fillDefault m (BOr a b) = BOr (evalSym fillDefault m a) (evalSym fillDefault m b)
evalSym fillDefault m (Eq a b) = Eq (evalSym fillDefault m a) (evalSym fillDefault m b)

instance (ExtractSym a) => ExtractSym (Expr a) where
extractSymMaybe (IntVal a) = extractSymMaybe a
extractSymMaybe (BoolVal a) = extractSymMaybe a
extractSymMaybe (Add a b) = extractSymMaybe a <> extractSymMaybe b
extractSymMaybe (Mul a b) = extractSymMaybe a <> extractSymMaybe b
extractSymMaybe (BAnd a b) = extractSymMaybe a <> extractSymMaybe b
extractSymMaybe (BOr a b) = extractSymMaybe a <> extractSymMaybe b
extractSymMaybe (Eq a b) = extractSymMaybe a <> extractSymMaybe b

eval :: Expr a -> a
eval (IntVal a) = a
eval (BoolVal a) = a
eval (Add a b) = eval .# a + eval .# b
eval (Mul a b) = eval .# a * eval .# b
eval (BAnd a b) = eval .# a .&& eval .# b
eval (BOr a b) = eval .# a .|| eval .# b
eval (Eq a b) = eval .# a .== eval .# b

verifyEquivalent :: (BasicSymPrim a) => Expr a -> Expr a -> IO ()
verifyEquivalent e1 e2 = do
res <- solve z3 $ eval e1 ./= eval e2
case res of
Left Unsat -> putStrLn "The two expressions are equivalent"
Left err -> putStrLn $ "The solver returned unexpected response: " <> show err
Right model -> do
putStrLn "The two expressions are not equivalent, under the model:"
print model
putStrLn $ "lhs: " <> show e1
putStrLn $ "rhs: " <> show e2
putStrLn $ "lhs evaluates to: " <> show (evalSym False model $ eval e1)
putStrLn $ "rhs evaluates to: " <> show (evalSym False model $ eval e2)

synthesisRewriteTarget :: (BasicSymPrim a) => Expr a -> UExpr a -> IO ()
synthesisRewriteTarget expr sketch = do
r <- cegisForAll z3 expr $ cegisPostCond $ eval expr .== eval .# sketch
case r of
(_, CEGISSuccess model) -> do
putStrLn $ "For the target expression: " <> show expr
putStrLn "Successfully synthesized RHS:"
print $ evalSym False model sketch
(cex, failure) -> do
putStrLn $ "Synthesis failed with error: " ++ show failure
putStrLn $ "Counter example list: " ++ show cex

productOfSum :: Expr SymInteger
productOfSum = Mul (intVal "a") (add (intVal "b") (intVal "c"))

sumOfProduct :: Expr SymInteger
sumOfProduct =
Add (mul (intVal "a") (intVal "b")) (mul (intVal "a") (intVal "c"))

allSum :: Expr SymInteger
allSum = Add (intVal "a") (add (intVal "b") (intVal "c"))

xPlusX :: Expr SymInteger
xPlusX = Add (intVal "x") (intVal "x")

xTimesC :: UExpr SymInteger
xTimesC = mul (intVal "x") (intVal "c")

nextLevel :: [UExpr SymInteger] -> Fresh (UExpr SymInteger)
nextLevel exprs = do
lhs <- chooseUnionFresh exprs
rhs <- chooseUnionFresh exprs
chooseUnionFresh [add lhs rhs, mul lhs rhs, lhs]

getSketch :: Fresh (UExpr SymInteger)
getSketch = do
let atom = [intVal "a", intVal "b", intVal "c"]
l2 <- nextLevel atom
r2 <- nextLevel atom
nextLevel [l2, r2]

sketch :: UExpr SymInteger
sketch = runFresh getSketch "sketch"

main :: IO ()
main = do
Right model <- solve z3 $ executableSpace 2 .== 5
-- result: SymPlus {SymInput x} {SymConst 3}
print $ evalSym False model (progSpace "x")
let synthesizedProgram :: Integer -> Integer =
evalSymToCon model . executableSpace
-- result: 13
print $ synthesizedProgram 10
putStrLn "---- verifying productOfSum and sumOfProduct are equivalent ----"
verifyEquivalent productOfSum sumOfProduct
putStrLn "---- verifying productOfSum and allSum are equivalent (should fail) ----"
verifyEquivalent productOfSum allSum

putStrLn "---- synthesis x + x => x * 2 ----"
synthesisRewriteTarget xPlusX xTimesC
putStrLn "---- synthesis a * (b + c) => a * b + a * c ----"
synthesisRewriteTarget productOfSum sketch
2 changes: 1 addition & 1 deletion examples/grisette-examples.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 1.12

-- This file has been generated from package.yaml by hpack version 0.36.1.
-- This file has been generated from package.yaml by hpack version 0.37.0.
--
-- see: https://github.com/sol/hpack

Expand Down
1 change: 1 addition & 0 deletions grisette.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ library
Grisette.Internal.TH.DeriveTypeParamHandler
Grisette.Internal.TH.DeriveUnifiedInterface
Grisette.Internal.TH.DeriveWithHandlers
Grisette.Internal.TH.GADT.DeriveMergeable
Grisette.Internal.TH.Util
Grisette.Internal.Utils.Derive
Grisette.Internal.Utils.Parameterized
Expand Down
Loading

0 comments on commit 7ba3934

Please sign in to comment.