Skip to content

Commit

Permalink
Fix tapir-based subscriptions (ghostdogpr#1205)
Browse files Browse the repository at this point in the history
* Fix tapir subscriptions

* Cleanup

* Cleanup

* Fix keep alive
  • Loading branch information
ghostdogpr authored Dec 14, 2021
1 parent a196173 commit c4e0345
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 99 deletions.
8 changes: 4 additions & 4 deletions adapters/zio-http/src/main/scala/caliban/ZHttpAdapter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object ZHttpAdapter {
case _ => Stream.empty
}

val response = connectionAck ++ keepAlive(keepAliveTime)
val response = ZStream.succeed(connectionAck) ++ keepAlive(keepAliveTime)

val after = webSocketHooks.afterInit match {
case Some(afterInit) => ZStream.fromEffect(afterInit).drain.catchAll(toStreamError(id, _))
Expand Down Expand Up @@ -109,14 +109,14 @@ object ZHttpAdapter {

webSocketHooks.onMessage.map(_.transform(stream)).getOrElse(stream).catchAll(toStreamError(id, _))

case None => connectionError
case None => ZStream.succeed(connectionError)
}
case GraphQLWSInput("stop", id, _) =>
removeSubscription(id, subscriptions) *> ZStream.empty
ZStream.fromEffect(removeSubscription(id, subscriptions)) *> ZStream.empty

}
.flatten
.catchAll(_ => connectionError)
.catchAll(_ => ZStream.succeed(connectionError))
.map(output => WebSocketFrame.Text(output.asJson.noSpaces))
}

