Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid rebinding PrimExprs that were created in the Aggregator itself #586

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Test/Opaleye/Test/Arbitrary.hs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ genSelectArr =
genSelectMapper :: [TQ.Gen (O.Select Fields -> O.Select Fields)]
genSelectMapper =
[ do
return (O.distinctExplicit distinctFields)
return (O.distinctExplicit unpackFields distinctFields)
, do
ArbitraryPositiveInt l <- TQ.arbitrary
return (O.limit l)
Expand All @@ -481,9 +481,8 @@ genSelectMapper =
, do
o <- TQ.arbitrary
return (O.orderBy (arbitraryOrder o))

, do
return (O.aggregate aggregateFields)
return (O.aggregateExplicit unpackFields aggregateFields)
, do
let q' q = P.dimap (\_ -> fst . firstBoolOrTrue (O.sqlBool True))
(fieldsList
Expand Down
4 changes: 2 additions & 2 deletions Test/QuickCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ order o (ArbitrarySelect q) =

distinct :: ArbitrarySelect -> Connection -> IO TQ.Property
distinct =
compareDenotation' (O.distinctExplicit distinctFields) nub
compareDenotation' (O.distinctExplicit unpackFields distinctFields) nub

-- When we generalise compareDenotation... we can just test
--
Expand All @@ -455,7 +455,7 @@ valuesEmpty l =

aggregate :: ArbitrarySelect -> Connection -> IO TQ.Property
aggregate =
compareDenotationNoSort' (O.aggregate aggregateFields)
compareDenotationNoSort' (O.aggregateExplicit unpackFields aggregateFields)
aggregateDenotation


Expand Down
20 changes: 18 additions & 2 deletions Test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ testStringArrayAggregateOrdered = it "" $ q `selectShouldReturnSorted` expected


testStringArrayAggregateOrderedDistinct :: Test
testStringArrayAggregateOrderedDistinct = xit "" $ q `selectShouldReturnSorted` expected
testStringArrayAggregateOrderedDistinct = it "" $ q `selectShouldReturnSorted` expected
where q =
O.aggregateOrdered
(O.asc snd)
Expand All @@ -632,6 +632,21 @@ testStringArrayAggregateOrderedDistinct = xit "" $ q `selectShouldReturnSorted`
]
sortedData = L.sortBy (Ord.comparing snd) table7data

-- See
--
-- https://github.com/tomjaguarpaw/haskell-opaleye/pull/578#issuecomment-1782638274
testStringArrayAggregateOrderedDistinctDuplicateFields :: Test
testStringArrayAggregateOrderedDistinctDuplicateFields = xit "" $ q `selectShouldReturnSorted` expected
where q =
O.aggregateOrdered
(O.asc (\x -> snd x O..++ snd x))
(PP.p2 (O.arrayAgg, O.distinctAggregator . O.stringAgg . O.sqlString $ ","))
table7Q
expected = [ ( map fst sortedData
, L.intercalate "," $ map NE.head $ NE.group $ map snd sortedData
)
]
sortedData = L.sortBy (Ord.comparing snd) table7data

-- | Using orderAggregate you can apply different orderings to
-- different aggregates.
Expand Down Expand Up @@ -1462,7 +1477,7 @@ testUnnest = do

testSetAggregate :: Test
testSetAggregate = do
xit "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation])
it "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation])
where query :: Select (Field O.SqlFloat8)
query = O.aggregate median (O.values as)

Expand Down Expand Up @@ -1559,6 +1574,7 @@ main = do
testMultipleAggregateOrdered
testStringArrayAggregateOrdered
testStringArrayAggregateOrderedDistinct
testStringArrayAggregateOrderedDistinctDuplicateFields
testDistinctAndAggregate
testDoubleAggregate
testSetAggregate
Expand Down
22 changes: 15 additions & 7 deletions src/Opaleye/Aggregate.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}

-- | Perform aggregation on 'S.Select's. To aggregate a 'S.Select' you
-- should construct an 'Aggregator' encoding how you want the
Expand Down Expand Up @@ -33,11 +34,14 @@ module Opaleye.Aggregate
, stringAgg
-- * Counting rows
, countRows
-- * Explicit
, aggregateExplicit
) where

import Control.Arrow (second)
import Control.Arrow (second, (<<<))
import Data.Profunctor (lmap)
import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product.Default as D

import qualified Opaleye.Internal.Aggregate as A
import Opaleye.Internal.Aggregate (Aggregator, orderAggregate)
Expand All @@ -46,6 +50,7 @@ import qualified Opaleye.Internal.QueryArr as Q
import qualified Opaleye.Internal.HaskellDB.PrimQuery as HPQ
import qualified Opaleye.Internal.Operators as O
import qualified Opaleye.Internal.PackMap as PM
import Opaleye.Internal.Rebind (rebindExplicit)
import qualified Opaleye.Internal.Tag as Tag
import qualified Opaleye.Internal.Unpackspec as U

Expand Down Expand Up @@ -85,11 +90,8 @@ result of an aggregation.
-}
-- See 'Opaleye.Internal.Sql.aggregate' for details of how aggregating
-- by an empty query with no group by is handled.
aggregate :: Aggregator a b -> S.Select a -> S.Select b
aggregate agg q = Q.productQueryArr $ do
(a, pq) <- Q.runSimpleSelect q
t <- Tag.fresh
pure (second ($ pq) (A.aggregateU agg (a, t)))
aggregate :: D.Default U.Unpackspec a a => Aggregator a b -> S.Select a -> S.Select b
aggregate = aggregateExplicit D.def

-- | Order the values within each aggregation in `Aggregator` using
-- the given ordering. This is only relevant for aggregations that
Expand All @@ -100,7 +102,13 @@ aggregate agg q = Q.productQueryArr $ do
-- you need different orderings for different aggregations, use
-- 'Opaleye.Internal.Aggregate.orderAggregate'.

