Skip to content

Commit

Permalink
Merge branch 'master' into kotlin-protos
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesward authored Oct 12, 2021
2 parents 23c167a + 9e98ca2 commit c0d1b97
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import io.grpc.StatusException
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import io.grpc.Context as GrpcContext
Expand Down Expand Up @@ -33,6 +34,9 @@ abstract class CoroutineContextServerInterceptor : ServerInterceptor {
* server object.
*
* This function will be called each time a [call] is executed.
*
* @throws StatusException if the call should be closed with the [Status][io.grpc.Status] in the
* exception and further processing suppressed
*/
abstract fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext

Expand All @@ -49,8 +53,15 @@ abstract class CoroutineContextServerInterceptor : ServerInterceptor {
call: ServerCall<ReqT, RespT>,
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT> =
withGrpcContext(GrpcContext.current().extendCoroutineContext(coroutineContext(call, headers))) {
): ServerCall.Listener<ReqT> {
val coroutineContext = try {
coroutineContext(call, headers)
} catch (e: StatusException) {
call.close(e.status, e.trailers ?: Metadata())
throw e
}
return withGrpcContext(GrpcContext.current().extendCoroutineContext(coroutineContext)) {
next.startCall(call, headers)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package io.grpc.kotlin
import com.google.common.truth.Truth.assertThat
import io.grpc.ServerCall
import io.grpc.ServerInterceptors
import io.grpc.Status
import io.grpc.StatusException
import io.grpc.StatusRuntimeException
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub
import io.grpc.examples.helloworld.HelloReply
Expand Down Expand Up @@ -97,4 +100,46 @@ class CoroutineContextServerInterceptorTest : AbstractCallsTest() {
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("interceptor")
}
}
}

@Test
fun statusExceptionThrownFromCoroutineContextClosesCall() {
val interceptor = object : CoroutineContextServerInterceptor() {
override fun coroutineContext(
call: ServerCall<*, *>,
headers: GrpcMetadata
): CoroutineContext {
throw StatusException(Status.INTERNAL.withDescription("An error"))
}
}

val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor)
val client = GreeterCoroutineStub(channel)

runBlocking {
assertThrows<StatusException> { client.sayHello(helloRequest("")) }
}
}

@Test
fun retainsTrailersFromStatusExceptionThrownFromCoroutineContext() {
val aMetadataKey = GrpcMetadata.Key.of("a-metadata-key", GrpcMetadata.ASCII_STRING_MARSHALLER)
val interceptor = object : CoroutineContextServerInterceptor() {
override fun coroutineContext(
call: ServerCall<*, *>,
headers: GrpcMetadata
): CoroutineContext {
val trailers = GrpcMetadata().apply { put(aMetadataKey, "A value") }
throw StatusException(Status.INTERNAL, trailers)
}
}

val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor)
val client = GreeterCoroutineStub(channel)

runBlocking {
val thrown = assertThrows<StatusException> { client.sayHello(helloRequest("")) }

assertThat(thrown.trailers.get(aMetadataKey)).isEqualTo("A value")
}
}
}

0 comments on commit c0d1b97

Please sign in to comment.