Skip to content

Commit

Permalink
test that instances for Eq and Ord agree with going via toAscList (#670)
Browse files Browse the repository at this point in the history
* Test that instances for Eq and Ord agree with going via toAscList

* Add benchmark for "instance Ord IntSet", using "Set IntSet"

* Improve implementation of "instance Ord IntSet"
that avoids toAscList and walks the tree directly. See #470
  • Loading branch information
jwaldmann authored and treeowl committed Dec 22, 2019
1 parent eb55a75 commit 7aff529
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 35 deletions.
133 changes: 102 additions & 31 deletions containers-tests/benchmarks/IntSet.hs
Original file line number Diff line number Diff line change
@@ -1,52 +1,123 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns, GeneralizedNewtypeDeriving #-}

module Main where

import Control.DeepSeq (rnf)
import Control.Exception (evaluate)
import Gauge (bench, defaultMain, whnf)
import Data.List (foldl')
import qualified Data.IntSet as S
import Data.Monoid (Sum(..))
#if !MIN_VERSION_base(4,8,0)
import Data.Foldable (foldMap)
#endif
import qualified Data.IntSet as IS
-- benchmarks for "instance Ord IntSet"
-- uses IntSet as keys of maps, and elements of sets
import qualified Data.Set as S
import qualified Data.IntMap as IM
import qualified Data.Map.Strict as M

main = do
let s = S.fromAscList elems :: S.IntSet
s_even = S.fromAscList elems_even :: S.IntSet
s_odd = S.fromAscList elems_odd :: S.IntSet
let s = IS.fromAscList elems :: IS.IntSet
s_even = IS.fromAscList elems_even :: IS.IntSet
s_odd = IS.fromAscList elems_odd :: IS.IntSet
evaluate $ rnf [s, s_even, s_odd]
defaultMain
[ bench "member" $ whnf (member elems) s
, bench "insert" $ whnf (ins elems) S.empty
, bench "map" $ whnf (S.map (+ 1)) s
, bench "filter" $ whnf (S.filter ((== 0) . (`mod` 2))) s
, bench "partition" $ whnf (S.partition ((== 0) . (`mod` 2))) s
, bench "fold" $ whnf (S.fold (:) []) s
, bench "insert" $ whnf (ins elems) IS.empty
, bench "map" $ whnf (IS.map (+ 1)) s
, bench "filter" $ whnf (IS.filter ((== 0) . (`mod` 2))) s
, bench "partition" $ whnf (IS.partition ((== 0) . (`mod` 2))) s
, bench "fold" $ whnf (IS.fold (:) []) s
, bench "delete" $ whnf (del elems) s
, bench "findMin" $ whnf S.findMin s
, bench "findMax" $ whnf S.findMax s
, bench "deleteMin" $ whnf S.deleteMin s
, bench "deleteMax" $ whnf S.deleteMax s
, bench "unions" $ whnf S.unions [s_even, s_odd]
, bench "union" $ whnf (S.union s_even) s_odd
, bench "difference" $ whnf (S.difference s) s_even
, bench "intersection" $ whnf (S.intersection s) s_even
, bench "fromList" $ whnf S.fromList elems
, bench "fromAscList" $ whnf S.fromAscList elems
, bench "fromDistinctAscList" $ whnf S.fromDistinctAscList elems
, bench "disjoint:false" $ whnf (S.disjoint s) s_even
, bench "disjoint:true" $ whnf (S.disjoint s_odd) s_even
, bench "null.intersection:false" $ whnf (S.null. S.intersection s) s_even
, bench "null.intersection:true" $ whnf (S.null. S.intersection s_odd) s_even
, bench "findMin" $ whnf IS.findMin s
, bench "findMax" $ whnf IS.findMax s
, bench "deleteMin" $ whnf IS.deleteMin s
, bench "deleteMax" $ whnf IS.deleteMax s
, bench "unions" $ whnf IS.unions [s_even, s_odd]
, bench "union" $ whnf (IS.union s_even) s_odd
, bench "difference" $ whnf (IS.difference s) s_even
, bench "intersection" $ whnf (IS.intersection s) s_even
, bench "fromList" $ whnf IS.fromList elems
, bench "fromAscList" $ whnf IS.fromAscList elems
, bench "fromDistinctAscList" $ whnf IS.fromDistinctAscList elems
, bench "disjoint:false" $ whnf (IS.disjoint s) s_even
, bench "disjoint:true" $ whnf (IS.disjoint s_odd) s_even
, bench "null.intersection:false" $ whnf (IS.null. IS.intersection s) s_even
, bench "null.intersection:true" $ whnf (IS.null. IS.intersection s_odd) s_even
, bench "instanceOrd:dense" -- the IntSet will just use one Tip
$ whnf (num_transitions . det 2 0) $ hard_nfa 1 16
, bench "instanceOrd:sparse" -- many Bin, each Tip is singleton
$ whnf (num_transitions . det 2 0) $ hard_nfa 1111 16
]
where
elems = [1..2^12]
elems_even = [2,4..2^12]
elems_odd = [1,3..2^12]

member :: [Int] -> S.IntSet -> Int
member xs s = foldl' (\n x -> if S.member x s then n + 1 else n) 0 xs
member :: [Int] -> IS.IntSet -> Int
member xs s = foldl' (\n x -> if IS.member x s then n + 1 else n) 0 xs

ins :: [Int] -> S.IntSet -> S.IntSet
ins xs s0 = foldl' (\s a -> S.insert a s) s0 xs
ins :: [Int] -> IS.IntSet -> IS.IntSet
ins xs s0 = foldl' (\s a -> IS.insert a s) s0 xs

del :: [Int] -> S.IntSet -> S.IntSet
del xs s0 = foldl' (\s k -> S.delete k s) s0 xs
del :: [Int] -> IS.IntSet -> IS.IntSet
del xs s0 = foldl' (\s k -> IS.delete k s) s0 xs



-- | Automata contain just the transitions
type NFA = IM.IntMap (IM.IntMap IS.IntSet)
type DFA = IM.IntMap (M.Map IS.IntSet IS.IntSet)

newtype State = State Int deriving (Num, Enum)
instance Show State where show (State s) = show s
newtype Sigma = Sigma Int deriving (Num, Enum, Eq)

num_transitions :: DFA -> Int
num_transitions = getSum . foldMap (Sum . M.size)

det :: Sigma -> State -> NFA -> DFA
det sigma (State initial) aut =
let get :: State -> Sigma -> IS.IntSet
get (State p) (Sigma s) = IM.findWithDefault IS.empty p
$ IM.findWithDefault IM.empty s aut
go :: DFA -> S.Set IS.IntSet -> S.Set IS.IntSet -> DFA
go !accu !done !todo = case S.minView todo of
Nothing -> accu
Just (t, odo) ->
if S.member t done
then go accu done odo
else let ts = do
s <- [0 .. sigma-1]
let next :: IS.IntSet
next = foldMap (\p -> get (State p) s) $ IS.toList t
return (t, s, next)
in go (union_dfa (dfa ts) accu)
(S.insert t done)
(Data.List.foldl' (\ o (_,_,q) -> S.insert q o) odo ts)
in go IM.empty S.empty $ S.singleton $ IS.singleton initial

nfa :: [(State,Sigma,State)] -> NFA
nfa ts = IM.fromListWith ( IM.unionWith IS.union )
$ Prelude.map (\(State p,Sigma s,State q) ->
(s, IM.singleton p (IS.singleton q))) ts

dfa :: [(IS.IntSet, Sigma, IS.IntSet)] -> DFA
dfa ts = IM.fromListWith ( M.unionWith ( error "WAT") )
$ Prelude.map (\( p, Sigma s, q) ->
(s, M.singleton p q)) ts

union_dfa a b = IM.unionWith (M.unionWith (error "WAT")) a b

-- | for the language Sigma^* 1 Sigma^{n-2} where Sigma={0,1}.
-- this NFA has n states. DFA has 2^(n-1) states
-- since it needs to remember the last n characters.
-- Extra parameter delta: the automaton will use states [0, delta .. ]
-- for IntSet, larger deltas should be harder,
-- since for delta=1, all the states do fit in one Tip
hard_nfa :: State -> Int -> NFA
hard_nfa delta n = nfa
$ [ (0, 0, 0), (0,1,0), (0, 1, delta) ]
++ do k <- [1 .. State n - 2] ; c <- [0,1] ; return (delta * k,c,delta *(k+1))
12 changes: 12 additions & 0 deletions containers-tests/tests/intset-properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ main = defaultMain [ testCase "lookupLT" test_lookupLT
, testProperty "prop_EmptyValid" prop_EmptyValid
, testProperty "prop_SingletonValid" prop_SingletonValid
, testProperty "prop_InsertIntoEmptyValid" prop_InsertIntoEmptyValid
, testProperty "prop_instanceEqIntSet" prop_instanceEqIntSet
, testProperty "prop_instanceOrdIntSet" prop_instanceOrdIntSet
, testProperty "prop_Single" prop_Single
, testProperty "prop_Member" prop_Member
, testProperty "prop_NotMember" prop_NotMember
Expand Down Expand Up @@ -141,6 +143,16 @@ prop_InsertIntoEmptyValid :: Int -> Property
prop_InsertIntoEmptyValid x =
valid (insert x empty)

{--------------------------------------------------------------------
Instances for Eq and Ord
--------------------------------------------------------------------}

prop_instanceEqIntSet :: IntSet -> IntSet -> Bool
prop_instanceEqIntSet x y = (x == y) == (toAscList x == toAscList y)

prop_instanceOrdIntSet :: IntSet -> IntSet -> Bool
prop_instanceOrdIntSet x y = (compare x y) == (compare (toAscList x) (toAscList y))

{--------------------------------------------------------------------
Single, Member, Insert, Delete, Member, FromList
--------------------------------------------------------------------}
Expand Down
154 changes: 150 additions & 4 deletions containers/src/Data/IntSet/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ import Utils.Containers.Internal.BitUtil
import Utils.Containers.Internal.StrictPair

#if __GLASGOW_HASKELL__
import Data.Data (Data(..), Constr, mkConstr, constrIndex, Fixity(Prefix), DataType, mkDataType)
import Data.Data (Data(..), Constr, mkConstr, constrIndex, DataType, mkDataType)
import qualified Data.Data
import Text.Read
#endif

Expand Down Expand Up @@ -311,7 +312,7 @@ instance Data IntSet where
dataTypeOf _ = intSetDataType

fromListConstr :: Constr
fromListConstr = mkConstr intSetDataType "fromList" [] Prefix
fromListConstr = mkConstr intSetDataType "fromList" [] Data.Data.Prefix

intSetDataType :: DataType
intSetDataType = mkDataType "Data.IntSet.Internal.IntSet" [fromListConstr]
Expand Down Expand Up @@ -1173,8 +1174,153 @@ nequal _ _ = True
--------------------------------------------------------------------}

instance Ord IntSet where
compare s1 s2 = compare (toAscList s1) (toAscList s2)
-- tentative implementation. See if more efficient exists.
compare Nil Nil = EQ
compare Nil _ = LT
compare _ Nil = GT
compare t1@(Tip _ _) t2@(Tip _ _)
= orderingOf $ relateTipTip t1 t2
compare xs ys
| (xsNeg, xsNonNeg) <- splitSign xs
, (ysNeg, ysNonNeg) <- splitSign ys
= case relate xsNeg ysNeg of
Less -> LT
Prefix -> if null xsNonNeg then LT else GT
Equals -> orderingOf (relate xsNonNeg ysNonNeg)
FlipPrefix -> if null ysNonNeg then GT else LT
Greater -> GT

-- | detailed outcome of lexicographic comparison of lists.
-- w.r.t. Ordering, there are two extra cases,
-- since (++) is not monotonic w.r.t. lex. order on lists
-- (which is used by definition):
-- consider comparison of (Bin [0,3,4] [ 6] ) to (Bin [0,3] [7] )
-- where [0,3,4] > [0,3] but [0,3,4,6] < [0,3,7].

data Relation
= Less -- ^ holds for [0,3,4] [0,3,5,1]
| Prefix -- ^ holds for [0,3,4] [0,3,4,5]
| Equals -- ^ holds for [0,3,4] [0,3,4]
| FlipPrefix -- ^ holds for [0,3,4] [0,3]
| Greater -- ^ holds for [0,3,4] [0,2,5]
deriving (Show, Eq)

orderingOf :: Relation -> Ordering
{-# INLINE orderingOf #-}
orderingOf r = case r of
Less -> LT
Prefix -> LT
Equals -> EQ
FlipPrefix -> GT
Greater -> GT

-- | precondition: each argument is non-mixed
relate :: IntSet -> IntSet -> Relation
relate Nil Nil = Equals
relate Nil t2 = Prefix
relate t1 Nil = FlipPrefix
relate t1@(Tip p1 bm1) t2@(Tip p2 bm2) = relateTipTip t1 t2
relate t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
| succUpperbound t1 <= lowerbound t2 = Less
| lowerbound t1 >= succUpperbound t2 = Greater
| otherwise = case compare (natFromInt m1) (natFromInt m2) of
GT -> combine_left (relate l1 t2)
EQ -> combine (relate l1 l2) (relate r1 r2)
LT -> combine_right (relate t1 l2)
relate t1@(Bin p1 m1 l1 r1) t2@(Tip p2 _)
| succUpperbound t1 <= lowerbound t2 = Less
| lowerbound t1 >= succUpperbound t2 = Greater
| 0 == (m1 .&. p2) = combine_left (relate l1 t2)
| otherwise = Less
relate t1@(Tip p1 _) t2@(Bin p2 m2 l2 r2)
| succUpperbound t1 <= lowerbound t2 = Less
| lowerbound t1 >= succUpperbound t2 = Greater
| 0 == (p1 .&. m2) = combine_right (relate t1 l2)
| otherwise = Greater

relateTipTip :: IntSet -> IntSet -> Relation
{-# INLINE relateTipTip #-}
relateTipTip t1@(Tip p1 bm1) t2@(Tip p2 bm2) = case compare p1 p2 of
LT -> Less
EQ -> relateBM bm1 bm2
GT -> Greater

relateBM :: BitMap -> BitMap -> Relation
{-# inline relateBM #-}
relateBM w1 w2 | w1 == w2 = Equals
relateBM w1 w2 =
let delta = xor w1 w2
lowest_diff_mask = delta .&. complement (delta-1)
prefix = (complement lowest_diff_mask + 1)
.&. (complement lowest_diff_mask)
in if 0 == lowest_diff_mask .&. w1
then if 0 == w1 .&. prefix
then Prefix else Greater
else if 0 == w2 .&. prefix
then FlipPrefix else Less

-- | This function has the property
-- relate t1@(Bin p m l1 r1) t2@(Bin p m l2 r2) = combine (relate l1 l2) (relate r1 r2)
-- It is important that `combine` is lazy in the second argument (achieved by inlining)
combine :: Relation -> Relation -> Relation
{-# inline combine #-}
combine r eq = case r of
Less -> Less
Prefix -> Greater
Equals -> eq
FlipPrefix -> Less
Greater -> Greater

-- | This function has the property
-- relate t1@(Bin p1 m1 l1 r1) t2 = combine_left (relate l1 t2)
-- under the precondition that the range of l1 contains the range of t2,
-- and r1 is non-empty
combine_left :: Relation -> Relation
{-# inline combine_left #-}
combine_left r = case r of
Less -> Less
Prefix -> Greater
Equals -> FlipPrefix
FlipPrefix -> FlipPrefix
Greater -> Greater

-- | This function has the property
-- relate t1 t2@(Bin p2 m2 l2 r2) = combine_right (relate t1 l2)
-- under the precondition that the range of t1 is included in the range of l2,
-- and r2 is non-empty
combine_right :: Relation -> Relation
{-# inline combine_right #-}
combine_right r = case r of
Less -> Less
Prefix -> Prefix
Equals -> Prefix
FlipPrefix -> Less
Greater -> Greater

-- | shall only be applied to non-mixed non-Nil trees
lowerbound :: IntSet -> Int
{-# INLINE lowerbound #-}
lowerbound (Tip p _) = p
lowerbound (Bin p _ _ _) = p

-- | this is one more than the actual upper bound (to save one operation)
-- shall only be applied to non-mixed non-Nil trees
succUpperbound :: IntSet -> Int
{-# INLINE succUpperbound #-}
succUpperbound (Tip p _) = p + wordSize
succUpperbound (Bin p m _ _) = p + shiftR m 1

-- | split a set into subsets of negative and non-negative elements
splitSign :: IntSet -> (IntSet,IntSet)
{-# INLINE splitSign #-}
splitSign t@(Tip kx _)
| kx >= 0 = (Nil, t)
| otherwise = (t, Nil)
splitSign t@(Bin p m l r)
-- m < 0 is the usual way to find out if we have positives and negatives (see findMax)
| m < 0 = (r, l)
| p < 0 = (t, Nil)
| otherwise = (Nil, t)
splitSign Nil = (Nil, Nil)

{--------------------------------------------------------------------
Show
Expand Down

0 comments on commit 7aff529

Please sign in to comment.