aggregateOrdered :: Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b
aggregateExplicit :: U.Unpackspec a a' -> Aggregator a' b -> S.Select a -> S.Select b
aggregateExplicit u agg q = Q.productQueryArr $ do
(a, pq) <- Q.runSimpleSelect (rebindExplicit u <<< q)
t <- Tag.fresh
pure (second ($ pq) (A.aggregateU agg (a, t)))

aggregateOrdered :: D.Default U.Unpackspec a a => Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b
aggregateOrdered o agg = aggregate (orderAggregate o agg)

-- | Aggregate only distinct values
Expand Down
4 changes: 3 additions & 1 deletion src/Opaleye/Distinct.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Opaleye.Internal.Distinct
import Opaleye.Order

import qualified Data.Profunctor.Product.Default as D
import Opaleye.Internal.Unpackspec (Unpackspec)

-- | Remove duplicate rows from the 'Select'.
--
Expand All @@ -40,5 +41,6 @@ import qualified Data.Profunctor.Product.Default as D
-- 'Opaleye.Lateral.laterally' 'distinct' :: 'Data.Profunctor.Product.Default' 'Distinctspec' fields fields => 'Opaleye.Select.SelectArr' i fields -> 'Opaleye.Select.SelectArr' i fields
-- @
distinct :: D.Default Distinctspec fields fields =>
D.Default Unpackspec fields fields =>
Select fields -> Select fields
distinct = distinctExplicit D.def
distinct = distinctExplicit D.def D.def
34 changes: 11 additions & 23 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
module Opaleye.Internal.Aggregate where

import Control.Applicative (liftA2)
import Data.Foldable (toList)

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.4)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.4)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Foldable’ is redundant

Check warning on line 5 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Foldable’ is redundant
import Data.Traversable (for)

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.4)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.4)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Traversable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Traversable’ is redundant

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP
Expand Down Expand Up @@ -129,43 +129,31 @@
-- https://github.com/tomjaguarpaw/haskell-opaleye/pull/460#issuecomment-626716160
--
-- Instead of detecting when we are aggregating over a field from a
-- previous query we just create new names for all field before we
-- previous query we just create new names for all fields before we
-- aggregate. On the other hand, referring to a field from a previous
-- query in an ORDER BY expression is totally fine!
aggregateU :: Aggregator a b
-> (a, T.Tag) -> (b, PQ.PrimQuery -> PQ.PrimQuery)
aggregateU agg (c0, t0) = (c1, primQ')
where (c1, projPEs_inners) =
where projPEs_inners :: PQ.Bindings HPQ.Aggregate
(c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
projPEs = projPEs_inners

primQ' = PQ.Aggregate projPEs . PQ.Rebind True inners
primQ' = PQ.Aggregate projPEs

extractAggregateFields
:: Traversable t
=> T.Tag
-> t HPQ.PrimExpr
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
:: T.Tag
-> HPQ.Aggregate
-> PM.PM (PQ.Bindings HPQ.Aggregate) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new
let sinner = HPQ.Symbol ("result" ++ i) tag

let souter = HPQ.Symbol ("result" ++ i) tag
PM.write (sinner, agg)

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings

PM.write ((souter, agg'), toList bindings)

pure (HPQ.AttrExpr souter)
pure (HPQ.AttrExpr sinner)

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
8 changes: 5 additions & 3 deletions src/Opaleye/Internal/Distinct.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@ module Opaleye.Internal.Distinct where
import qualified Opaleye.Internal.MaybeFields as M
import Opaleye.Select (Select)
import Opaleye.Field (Field_)
import Opaleye.Aggregate (Aggregator, groupBy, aggregate)
import Opaleye.Aggregate (Aggregator, groupBy, aggregateExplicit)

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP
import Data.Profunctor.Product.Default (Default, def)
import Opaleye.Internal.Unpackspec (Unpackspec)

-- We implement distinct simply by grouping by all columns. We could
-- instead implement it as SQL's DISTINCT but implementing it in terms
-- of something else that we already have is easier at this point.

distinctExplicit :: Distinctspec fields fields'
distinctExplicit :: Unpackspec fields fields
-> Distinctspec fields fields'
-> Select fields -> Select fields'
distinctExplicit (Distinctspec agg) = aggregate agg
distinctExplicit u (Distinctspec agg) = aggregateExplicit u agg

newtype Distinctspec a b = Distinctspec (Aggregator a b)

Expand Down
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ data PrimQuery' a = Unit
| Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr]
-- | The subqueries to take the product of and the
-- restrictions to apply
| Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol))
| Aggregate (Bindings HPQ.Aggregate)
(PrimQuery' a)
| Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a)
-- | Represents both @DISTINCT ON@ and @ORDER BY@
Expand Down Expand Up @@ -178,7 +178,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold
, empty :: a -> p'
, baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p'
, product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p'
, aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol)
, aggregate :: Bindings HPQ.Aggregate
-> p
-> p'
, window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p'
Expand Down
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ product ss pes = SelectFrom $
PQ.Lateral -> Lateral
PQ.NonLateral -> NonLateral

aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol)
aggregate :: PQ.Bindings HPQ.Aggregate
-> Select
-> Select
aggregate aggrs' s =
Expand Down Expand Up @@ -191,7 +191,7 @@ aggregate aggrs' s =
handleEmpty = ensureColumnsGen SP.deliteral

aggrs :: [(Symbol, HPQ.Aggregate)]
aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs'
aggrs = aggrs'

groupBy' :: [(symbol, HPQ.Aggregate)]
-> NEL.NonEmpty HSql.SqlExpr
Expand Down
Loading