Expand Down
173 changes: 84 additions & 89 deletions interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,66 +224,68 @@ object TapirAdapter {
): ServerEndpoint[ZioWebSockets, RIO[R, *]] = {

val io: URIO[R, Either[Nothing, CalibanPipe]] =
RIO
.environment[R]
.flatMap(env =>
Ref
.make(Map.empty[String, Promise[Any, Unit]])
.flatMap(subscriptions =>
UIO.right[CalibanPipe](
_.collect {
case GraphQLWSInput("connection_init", id, payload) =>
val before = (webSocketHooks.beforeInit, payload) match {
case (Some(beforeInit), Some(payload)) =>
ZStream.fromEffect(beforeInit(payload)).drain.catchAll(toStreamError(id, _))
case _ => Stream.empty
}

val response = connectionAck ++ keepAlive(keepAliveTime)

val after = webSocketHooks.afterInit match {
case Some(afterInit) => ZStream.fromEffect(afterInit).drain.catchAll(toStreamError(id, _))
case _ => Stream.empty
}

before ++ ZStream.mergeAllUnbounded()(response, after)
case GraphQLWSInput("start", id, payload) =>
val request = payload.collect { case InputValue.ObjectValue(fields) =>
val query = fields.get("query").collect { case StringValue(v) => v }
val operationName = fields.get("operationName").collect { case StringValue(v) => v }
val variables = fields.get("variables").collect { case InputValue.ObjectValue(v) => v }
val extensions = fields.get("extensions").collect { case InputValue.ObjectValue(v) => v }
GraphQLRequest(query, operationName, variables, extensions)
}
request match {
case Some(req) =>
val stream = generateGraphQLResponse(
req,
id.getOrElse(""),
interpreter,
skipValidation,
enableIntrospection,
queryExecution,
subscriptions
)
webSocketHooks.onMessage
.map(_.transform(stream))
.getOrElse(stream)
.catchAll(toStreamError(id, _))

case None => connectionError
}
case GraphQLWSInput("stop", id, _) =>
removeSubscription(id, subscriptions) *> ZStream.empty
case GraphQLWSInput("connection_terminate", _, _) =>
ZStream.fromEffect(ZIO.interrupt)
}.flatten
.catchAll(_ => connectionError)
.ensuring(subscriptions.get.flatMap(m => ZIO.foreach(m.values)(_.succeed(()))))
.provide(env)
)
)
)
for {
env <- RIO.environment[R]
subscriptions <- Ref.make(Map.empty[String, Promise[Any, Unit]])
output <- Queue.unbounded[GraphQLWSOutput]
pipe <- UIO.right[CalibanPipe] { input =>
ZStream
.bracket(
input.collectM {
case GraphQLWSInput("connection_init", id, payload) =>
val before = ZIO.whenCase((webSocketHooks.beforeInit, payload)) {
case (Some(beforeInit), Some(payload)) =>
beforeInit(payload).catchAll(e => output.offer(makeError(id, e)))
}
val response = output.offer(connectionAck)
val ka = keepAlive(keepAliveTime).mapM(output.offer).runDrain.fork
val after = ZIO.whenCase(webSocketHooks.afterInit) { case Some(afterInit) =>
afterInit.catchAll(e => output.offer(makeError(id, e)))
}

before *> response *> ka *> after
case GraphQLWSInput("start", id, payload) =>
val request = payload.collect { case InputValue.ObjectValue(fields) =>
val query = fields.get("query").collect { case StringValue(v) => v }
val operationName = fields.get("operationName").collect { case StringValue(v) => v }
val variables = fields.get("variables").collect { case InputValue.ObjectValue(v) => v }
val extensions = fields.get("extensions").collect { case InputValue.ObjectValue(v) => v }
GraphQLRequest(query, operationName, variables, extensions)
}
request match {
case Some(req) =>
val stream = generateGraphQLResponse(
req,
id.getOrElse(""),
interpreter,
skipValidation,
enableIntrospection,
queryExecution,
subscriptions
)
webSocketHooks.onMessage
.map(_.transform(stream))
.getOrElse(stream)
.mapM(output.offer)
.runDrain
.catchAll(e => output.offer(makeError(id, e)))
.fork
.unit

case None => output.offer(connectionError)
}
case GraphQLWSInput("stop", id, _) =>
removeSubscription(id, subscriptions)
case GraphQLWSInput("connection_terminate", _, _) =>
ZIO.interrupt
}.runDrain
.catchAll(_ => output.offer(connectionError))
.ensuring(subscriptions.get.flatMap(m => ZIO.foreach(m.values)(_.succeed(()))))
.provide(env)
.forkDaemon
)(_.interrupt) *> ZStream.fromQueueWithShutdown(output)
}
} yield pipe

makeWebSocketEndpoint.serverLogic[RIO[R, *]](serverRequest =>
requestInterceptor(serverRequest).foldM(statusCode => ZIO.left(statusCode), _ => io)
Expand Down Expand Up @@ -322,10 +324,8 @@ object TapirAdapter {
.provideLayer(Clock.live)
}

private[caliban] val connectionError: UStream[GraphQLWSOutput] =
ZStream.succeed(GraphQLWSOutput("connection_error", None, None))
private[caliban] val connectionAck: UStream[GraphQLWSOutput] =
ZStream.succeed(GraphQLWSOutput("connection_ack", None, None))
private[caliban] val connectionError: GraphQLWSOutput = GraphQLWSOutput("connection_error", None, None)
private[caliban] val connectionAck: GraphQLWSOutput = GraphQLWSOutput("connection_ack", None, None)

type Subscriptions = Ref[Map[String, Promise[Any, Unit]]]

Expand Down Expand Up @@ -355,41 +355,36 @@ object TapirAdapter {
(resp ++ complete(id)).catchAll(toStreamError(Option(id), _))
}

private[caliban] def trackSubscription(id: String, subs: Subscriptions): UStream[Promise[Any, Unit]] =
private def trackSubscription(id: String, subs: Subscriptions): UStream[Promise[Any, Unit]] =
ZStream.fromEffect(Promise.make[Any, Unit].tap(p => subs.update(_.updated(id, p))))

private[caliban] def removeSubscription(id: Option[String], subs: Subscriptions): UStream[Unit] =
ZStream
.fromEffect(IO.whenCase(id) { case Some(id) =>
subs.modify(map => (map.get(id), map - id)).flatMap { p =>
IO.whenCase(p) { case Some(p) => p.succeed(()) }
}
})
private[caliban] def removeSubscription(id: Option[String], subs: Subscriptions): UIO[Unit] =
IO.whenCase(id) { case Some(id) =>
subs.modify(map => (map.get(id), map - id)).flatMap { p =>
IO.whenCase(p) { case Some(p) => p.succeed(()) }
}
}

private[caliban] def toStreamError[E](id: Option[String], e: E): UStream[GraphQLWSOutput] =
ZStream.succeed(
GraphQLWSOutput(
"error",
id,
Some(ResponseValue.ListValue(List(e match {
case e: CalibanError => e.toResponseValue
case e => StringValue(e.toString)
})))
)
ZStream.succeed(makeError(id, e))

private def makeError[E](id: Option[String], e: E): GraphQLWSOutput =
GraphQLWSOutput(
"error",
id,
Some(ResponseValue.ListValue(List(e match {
case e: CalibanError => e.toResponseValue
case e => StringValue(e.toString)
})))
)

private[caliban] def complete(id: String): UStream[GraphQLWSOutput] =
private def complete(id: String): UStream[GraphQLWSOutput] =
ZStream.succeed(GraphQLWSOutput("complete", Some(id), None))

private[caliban] def toResponse[E](
id: String,
fieldName: String,
r: ResponseValue,
errors: List[E]
): GraphQLWSOutput =
private def toResponse[E](id: String, fieldName: String, r: ResponseValue, errors: List[E]): GraphQLWSOutput =
toResponse(id, GraphQLResponse(ObjectValue(List(fieldName -> r)), errors))

private[caliban] def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput =
private def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput =
GraphQLWSOutput("data", Some(id), Some(r.toResponseValue))

private def parsePath(path: String): List[Either[String, Int]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,23 @@ object TapirAdapterSpec {
Some(ObjectValue(Map("query" -> StringValue("subscription { characterDeleted }"))))
)
)
sendDelete =
send(run((GraphQLRequest(Some("""mutation{ deleteCharacter(name: "Amos Burton") }""")), null)))
.delay(3 seconds)
sendDelete = send(
run((GraphQLRequest(Some("""mutation{ deleteCharacter(name: "Amos Burton") }""")), null))
).delay(3 seconds)
stop = inputQueue.offer(GraphQLWSInput("stop", Some("id"), None))
messages <- outputStream
.tap(out => ZIO.when(out.`type` == "connection_ack")(sendDelete))
.take(2)
.tap(out => ZIO.when(out.`type` == "data")(stop))
.take(3)
.runCollect
.timeoutFail(new Throwable("timeout ws"))(30.seconds)
.provideSomeLayer[SttpClient](Clock.live)
} yield messages

io.map { messages =>
assert(messages.head.`type`)(equalTo("connection_ack")) &&
assert(messages(1).payload.get.toString)(equalTo("""{"data":{"characterDeleted":"Amos Burton"}}"""))
assertTrue(messages.head.`type` == "connection_ack") &&
assertTrue(messages(1).payload.get.toString == """{"data":{"characterDeleted":"Amos Burton"}}""") &&
assertTrue(messages(2).`type` == "complete")
}
}
)
Expand Down

0 comments on commit c4e0345

Please sign in to comment.