Skip to content

Commit

Permalink
Merge pull request #317 from GaloisInc/vr/max-restarts
Browse files Browse the repository at this point in the history
add max-restarts CLI arguments to stop long constraint solving loops
  • Loading branch information
Ptival authored Oct 14, 2024
2 parents ccd5dbf + e5859d5 commit 07488fb
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 334 deletions.
1 change: 1 addition & 0 deletions reopt.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ library
attoparsec-aeson,
base64,
bytestring,
composition-extra,
containers,
directory,
elf-edit >= 0.40,
Expand Down
16 changes: 15 additions & 1 deletion reopt/Main_reopt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ data Args = Args
-- ^ Trace unification of the type inference solver
, traceConstraintOrigins :: Bool
-- ^ Trace the origin of constraints in the type inference solver
, maxRestarts :: Maybe Int
-- ^ How many time the type constraint solver should restart before giving up. `Nothing` means
-- infinite restarts.
}
deriving (Generic)

Expand Down Expand Up @@ -247,6 +250,7 @@ defaultArgs =
, performRecovery = False
, traceTypeUnification = False
, traceConstraintOrigins = False
, maxRestarts = Nothing
}

------------------------------------------------------------------------
Expand Down Expand Up @@ -299,6 +303,13 @@ traceConstraintOriginsP =
long "trace-constraint-origins"
<> help "Trace the origins of constraints in the type inference engine"

maxRestartsP :: Parser (Maybe Int)
maxRestartsP =
optional $ option auto $
long "max-restarts"
<> metavar "NUMBER"
<> help "Number of times the type constraint solver should restart before giving up"

outputPathP :: Parser String
outputPathP =
strOption $
Expand Down Expand Up @@ -623,6 +634,7 @@ arguments =
<*> performRecoveryP
<*> traceTypeUnificationP
<*> traceConstraintOriginsP
<*> maxRestartsP

-- | Parser to set the path to the binary to analyze.
programPathP :: Parser String
Expand Down Expand Up @@ -668,6 +680,7 @@ argsReoptOptions args = do
, roDiscoveryOptions = args ^. #discOpts
, roDynDepPaths = dynDepPath args
, roDynDepDebugPaths = dynDepDebugPath args ++ gdbDebugDirs
, roMaxRestarts = maxRestarts args
, roTraceUnification = traceTypeUnification args
, roTraceConstraintOrigins = traceConstraintOrigins args
}
Expand Down Expand Up @@ -722,7 +735,8 @@ showConstraints args elfPath = do
doRecoverX86 funPrefix sysp symAddrMap debugTypeMap discState Map.empty

let recMod = recoveredModule recoverX86Output
pure $ genModuleConstraints recMod (Macaw.memory discState) (traceTypeUnification args) (traceConstraintOrigins args)
pure $ genModuleConstraints recMod (Macaw.memory discState)
(maxRestarts args) (traceTypeUnification args) (traceConstraintOrigins args)

mc <- handleEitherWithExit mr

Expand Down
3 changes: 3 additions & 0 deletions src/Reopt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ data ReoptOptions = ReoptOptions
, roDynDepPaths :: ![FilePath]
-- ^ Additional paths to search for dynamic dependencies.
, roDynDepDebugPaths :: ![FilePath]
, roMaxRestarts :: Maybe Int
-- ^ Additional paths to search for debug versions of dynamic dependencies.
, roTraceUnification :: !Bool
-- ^ Trace unification in the solver
Expand All @@ -499,6 +500,7 @@ defaultReoptOptions =
, roDiscoveryOptions = reoptDefaultDiscoveryOptions
, roDynDepPaths = []
, roDynDepDebugPaths = []
, roMaxRestarts = Nothing
, roTraceUnification = False
, roTraceConstraintOrigins = False
}
Expand Down Expand Up @@ -2745,6 +2747,7 @@ reoptRecoveryLoop symAddrMap rOpts funPrefix sysp debugTypeMap firstDiscState =
genModuleConstraints
recMod
(Macaw.memory discState')
(roMaxRestarts rOpts)
(roTraceUnification rOpts)
(roTraceConstraintOrigins rOpts)

Expand Down
31 changes: 21 additions & 10 deletions src/Reopt/TypeInference/ConstraintGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance (..),
FnRepProvenance (..),
)
import Reopt.TypeInference.Solver.Monad (ConstraintSolvingReader (..))
import Reopt.TypeInference.Solver.Types (TyF (..))

