-
Notifications
You must be signed in to change notification settings - Fork 9
/
Simplify.hs
274 lines (246 loc) · 10.1 KB
/
Simplify.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
{-# LANGUAGE BangPatterns
#-}
module Simplify
( do_simplify
) where
import Data.List (foldl')
import qualified Data.Set as Set
import Control.Monad.State
import Common
import Errors
import Field
import Constraints
import UnionFind
import SimplMonad
----------------------------------------------------------------
-- Substitution --
----------------------------------------------------------------
-- | Normalize constraint 'constr', by substituting roots/constants
-- for the variables that appear in the constraint. Note that, when
-- normalizing a multiplicative constraint, it may be necessary to
-- convert it into an additive constraint.
subst_constr :: Field a
=> Constraint a
-> State (SEnv a) (Constraint a)
subst_constr !constr = case constr of
CMagic !_ !xs !mf ->
do { solve <- solve_mode_flag
; if solve then
do { b <- mf xs
; if b then return $ cadd zero []
else return constr
}
else return constr
}
CAdd a m ->
do { -- Variables resolvable to constants
consts' <- mapM (\(x,a0) ->
do { var_or_a <- bind_of_var x
; case var_or_a of
Left _ -> return []
Right a' -> return $! [(x,a0 `mult` a')]
})
$! asList m
; let consts = concat consts'
; let const_keys = map fst consts
; let const_vals = map snd consts
-- The new constant folding in all constant constraint variables
; let new_const = foldl' add a const_vals
-- The linear combination minus
-- (1) Terms whose variables resolve to constants, and
-- (2) Terms with coeff 0.
; let less_consts
= filter (\(k,v) -> not (elem k const_keys) && v/=zero)
$! asList m
-- The new linear combination: 'less_consts' with all variables
-- replaced by their roots.
; new_map <- mapM (\(x,a0) ->
do { rx <- root_of_var x
; return $! (rx,a0)
})
less_consts
; return $! cadd new_const new_map
}
CMult !(c,x) !(d,y) !ez ->
do { bx <- bind_of_var x
; by <- bind_of_var y
; bz <- bind_of_term ez
; case (bx,by,bz) of
(Left rx,Left ry,Left (e,rz)) ->
return
$! CMult (c,rx) (d,ry) (e,Just rz)
(Left rx,Left ry,Right e) ->
return
$! CMult (c,rx) (d,ry) (e,Nothing)
(Left rx,Right d0,Left (e,rz)) ->
return
$! cadd zero [(rx,c `mult` d `mult` d0),(rz,neg e)]
(Left rx,Right d0,Right e) ->
return
$! cadd (neg e) [(rx,c `mult` d `mult` d0)]
(Right c0,Left ry,Left (e,rz)) ->
return
$! cadd zero [(ry,c0 `mult` c `mult` d),(rz,neg e)]
(Right c0,Left ry,Right e) ->
return
$! cadd (neg e) [(ry,c0 `mult` c `mult` d)]
(Right c0,Right d0,Left (e,rz)) ->
return
$! cadd (c `mult` c0 `mult` d `mult` d0) [(rz,neg e)]
(Right c0,Right d0,Right e) ->
return
$! cadd (c `mult` c0 `mult` d `mult` d0 `add` (neg e)) []
}
where bind_of_term (e,Nothing)
= return $! Right e
bind_of_term (e,Just z)
= do { var_or_a <- bind_of_var z
; case var_or_a of
Left rz -> return $! Left (e,rz)
Right e0 -> return $! Right (e `mult` e0)
}
----------------------------------------------------------------
-- Constraint Set Minimization --
----------------------------------------------------------------
-- | Is 'constr' a tautology?
is_taut :: Field a
=> Constraint a
-> State (SEnv a) Bool
is_taut constr
= case constr of
CAdd _ (CoeffList []) -> return True
CAdd _ (CoeffList (_ : _)) -> return False
CMult _ _ _ -> return False
CMagic _ xs mf -> mf xs
-- | Remove tautologous constraints.
remove_tauts :: Field a => [Constraint a] -> State (SEnv a) [Constraint a]
remove_tauts sigma
= do { sigma_taut <-
mapM (\t -> do { t' <- subst_constr t
; b <- is_taut t'
; return (b,t') }) sigma
; return $ map snd $ filter (not . fst) sigma_taut
}
-- | Learn bindings and variable equalities from constraint 'constr'.
learn :: Field a
=> Constraint a
-> State (SEnv a) ()
learn = go
where go (CAdd a (CoeffList [(x,c)]))
= if c == zero then return ()
else case inv c of
Nothing ->
fail_with
$ ErrMsg (show c ++ " not invertible")
Just c' -> bind_var (x,neg a `mult` c')
go (CAdd a (CoeffList [(x,c),(y,d)]))
| a==zero
= if c == neg d then unite_vars x y else return ()
go (CAdd _ _)
| otherwise
= return ()
go _ | otherwise = return ()
do_simplify :: Field a
=> Bool -- ^ Solve mode? If 'True', use Magic.
-> Assgn a -- ^ Initial variable assignment
-> ConstraintSystem a -- ^ Constraint set to be simplified
-> (Assgn a,ConstraintSystem a)
-- ^ Resulting assignment, simplified constraint set
do_simplify in_solve_mode env cs
-- NOTE: Pinned vars include:
-- - input vars
-- - output vars
-- - magic vars (those that appear in magic constraints, used to
-- resolve nondeterministic inputs)
-- Pinned vars are never optimized away.
= let pinned_vars = cs_in_vars cs ++ cs_out_vars cs ++ magic_vars (cs_constraints cs)
do_solve = if in_solve_mode then UseMagic else JustSimplify
new_state = SEnv (new_uf { extras = env }) do_solve
in fst $ runState (go pinned_vars) new_state
where go pinned_vars
= do { sigma' <- simplify pinned_vars $ cs_constraints cs
-- NOTE: In the next line, it's OK that 'pinned_vars'
-- may overlap with 'constraint_vars cs'.
-- 'assgn_of_vars' might do a bit of duplicate
-- work (to look up the same key more than once).
; assgn <- assgn_of_vars
$ pinned_vars
++ constraint_vars (cs_constraints cs)
; return (assgn,cs { cs_constraints = sigma' })
}
magic_vars cs0
= Set.fold (\c0 acc ->
case c0 of
CMagic _ xs _ -> xs ++ acc
_ -> acc
) [] cs0
simplify :: Field a
=> [Var]
-> ConstraintSet a
-> State (SEnv a) (ConstraintSet a)
simplify pinned_vars sigma
= do { sigma' <- simplify_rec sigma
; sigma_subst <- mapM subst_constr $ Set.toList sigma'
; sigma_no_tauts <- remove_tauts sigma_subst
; sigma_pinned <- add_pin_eqns sigma_no_tauts
; return $ Set.fromList sigma_pinned
}
where -- NOTE: We handle pinned variables 'x' as follows:
-- (1) Look up the term associated with
-- the pinned variable, if any (call it 't').
-- (2) If there is no such term (other than 'x' itself),
-- do nothing (clauses containing the pinned
-- variable must still contain the pinned variable).
-- (3) Otherwise, introduce a new equation 'x = t'.
add_pin_eqns sigma0
= do { pinned_terms <-
mapM (\x -> do { var_or_a <- bind_of_var x
; return (x,var_or_a)
}) pinned_vars
; let pin_eqns
= map (\(x,var_or_a) ->
case var_or_a of
Left rx ->
cadd zero [(x,one),(rx,neg one)]
Right c ->
cadd (neg c) [(x,one)])
$ filter (\(x,rx) -> Left x /= rx) pinned_terms
; return $ pin_eqns ++ sigma0
}
simplify_rec :: Field a
=> ConstraintSet a -- ^ Initial constraint set
-> State (SEnv a) (ConstraintSet a)
-- ^ Resulting simplified constraint set
simplify_rec sigma
= do { sigma' <- simplify_once sigma
; if Set.size sigma' < Set.size sigma then
simplify_rec sigma'
else if Set.difference sigma sigma'
`Set.isSubsetOf` Set.empty then return sigma'
else simplify_rec sigma'
}
where simplify_once :: Field a
=> ConstraintSet a -- ^ Initial constraint set
-> State (SEnv a) (ConstraintSet a)
-- ^ Resulting simplified constraint set
simplify_once sigma0
= do { sigma2 <- go Set.empty sigma0
; sigma' <- remove_tauts (Set.toList sigma2)
; return $ Set.fromList sigma'
}
go ws us
| Set.size us == 0
= return ws
go ws us
| otherwise
= let (given,us') = choose us
in do { given' <- subst_constr given
; given_taut <- is_taut given'
; if given_taut then go ws us'
else do
learn given'
let ws' = Set.insert given' ws
go ws' us'
}
-- NOTE: Assumes input set is nonempty
choose s = Set.deleteFindMin s