diff --git a/grpc-kotlin-example-chatserver/README.md b/grpc-kotlin-example-chatserver/README.md index ffa09a8..fc24474 100644 --- a/grpc-kotlin-example-chatserver/README.md +++ b/grpc-kotlin-example-chatserver/README.md @@ -1,11 +1,11 @@ # grpc-kotlin-example-chatserver -A simple command line chat server written using bidirectional gRPC. +A simple command line chat server written using both bidirectional & server streaming gRPC. Build the parent project. From the repo root, run ```sh -mvn package +./mvnw clean package ``` Start the server @@ -14,12 +14,18 @@ Start the server java -jar grpc-kotlin-example-chatserver/target/grpc-kotlin-example-chatserver.jar server ``` -From another shell, start a client +From another shell, start a bidirectional streaming client ```sh java -jar grpc-kotlin-example-chatserver/target/grpc-kotlin-example-chatserver.jar client ``` +From the third shell, start a server streaming client + +```sh +java -jar grpc-kotlin-example-chatserver/target/grpc-kotlin-example-chatserver.jar clientSS +``` + --- Big thanks to [Björn Hegerfors](https://github.com/Bj0rnen) and [Emilio Del Tessandoro](https://github.com/emmmile) diff --git a/grpc-kotlin-example-chatserver/pom.xml b/grpc-kotlin-example-chatserver/pom.xml index 28f4f75..535c569 100644 --- a/grpc-kotlin-example-chatserver/pom.xml +++ b/grpc-kotlin-example-chatserver/pom.xml @@ -113,8 +113,7 @@ grpc-kotlin - io.rouz:grpc-kotlin-gen:${project.version}:exe:${os.detected.classifier} - + io.rouz:grpc-kotlin-gen:${project.version}:exe:${os.detected.classifier} diff --git a/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ChatService.kt b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ChatService.kt index d07f731..4c932cf 100644 --- a/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ChatService.kt +++ b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ChatService.kt @@ -2,7 +2,7 @@ * -\-\- * simple-kotlin-standalone-example * -- - * Copyright (C) 2016 - 2018 rouz.io + * Copyright (C) 2016 - 2019 rouz.io * -- * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,86 +24,120 @@ import com.google.protobuf.Empty import com.google.protobuf.Timestamp import io.grpc.Status import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch +import java.time.Instant +import java.util.concurrent.ConcurrentSkipListMap @UseExperimental(ExperimentalCoroutinesApi::class) class ChatService : ChatServiceImplBase() { - - data class Client(val name: String, val channel: SendChannel) - - private val clientChannels = LinkedHashSet() + private val clientChannels = ConcurrentSkipListMap>() override suspend fun getNames(request: Empty): ChatRoom { return ChatRoom.newBuilder() - .addAllNames(clientChannels.map(Client::name)) - .build() + .addAllNames(clientChannels.keys) + .build() } - override fun chat(requests: ReceiveChannel): ReceiveChannel { - val channel = Channel(Channel.UNLIMITED) - channel.invokeOnClose { + private fun createChannel() = Channel(100).apply { + invokeOnClose { it?.printStackTrace() } - println("New client connection: $channel") - - launch { - // wait for first message - val hello = requests.receive() - val name = hello.from - val client = Client(name, channel) - clientChannels.add(client) - - try { - for (chatMessage in requests) { - println("Got request from $requests:") - println(chatMessage) - val message = createMessage(chatMessage) - clientChannels - .filter { it.name != chatMessage.from } - .forEach { other -> - println("Sending to $other") - other.channel.send(message) - } - } - } catch (t: Throwable) { - println("Threw $t") - if (Status.fromThrowable(t).code != Status.Code.CANCELLED) { - println("An actual error occurred") - t.printStackTrace() + } + + private fun subscribe(name: String, ch: SendChannel) { + println("New client connected: $name") + clientChannels.put(name, ch) + ?.apply { + println("Close duplicate channel of user: $name") + close() } - } finally { - println("$name hung up. Removing client channel") - clientChannels.remove(client) - if (!channel.isClosedForSend) { - channel.close() + } + + private suspend fun broadcast(message: ChatMessage) = createMessage(message) + .let { broadcastMessage -> + println("Broadcast ${message.from}: ${message.message}") + + clientChannels.asSequence() + .filterNot { (name, _) -> name == message.from } + .forEach { (other, ch) -> + launch { + try { + println("Sending to $other") + ch.send(broadcastMessage) + } catch (e: Throwable) { + println("$other hung up: ${e.message}. Removing client channel") + clientChannels.remove(other)?.close(e) + } + } + } + } + + + override fun chat(requests: ReceiveChannel): ReceiveChannel = + createChannel().also { + GlobalScope.launch { + doChat(requests, it) } } + + private suspend fun doChat(req: ReceiveChannel, resp: SendChannel) { + val hello = req.receive() + subscribe(hello.from, resp) + broadcast(hello) + + try { + for (chatMessage in req) { + println("Got request from $req:") + println(chatMessage) + broadcast(chatMessage) + } + } catch (t: Throwable) { + println("Threw $t") + if (Status.fromThrowable(t).code != Status.Code.CANCELLED) { + println("An actual error occurred") + t.printStackTrace() + } + } finally { + println("${hello.from} hung up. Removing client channel") + clientChannels.remove(hello.from) + if (!resp.isClosedForSend) { + resp.close() + } } + } - return channel + override suspend fun say(request: ChatMessage): Empty = Empty.getDefaultInstance().also { + broadcast(request) + } + + override fun listen(request: WhoAmI): ReceiveChannel = createChannel().also { + subscribe(request.name, it) } fun shutdown() { println("Shutting down Chat service") - clientChannels.stream().forEach { client -> + clientChannels.forEach { (client, channel) -> println("Closing client channel $client") - client.channel.close() + channel.close() } clientChannels.clear() } - private fun createMessage(request: ChatMessage): ChatMessageFromService { - return ChatMessageFromService.newBuilder() - .setTimestamp( - Timestamp.newBuilder() - .setSeconds(System.nanoTime() / 1000000000) - .setNanos((System.nanoTime() % 1000000000).toInt()) - .build() - ) - .setMessage(request) - .build() - } + private fun createMessage(request: ChatMessage) = ChatMessageFromService.newBuilder() + .run { + timestamp = Instant.now().run { + Timestamp.newBuilder().run { + seconds = epochSecond + nanos = nano + build() + } + } + message = request + build() + } + } diff --git a/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/Entrypoint.kt b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/Entrypoint.kt index 7903acb..bd9dc51 100644 --- a/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/Entrypoint.kt +++ b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/Entrypoint.kt @@ -24,5 +24,6 @@ fun main(args: Array) { when(args[0]) { "server" -> grpcServer() "client" -> chatClient() + "clientSS" -> serverStreamingChatClient() } } diff --git a/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ServerStreaminChatClient.kt b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ServerStreaminChatClient.kt new file mode 100644 index 0000000..58e3dcc --- /dev/null +++ b/grpc-kotlin-example-chatserver/src/main/kotlin/io/rouz/grpc/examples/chat/ServerStreaminChatClient.kt @@ -0,0 +1,81 @@ +/*- + * -\-\- + * simple-kotlin-standalone-example + * -- + * Copyright (C) 2016 - 2019 rouz.io + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package io.rouz.grpc.examples.chat + +import io.grpc.ManagedChannelBuilder +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import java.util.concurrent.TimeUnit + +fun main() = serverStreamingChatClient() + +fun serverStreamingChatClient() { + val channel = ManagedChannelBuilder.forAddress("localhost", 15001) + .usePlaintext() + .build() + + val chatService = ChatServiceGrpc.newStub(channel) + + println("type :q to quit") + print("Enter your name: ") + val from = readLine() + + val listen = chatService.listen(WhoAmI.newBuilder().setName(from).build()) + + runBlocking(Dispatchers.IO) { + launch { + startPrintLoop(listen) + } + + try { + while (true) { + print("Message: ") + val message = readLine() + if (message == null || message == ":q") { + break + } + chatService.say( + ChatMessage.newBuilder() + .setFrom(from) + .setMessage(message) + .build() + ) + } + } finally { + println("closing") + channel.shutdownNow().awaitTermination(1, TimeUnit.SECONDS) + println("closed") + } + } +} + +private suspend fun startPrintLoop(chat: ReceiveChannel) = + try { + for (responseMessage in chat) { + val message = responseMessage.message + println("${message.from}: ${message.message}") + } + println("Server disconnected") + } catch (e: Throwable) { + println("Server disconnected badly: $e") + } diff --git a/grpc-kotlin-example-chatserver/src/main/proto/chat.proto b/grpc-kotlin-example-chatserver/src/main/proto/chat.proto index bfe08e4..065414d 100644 --- a/grpc-kotlin-example-chatserver/src/main/proto/chat.proto +++ b/grpc-kotlin-example-chatserver/src/main/proto/chat.proto @@ -12,12 +12,19 @@ option java_package = "io.rouz.grpc.examples.chat"; service ChatService { rpc Chat (stream ChatMessage) returns (stream ChatMessageFromService); rpc GetNames (google.protobuf.Empty) returns (ChatRoom); + + rpc Say(ChatMessage) returns (google.protobuf.Empty); + rpc Listen(WhoAmI) returns (stream ChatMessageFromService); } message ChatRoom { repeated string names = 1; } +message WhoAmI { + string name = 1; +} + message ChatMessage { string from = 1; string message = 2; diff --git a/grpc-kotlin-gen/src/main/resources/ImplBase.mustache b/grpc-kotlin-gen/src/main/resources/ImplBase.mustache index 92193bb..44a5110 100644 --- a/grpc-kotlin-gen/src/main/resources/ImplBase.mustache +++ b/grpc-kotlin-gen/src/main/resources/ImplBase.mustache @@ -63,7 +63,12 @@ abstract class {{serviceName}}ImplBase( tryCatchingStatus(responseObserver) { val responses = {{methodName}}(request) for (response in responses) { - onNext(response) + try { + onNext(response) + } catch (e: Throwable) { + responses.cancel() + throw e + } } } } @@ -104,7 +109,12 @@ abstract class {{serviceName}}ImplBase( tryCatchingStatus(responseObserver) { val responses = {{methodName}}(requests) for (response in responses) { - onNext(response) + try { + onNext(response) + } catch (e: Throwable) { + responses.cancel() + throw e + } } } } diff --git a/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/GreeterImpl.kt b/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/GreeterImpl.kt index 1df48b8..651b22f 100644 --- a/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/GreeterImpl.kt +++ b/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/GreeterImpl.kt @@ -33,11 +33,11 @@ import java.util.concurrent.Executors.newFixedThreadPool * Implementation of coroutine-based gRPC service defined in greeter.proto */ @UseExperimental(ExperimentalCoroutinesApi::class) -class GreeterImpl : GreeterImplBase( +open class GreeterImpl : GreeterImplBase( coroutineContext = newFixedThreadPool(4, threadFactory("server-worker-%d")).asCoroutineDispatcher() ) { - private val log = KotlinLogging.logger("server") + protected val log = KotlinLogging.logger("server") override suspend fun greet(request: GreetRequest): GreetReply { log.info(request.greeting) diff --git a/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/InfiniteStreamGreeterImpl.kt b/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/InfiniteStreamGreeterImpl.kt new file mode 100644 index 0000000..fe31675 --- /dev/null +++ b/grpc-kotlin-test/src/main/kotlin/io/rouz/greeter/InfiniteStreamGreeterImpl.kt @@ -0,0 +1,66 @@ +/*- + * -\-\- + * grpc-kotlin-test + * -- + * Copyright (C) 2016 - 2019 rouz.io + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + + +package io.rouz.greeter + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Semaphore + +class InfiniteStreamGreeterImpl : GreeterImpl() { + + val subscribers = ConcurrentHashMap>() + val startupSync = Semaphore(0) + + override fun greetServerStream(request: GreetRequest): ReceiveChannel = + Channel(100).also { + val name = request.greeting + subscribers[name] = it + log.info("Subscribed: {}", name) + startupSync.release() + } + + suspend fun greetAllSubscribers(word: String) { + withContext(Dispatchers.IO) { + subscribers.forEach { (subs, ch) -> + GreetReply.newBuilder().run { + reply = "$word $subs" + build() + }.also { + launch { + try { + ch.send(it) + log.info("Message to client: {}", it.reply) + } catch (e: Throwable) { + log.info("Unsubscribe client: {}", subs) + subscribers.remove(subs) + ch.close() + } + } + } + } + } + } +} diff --git a/grpc-kotlin-test/src/test/kotlin/io/rouz/greeter/ClientAbandonTest.kt b/grpc-kotlin-test/src/test/kotlin/io/rouz/greeter/ClientAbandonTest.kt new file mode 100644 index 0000000..8c990a4 --- /dev/null +++ b/grpc-kotlin-test/src/test/kotlin/io/rouz/greeter/ClientAbandonTest.kt @@ -0,0 +1,143 @@ +/*- + * -\-\- + * grpc-kotlin-test + * -- + * Copyright (C) 2016 - 2019 rouz.io + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package io.rouz.greeter + +import io.grpc.ManagedChannel +import io.grpc.ManagedChannelBuilder +import io.grpc.netty.NettyServerBuilder +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.yield +import mu.KotlinLogging +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import java.util.concurrent.TimeUnit.SECONDS +import kotlin.coroutines.resume +import kotlin.coroutines.suspendCoroutine + + +@RunWith(JUnit4::class) +class ClientAbandonTest { + private val log = KotlinLogging.logger {} + + private val svc = InfiniteStreamGreeterImpl() + + private fun server() = + NettyServerBuilder.forPort(16565) + .addService(svc) + .build() + .start() + + private fun getChannel() = ManagedChannelBuilder + .forAddress("localhost", 16565) + .usePlaintext() + .build().also { + channels.add(it) + } + + + private fun dumpSubscribers() = + log.info("Current subscribers: {}", svc.subscribers.keys) + + private val serv = server() + private val channels: MutableList = mutableListOf() + + @After + fun tearDown() { + channels.forEach { ch -> + if (!ch.shutdownNow().awaitTermination(1, SECONDS)) { + error("Failed to shutdown channel") + } + } + if (!serv.shutdownNow().awaitTermination(1, SECONDS)) { + error("Failed to shutdown server") + } + } + + @Test + fun infiniteServerStreamingToSubscribers() { + // Create two subscribers & subscribe them + val rouzStub = GreeterGrpc.newStub(getChannel()) + val igorStub = GreeterGrpc.newStub(getChannel()) + + val rouzCh = rouzStub.greetServerStream(req("rouz")) + val igorCh = igorStub.greetServerStream(req("igor")) + svc.startupSync.acquire(2) + + runBlocking { + dumpSubscribers() + assertEquals(setOf("rouz", "igor"), svc.subscribers.keys) + + // Publish greeting to all + val rouzJob = launch { + rouzCh.receive().also { + assertEquals("Hello rouz", it.reply) + } + } + + val igorJob = launch { + igorCh.receive().also { + assertEquals("Hello igor", it.reply) + } + } + + val sendJob = launch { + svc.greetAllSubscribers("Hello") + } + + joinAll(rouzJob, igorJob, sendJob) + assertEquals(setOf("rouz", "igor"), svc.subscribers.keys) + + // One client disconnects + suspendCoroutine { + val ch = igorStub.channel as ManagedChannel + ch.shutdownNow() + it.resume(assertTrue(ch.awaitTermination(1, SECONDS))) + } + + // Continue publishing of new greetings + val rouzJob2 = launch { + repeat(4) { + assertEquals("Hola $it rouz", rouzCh.receive().reply) + } + } + + // One of the first invocations should close the channel of Igor, the next one should remove subscription. + val senderJob2 = launch { + repeat(4) { + svc.greetAllSubscribers("Hola $it") + yield() + } + } + + joinAll(rouzJob2, senderJob2) + + // Check that subscription of Igor has gone + dumpSubscribers() + assertEquals(setOf("rouz"), svc.subscribers.keys) + } + } +}