Skip to content

Commit

Permalink
Merge pull request #170 from gemini-hlsw/topic/count
Browse files Browse the repository at this point in the history
Add Count query algebra term
  • Loading branch information
milessabin authored Aug 26, 2021
2 parents 13bc886 + e8d4437 commit 17f159a
Show file tree
Hide file tree
Showing 21 changed files with 699 additions and 115 deletions.
15 changes: 14 additions & 1 deletion modules/circe/src/test/scala/CirceData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ package circetests

import cats.Id
import cats.catsInstancesForId
import cats.implicits._

import edu.gemini.grackle.circe.CirceMapping
import edu.gemini.grackle.syntax._

import Query._
import QueryCompiler._

object TestCirceMapping extends CirceMapping[Id] {
val schema =
schema"""
Expand All @@ -23,8 +27,9 @@ object TestCirceMapping extends CirceMapping[Id] {
string: String
id: ID
choice: Choice
arrary: [Int!]
array: [Int!]
object: A
numChildren: Int
children: [Child!]!
}
enum Choice {
Expand All @@ -46,6 +51,7 @@ object TestCirceMapping extends CirceMapping[Id] {
"""

val QueryType = schema.ref("Query")
val RootType = schema.ref("Root")

val data =
json"""
Expand Down Expand Up @@ -84,4 +90,11 @@ object TestCirceMapping extends CirceMapping[Id] {
)
)
)

override val selectElaborator = new SelectElaborator(Map(
RootType -> {
case Select("numChildren", Nil, Empty) =>
Count("numChildren", Select("children", Nil, Empty)).rightIor
}
))
}
36 changes: 36 additions & 0 deletions modules/circe/src/test/scala/CirceSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,40 @@ final class CirceSpec extends CatsSuite {

assert(res == expected)
}

test("count") {
val query = """
query {
root {
numChildren
children {
id
}
}
}
"""

val expected = json"""
{
"data" : {
"root" : {
"numChildren" : 2,
"children" : [
{
"id" : "a"
},
{
"id" : "b"
}
]
}
}
}
"""

val res = TestCirceMapping.compileAndRun(query)
//println(res)

assert(res == expected)
}
}
2 changes: 2 additions & 0 deletions modules/core/src/main/scala/compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ object QueryCompiler {
case n@Narrow(subtpe, child) => transform(child, vars, schema, subtpe).map(ec => n.copy(child = ec))
case w@Wrap(_, child) => transform(child, vars, schema, tpe).map(ec => w.copy(child = ec))
case r@Rename(_, child) => transform(child, vars, schema, tpe).map(ec => r.copy(child = ec))
case c@Count(_, child) => transform(child, vars, schema, tpe).map(ec => c.copy(child = ec))
case g@Group(children) => children.traverse(q => transform(q, vars, schema, tpe)).map(eqs => g.copy(queries = eqs))
case g@GroupList(children) => children.traverse(q => transform(q, vars, schema, tpe)).map(eqs => g.copy(queries = eqs))
case u@Unique(child) => transform(child, vars, schema, tpe.nonNull.list).map(ec => u.copy(child = ec))
Expand Down Expand Up @@ -689,6 +690,7 @@ object QueryCompiler {
def loop(q: Query, depth: Int, width: Int, group: Boolean): (Int, Int) =
q match {
case Select(_, _, Empty) => if (group) (depth, width + 1) else (depth + 1, width + 1)
case Count(_, _) => if (group) (depth, width + 1) else (depth + 1, width + 1)
case Select(_, _, child) => if (group) loop(child, depth, width, false) else loop(child, depth + 1, width, false)
case Group(queries) => handleGroupedQueries(queries, depth, width)
case GroupList(queries) => handleGroupedQueries(queries, depth, width)
Expand Down
4 changes: 4 additions & 0 deletions modules/core/src/main/scala/mapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ abstract class Mapping[F[_]](implicit val M: Monad[F]) extends QueryExecutor[F,
def pos: SourcePos
}

case class PrimitiveField(fieldName: String, hidden: Boolean = false)(implicit val pos: SourcePos) extends FieldMapping {
def withParent(tpe: Type): PrimitiveField = this
}

/**
* Root mappings can perform a mutation prior to constructing the result `Cursor`. A `Mutation`
* may perform a Unit effect and simply return the passed arguments; or it may refine the passed
Expand Down
7 changes: 5 additions & 2 deletions modules/core/src/main/scala/query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ object Query {
case class Select(name: String, args: List[Binding], child: Query = Empty) extends Query {
def eliminateArgs(elim: Query => Query): Query = copy(args = Nil, child = elim(child))

def transformChild(f: Query => Query): Query = copy(child = f(child))

def render = {
val rargs = if(args.isEmpty) "" else s"(${args.map(_.render).mkString(", ")})"
val rchild = if(child == Empty) "" else s" { ${child.render} }"
Expand Down Expand Up @@ -174,6 +172,11 @@ object Query {
}
}

/** Computes the number of top-level elements of `child` as field `name` */
case class Count(name: String, child: Query) extends Query {
def render = s"$name:count { ${child.render} }"
}

/** A placeholder for a skipped node */
case object Skipped extends Query {
def render = "<skipped>"
Expand Down
20 changes: 17 additions & 3 deletions modules/core/src/main/scala/queryinterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,29 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
} yield List((resultName, value))

case (Rename(resultName, Wrap(_, child)), tpe) =>
for {
value <- runValue(child, tpe, cursor)
} yield List((resultName, value))
runFields(Wrap(resultName, child), tpe, cursor)

case (Wrap(fieldName, child), tpe) =>
for {
value <- runValue(child, tpe, cursor)
} yield List((fieldName, value))

case (Rename(resultName, Count(_, child)), tpe) =>
runFields(Count(resultName, child), tpe, cursor)

case (Count(fieldName, Select(countName, _, _)), _) =>
cursor.field(countName, None).flatMap { c0 =>
if (c0.isNullable)
c0.asNullable.flatMap {
case None => 0.rightIor
case Some(c1) =>
if (c1.isList) c1.asList.map(c2 => c2.size)
else 1.rightIor
}
else if (c0.isList) c0.asList.map(c2 => c2.size)
else 1.rightIor
}.map { value => List((fieldName, ProtoJson.fromJson(Json.fromInt(value)))) }

case (Group(siblings), _) =>
siblings.flatTraverse(query => runFields(query, tpe, cursor))

Expand Down
19 changes: 17 additions & 2 deletions modules/core/src/test/scala/starwars/StarWarsData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,22 @@ object StarWarsMapping extends ValueMapping[Id] {
interface Character {
id: String!
name: String
numberOfFriends: Int
friends: [Character!]
appearsIn: [Episode!]
}
type Human implements Character {
id: String!
name: String
numberOfFriends: Int
friends: [Character!]
appearsIn: [Episode!]
homePlanet: String
}
type Droid implements Character {
id: String!
name: String
numberOfFriends: Int
friends: [Character!]
appearsIn: [Episode!]
primaryFunction: String
Expand All @@ -154,6 +157,7 @@ object StarWarsMapping extends ValueMapping[Id] {
val CharacterType = schema.ref("Character")
val HumanType = schema.ref("Human")
val DroidType = schema.ref("Droid")
val EpisodeType = schema.ref("Episode")

val typeMappings =
List(
Expand All @@ -174,6 +178,7 @@ object StarWarsMapping extends ValueMapping[Id] {
ValueField("id", _.id),
ValueField("name", _.name),
ValueField("appearsIn", _.appearsIn),
PrimitiveField("numberOfFriends"),
ValueField("friends", resolveFriends _)
)
),
Expand All @@ -190,17 +195,27 @@ object StarWarsMapping extends ValueMapping[Id] {
List(
ValueField("primaryFunction", _.primaryFunction)
)
)
),
PrimitiveMapping(EpisodeType)
)

val numberOfFriends: PartialFunction[Query, Result[Query]] = {
case Select("numberOfFriends", Nil, Empty) =>
Count("numberOfFriends", Select("friends", Nil, Empty)).rightIor
}


override val selectElaborator = new SelectElaborator(Map(
QueryType -> {
case Select("hero", List(Binding("episode", TypedEnumValue(e))), child) =>
val episode = Episode.values.find(_.toString == e.name).get
Select("hero", Nil, Unique(Filter(Eql(UniquePath(List("id")), Const(hero(episode).id)), child))).rightIor
case Select(f@("character" | "human" | "droid"), List(Binding("id", IDValue(id))), child) =>
Select(f, Nil, Unique(Filter(Eql(UniquePath(List("id")), Const(id)), child))).rightIor
}
},
CharacterType -> numberOfFriends,
HumanType -> numberOfFriends,
DroidType -> numberOfFriends
))

val querySizeValidator = new QuerySizeValidator(5, 5)
Expand Down
96 changes: 96 additions & 0 deletions modules/core/src/test/scala/starwars/StarWarsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import edu.gemini.grackle.syntax._

final class StarWarsSpec extends CatsSuite {

test("validate mapping") {
val es = StarWarsMapping.validator.validateMapping()
es match {
case Nil => succeed
case _ => fail(es.foldMap(_.toErrorMessage))
}
}

test("simple query") {
val query = """
query {
Expand Down Expand Up @@ -313,4 +321,92 @@ final class StarWarsSpec extends CatsSuite {

assert(res == expected)
}

test("count") {
val query = """
query {
human(id: "1000") {
name
numberOfFriends
friends {
name
}
}
}
"""

val expected = json"""
{
"data" : {
"human" : {
"name" : "Luke Skywalker",
"numberOfFriends" : 4,
"friends" : [
{
"name" : "Han Solo"
},
{
"name" : "Leia Organa"
},
{
"name" : "C-3PO"
},
{
"name" : "R2-D2"
}
]
}
}
}
"""

val res = StarWarsMapping.compileAndRun(query)
//println(res)

assert(res == expected)
}

test("renamed count") {
val query = """
query {
human(id: "1000") {
name
num:numberOfFriends
friends {
name
}
}
}
"""

val expected = json"""
{
"data" : {
"human" : {
"name" : "Luke Skywalker",
"num" : 4,
"friends" : [
{
"name" : "Han Solo"
},
{
"name" : "Leia Organa"
},
{
"name" : "C-3PO"
},
{
"name" : "R2-D2"
}
]
}
}
}
"""

val res = StarWarsMapping.compileAndRun(query)
//println(res)

assert(res == expected)
}
}
13 changes: 8 additions & 5 deletions modules/doobie/src/main/scala/DoobieMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
package edu.gemini.grackle
package doobie

import java.sql.ResultSet

import cats.Reducible
import edu.gemini.grackle.sql._
import cats.effect.Sync
import _root_.doobie.{ Meta, Put, Read, Transactor, Fragment => DoobieFragment }
import _root_.doobie.enumerated.JdbcType._
import _root_.doobie.enumerated.Nullability.{ NoNulls, Nullable }
import _root_.doobie.implicits._
import _root_.doobie.util.fragments

import java.sql.ResultSet
import org.tpolecat.sourcepos.SourcePos
import org.tpolecat.typename.TypeName

import edu.gemini.grackle.sql._

abstract class DoobieMapping[F[_]: Sync](
val transactor: Transactor[F],
val monitor: DoobieMonitor[F]
Expand All @@ -33,8 +35,8 @@ abstract class DoobieMapping[F[_]: Sync](
def doubleEncoder = (Put[Double], false)

class TableDef(name: String) {
def col[T: TypeName](colName: String, codec: Meta[T], nullable: Boolean = false): ColumnRef =
ColumnRef(name, colName, (codec, nullable))
def col[T](colName: String, codec: Meta[T], nullable: Boolean = false)(implicit typeName: TypeName[T], pos: SourcePos): Column.ColumnRef =
Column.ColumnRef(name, colName, (codec, nullable), typeName.value, pos)
}

implicit def Fragments: SqlFragment[Fragment] =
Expand All @@ -53,6 +55,7 @@ abstract class DoobieMapping[F[_]: Sync](
}
}
def const(s: String): Fragment = DoobieFragment.const(s)
def and(fs: Fragment*): Fragment = fragments.and(fs: _*)
def andOpt(fs: Option[Fragment]*): Fragment = fragments.andOpt(fs: _*)
def orOpt(fs: Option[Fragment]*): Fragment = fragments.orOpt(fs: _*)
def whereAnd(fs: Fragment*): Fragment = fragments.whereAnd(fs: _*)
Expand Down
4 changes: 2 additions & 2 deletions modules/doobie/src/test/scala/world/WorldCompilerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ final class WorldCompilerSpec extends DatabaseSuite with SqlWorldCompilerSpec {
WorldMapping.mkMapping(xa, monitor)

def simpleRestrictedQuerySql: String =
"SELECT country.name, country.code FROM (SELECT country.name, country.code FROM country INNER JOIN (SELECT DISTINCT country.code FROM country WHERE (country.code IS NOT NULL ) AND (country.code = ?) ) AS pred_0 ON pred_0.code = country.code ) AS country"
"SELECT country.name , country.code FROM (SELECT country.name , country.code FROM country INNER JOIN (SELECT DISTINCT country.code FROM country WHERE (country.code IS NOT NULL ) AND (country.code = ?) ) AS pred_0 ON (pred_0.code = country.code ) ) AS country"

def simpleFilteredQuerySql: String =
"SELECT city.name, city.id FROM (SELECT city.name, city.id FROM city INNER JOIN (SELECT DISTINCT city.id FROM city WHERE (city.id IS NOT NULL ) AND (city.name ILIKE ?) ) AS pred_0 ON pred_0.id = city.id ) AS city"
"SELECT city.name , city.id FROM (SELECT city.name , city.id FROM city INNER JOIN (SELECT DISTINCT city.id FROM city WHERE (city.id IS NOT NULL ) AND (city.name ILIKE ?) ) AS pred_0 ON (pred_0.id = city.id ) ) AS city"
}
Loading

0 comments on commit 17f159a

Please sign in to comment.