-- This algorithm proceeds in stages:
Expand Down Expand Up @@ -241,20 +242,29 @@ inSolverM = CGenM . lift . lift

runCGenM ::
Memory (ArchAddrWidth arch) ->
Maybe Int ->
Bool ->
Bool ->
CGenM CGenGlobalContext arch a ->
a
runCGenM mem traceWanted orig (CGenM m) = runSolverM traceWanted orig ptrWidth $ do
let segs = memSegments mem
-- Allocate a row variable for each memory segment
memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs
let ctxt0 =
CGenGlobalContext
{ _cgenMemory = mem
, _cgenMemoryRegions = memRows
runCGenM mem maxRestarts traceWanted orig (CGenM m) = do
let initReader =
ConstraintSolvingReader
{ rMaxNumberOfRestarts = maxRestarts
, rPtrWidth = ptrWidth
, rTraceConstraintOrigins = orig
, rTraceUnification = traceWanted
}
evalStateT (Reader.runReaderT m ctxt0) st0
runSolverM initReader $ do
let segs = memSegments mem
-- Allocate a row variable for each memory segment
memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs
let ctxt0 =
CGenGlobalContext
{ _cgenMemory = mem
, _cgenMemoryRegions = memRows
}
evalStateT (Reader.runReaderT m ctxt0) st0
where
ptrWidth = widthVal (memWidth mem)

Expand Down Expand Up @@ -981,10 +991,11 @@ genModuleConstraints ::
FoldableFC (ArchFn arch) =>
RecoveredModule arch ->
Memory (ArchAddrWidth arch) ->
Maybe Int ->
Bool ->
Bool ->
ModuleConstraints arch
genModuleConstraints m mem traceWanted orig = runCGenM mem traceWanted orig $ do
genModuleConstraints m mem maxRestarts traceWanted orig = runCGenM mem maxRestarts traceWanted orig $ do
-- allocate type variables for functions without types
-- FIXME: we currently ignore hints

Expand Down
59 changes: 37 additions & 22 deletions src/Reopt/TypeInference/Solver/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module Reopt.TypeInference.Solver.Monad (
Conditional (..),
Conditional',
Conjunction (..),
ConstraintSolvingReader (..),
ConstraintSolvingState (..),
defineRowVar,
defineTyVar,
Expand Down Expand Up @@ -43,15 +44,16 @@ module Reopt.TypeInference.Solver.Monad (
withFresh,
) where

import Control.Lens (Lens', use, (%%=), (%=), (<<+=))
import Control.Monad.State (MonadState, State, evalState)
import Control.Lens (Lens', view, (%%=), (%=), (<<+=))
import Data.Foldable (asum)
import Data.Function.Slip (slipr)
import Data.Generics.Labels ()
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import GHC.Generics (Generic)
import Prettyprinter qualified as PP

import Control.Monad.RWS.Strict
import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance,
EqC (EqC),
Expand Down Expand Up @@ -81,6 +83,16 @@ import Reopt.TypeInference.Solver.UnionFindMap qualified as UM

type Conditional' = Conditional ([EqC], [EqRowC])

data ConstraintSolvingReader = ConstraintSolvingReader
{ rMaxNumberOfRestarts :: Maybe Int
, rPtrWidth :: Int
-- ^ The width of a pointer, in bits. This can go away when tyvars have an associated size, it is
-- only used for PtrAddC solving.
, rTraceUnification :: Bool
, rTraceConstraintOrigins :: Bool
}
deriving (Generic)

data ConstraintSolvingState = ConstraintSolvingState
{ ctxEqCs :: [EqC]
, ctxEqRowCs :: [EqRowC]
Expand All @@ -90,22 +102,16 @@ data ConstraintSolvingState = ConstraintSolvingState
, nextTraceId :: Int
, nextRowVar :: Int
, nextTyVar :: Int
, ptrWidth :: Int
-- ^ The width of a pointer, in bits. This can go away when
-- tyvars have an associated size, it is only used for PtrAddC
-- solving.
, ctxTyVars :: UnionFindMap TyVar TyVar ITy'
-- ^ The union-find data-structure mapping each tyvar onto its
-- representative tv. If no mapping exists, it is a self-mapping.
, ctxRowVars :: UnionFindMap RowVar RowInfo (FieldMap TyVar)
, -- Debugging
ctxTraceUnification :: Bool
, ctxTraceConstraintOrigins :: Bool
, ctxNumberOfRestarts :: Int
}
deriving (Generic)

emptyContext :: Int -> Bool -> Bool -> ConstraintSolvingState
emptyContext w trace orig =
emptyConstraintSolvingState :: ConstraintSolvingState
emptyConstraintSolvingState =
ConstraintSolvingState
{ ctxEqCs = []
, ctxEqRowCs = []
Expand All @@ -115,20 +121,29 @@ emptyContext w trace orig =
, nextTraceId = 0
, nextRowVar = 0
, nextTyVar = 0
, ptrWidth = w
, ctxTyVars = UM.empty
, ctxRowVars = UM.empty
, ctxTraceUnification = trace
, ctxTraceConstraintOrigins = orig
, ctxNumberOfRestarts = 0
}

newtype SolverM a = SolverM
{ getSolverM :: State ConstraintSolvingState a
{ getSolverM :: RWS ConstraintSolvingReader () ConstraintSolvingState a
}
deriving (Applicative, Functor, Monad, MonadState ConstraintSolvingState)

runSolverM :: Bool -> Bool -> Int -> SolverM a -> a
runSolverM b o w = flip evalState (emptyContext w b o) . getSolverM
deriving
( Applicative
, Functor
, Monad
, MonadState ConstraintSolvingState
, MonadReader ConstraintSolvingReader
)

runSolverM ::
ConstraintSolvingReader ->
SolverM a ->
a
runSolverM initReader = fst . slipr evalRWS initReader initState . getSolverM
where
initState = emptyConstraintSolvingState

--------------------------------------------------------------------------------
-- Adding constraints
Expand Down Expand Up @@ -272,13 +287,13 @@ unsafeUnifyTyVars root leaf = #ctxTyVars %= UM.unify root leaf
-- Other stuff

ptrWidthNumTy :: SolverM ITy'
ptrWidthNumTy = NumTy <$> use #ptrWidth
ptrWidthNumTy = NumTy <$> view #rPtrWidth

traceUnification :: SolverM Bool
traceUnification = use #ctxTraceUnification
traceUnification = view #rTraceUnification

traceConstraintOrigins :: SolverM Bool
traceConstraintOrigins = use #ctxTraceConstraintOrigins
traceConstraintOrigins = view #rTraceConstraintOrigins

--------------------------------------------------------------------------------
-- Conditional constraints
Expand Down
50 changes: 32 additions & 18 deletions src/Reopt/TypeInference/Solver/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Set qualified as Set
import Debug.Trace (trace)
import Prettyprinter qualified as PP

import Control.Monad.RWS (asks)
import Control.Monad.Trans (lift)
import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance (..),
Expand All @@ -43,6 +44,7 @@ import Reopt.TypeInference.Solver.Finalize (
import Reopt.TypeInference.Solver.Monad (
Conditional (..),
Conditional',
ConstraintSolvingReader (rMaxNumberOfRestarts, rPtrWidth),
ConstraintSolvingState (..),
SolverM,
addEqC,
Expand Down Expand Up @@ -180,17 +182,19 @@ solveHeadReset fld doit = do
put resetSt

-- Forget everything we know in resetSt about the eqvs for tv
let eqs = eqvClasses (ctxTyVars oldSt)
eqsTv = Map.findWithDefault [] tv eqs
let
eqs = eqvClasses (ctxTyVars oldSt)
eqsTv = Map.findWithDefault [] tv eqs
traverse_ undefineTyVar eqsTv

-- FIXME: gross
defineTyVar tv (ConflictTy (ptrWidth resetSt))
ptrWidth <- asks rPtrWidth
defineTyVar tv (ConflictTy ptrWidth)
-- FIXME: this could cause problems if we allocate tyvars after
-- we start solving. Because we don't, this should work.
mapM_ (addTyVarEq' ConflictProv tv) eqsTv -- retain eqv class for conflict var.
get
put resetSt'
put $ resetSt'{ctxNumberOfRestarts = ctxNumberOfRestarts resetSt + 1}

solveFirst ::
Lens' ConstraintSolvingState [a] ->
Expand Down Expand Up @@ -227,8 +231,9 @@ _solveAll fld solve = do
go acc progd [] = restore acc $> progd -- finished here, we didn't so anything.
go acc progd (c : cs) = do
(retain, progress) <- solve c
let acc' = if retain == Retain then c : acc else acc
progd' = progd || madeProgress progress
let
acc' = if retain == Retain then c : acc else acc
progd' = progd || madeProgress progress
go acc' progd' cs

-- | @preprocess l f# just pre-processes the element at @l@, and so
Expand All @@ -246,11 +251,17 @@ preprocess fld f =
fld %= (<> r)

solverLoop :: SolverM ()
solverLoop = evalStateT go =<< get
solverLoop = do
maxRestarts <- asks rMaxNumberOfRestarts
evalStateT (go (exceeds maxRestarts)) =<< get
where
go = do
exceeds (Just maxRestarts) r = r > maxRestarts
exceeds Nothing _ = False

go isTooMany = do
tooManyRestarts <- gets $ isTooMany . ctxNumberOfRestarts
keepGoing <- orM solvers
when keepGoing go
when (keepGoing && not tooManyRestarts) (go isTooMany)

solvers =
[ solveHeadReset #ctxEqCs solveEqC
Expand Down Expand Up @@ -371,14 +382,16 @@ solveConditional c = traceContext' "solveConditional" c $ do
solveEqRowC :: EqRowC -> SolverM ()
solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do
(le, m_lfm) <- lookupRowExpr (eqRowLHS eqc)
let lo = rowExprShift le
lv = rowExprVar le
lfm = fromMaybe emptyFieldMap m_lfm
let
lo = rowExprShift le
lv = rowExprVar le
lfm = fromMaybe emptyFieldMap m_lfm

(re, m_rfm) <- lookupRowExpr (eqRowRHS eqc)
let ro = rowExprShift re
rv = rowExprVar re
rfm = fromMaybe emptyFieldMap m_rfm
let
ro = rowExprShift re
rv = rowExprVar re
rfm = fromMaybe emptyFieldMap m_rfm

case () of
_
Expand All @@ -390,8 +403,9 @@ solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do
unify delta lowv lowfm highv highfm = do
undefineRowVar highv
unsafeUnifyRowVars (RowExprShift delta lowv) highv
let highfm' = shiftFieldMap delta highfm
(lowfm', newEqs) = unifyFieldMaps lowfm highfm'
let
highfm' = shiftFieldMap delta highfm
(lowfm', newEqs) = unifyFieldMaps lowfm highfm'
defineRowVar lowv lowfm'
traverse_ (uncurry (addTyVarEq' FromEqRowCProv)) newEqs

Expand Down Expand Up @@ -437,7 +451,7 @@ solveEqC eqc = do
-- when the type variable is conflicted.
unifyTypes :: TyVar -> ITy' -> ITy' -> SolverM (Maybe TyVar)
unifyTypes tv ty1 ty2 = do
pW <- gets ptrWidth
pW <- asks rPtrWidth
case (ty1, ty2) of
_ | ty1 == ty2 -> pure Nothing
(NumTy i, NumTy i')
Expand Down
Loading

0 comments on commit 07488fb

Please sign in to comment.