diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml index b408497..461a31e 100644 --- a/.idea/codeStyles/Project.xml +++ b/.idea/codeStyles/Project.xml @@ -3,7 +3,6 @@ @@ -13,4 +12,4 @@ - \ No newline at end of file + diff --git a/.idea/gradle.xml b/.idea/gradle.xml index 04cdbe6..5b75b81 100644 --- a/.idea/gradle.xml +++ b/.idea/gradle.xml @@ -25,6 +25,7 @@ - \ No newline at end of file + diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..2953945 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/build-logic/src/main/kotlin/plugin.common.gradle.kts b/build-logic/src/main/kotlin/plugin.common.gradle.kts index b56c38b..c555066 100644 --- a/build-logic/src/main/kotlin/plugin.common.gradle.kts +++ b/build-logic/src/main/kotlin/plugin.common.gradle.kts @@ -85,3 +85,14 @@ tasks { } } } + +kotlin.targets.withType { + // Do not activate backtrace for Mingw + // https://kotlinlang.org/docs/whatsnew1620.html?_ga=2.5870007.58710271.1649248900-2086887657.1620731764#better-stack-traces-with-libbacktrace + // https://youtrack.jetbrains.com/issue/KT-51866/Compile-error-to-mingwX64-with-libbacktrace + if (this.konanTarget.family != Family.MINGW) { + binaries.all { + binaryOptions["sourceInfoType"] = "libbacktrace" + } + } +} diff --git a/gradle.properties b/gradle.properties index ee1afd8..e8e5879 100644 --- a/gradle.properties +++ b/gradle.properties @@ -10,7 +10,6 @@ kotlin.code.style=official kotlin.mpp.enableCInteropCommonization=true kotlin.native.ignoreDisabledTargets=true kotlin.native.ignoreIncorrectDependencies=true -#kotlin.native.binary.sourceInfoType=libbacktrace kotlin.js.generate.executable.default=false kotlin.mpp.stability.nowarn=true kotlin.tests.individualTaskReports=true diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 005b24d..5e7a946 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -11,6 +11,7 @@ kotlinx-coroutines = "1.9.0-RC" kotlinx-atomicfu = "0.25.0" kotlinx-benchmark = "0.4.11" kotlinx-cli = "0.3.6" +kotlinx-io = "0.5.1" ktor = "3.0.0-beta-2" kermit = "2.0.4" @@ -44,6 +45,7 @@ kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-t kotlinx-atomicfu = { module = "org.jetbrains.kotlinx:atomicfu", version.ref = "kotlinx-atomicfu" } kotlinx-benchmark-runtime = { module = "org.jetbrains.kotlinx:kotlinx-benchmark-runtime", version.ref = "kotlinx-benchmark" } kotlinx-cli = { module = "org.jetbrains.kotlinx:kotlinx-cli", version.ref = "kotlinx-cli" } +kotlinx-io-core = { module = "org.jetbrains.kotlinx:kotlinx-io-core", version.ref = "kotlinx-io" } ktor-io = { module = "io.ktor:ktor-io", version.ref = "ktor" } ktor-network = { module = "io.ktor:ktor-network", version.ref = "ktor" } diff --git a/kzmq-benchmarks/build.gradle.kts b/kzmq-benchmarks/build.gradle.kts index 2c24752..e02a371 100644 --- a/kzmq-benchmarks/build.gradle.kts +++ b/kzmq-benchmarks/build.gradle.kts @@ -18,6 +18,7 @@ kotlin { dependencies { implementation(libs.kotlinx.benchmark.runtime) implementation(project(":kzmq-core")) + implementation(libs.kotlinx.io.core) } } diff --git a/kzmq-benchmarks/src/commonMain/kotlin/org/zeromq/benchmarks/PullPushBenchmark.kt b/kzmq-benchmarks/src/commonMain/kotlin/org/zeromq/benchmarks/PullPushBenchmark.kt index ad64414..9c854d2 100644 --- a/kzmq-benchmarks/src/commonMain/kotlin/org/zeromq/benchmarks/PullPushBenchmark.kt +++ b/kzmq-benchmarks/src/commonMain/kotlin/org/zeromq/benchmarks/PullPushBenchmark.kt @@ -7,6 +7,8 @@ package org.zeromq.benchmarks import kotlinx.benchmark.* import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* import kotlin.random.* @@ -23,7 +25,7 @@ open class PullPushBenchmark() { @Param("10", "100", "1000", "10000", "100000") var messageSize = 10 - private lateinit var message: Message + private lateinit var messageData: ByteString private lateinit var scope: CoroutineScope private lateinit var context: Context @@ -41,7 +43,7 @@ open class PullPushBenchmark() { else -> error("Unsuported transport '$transport'") } - message = Message(ByteArray(messageSize)) + messageData = ByteString(ByteArray(messageSize)) val engine = engines.find { it.name.lowercase() == engineName } ?: error("Engine '$engineName' not found") if (!engine.supportedTransports.contains(transport)) @@ -63,8 +65,8 @@ open class PullPushBenchmark() { } @Benchmark - fun sendReceive() = runBlocking { - pushSocket.send(message) - pullSocket.receive() + fun sendReceive(blackhole: Blackhole) = runBlocking { + pushSocket.send { writeFrame { write(messageData) } } + blackhole.consume(pullSocket.receive { readFrame { readByteString() } }) } } diff --git a/kzmq-cio/build.gradle.kts b/kzmq-cio/build.gradle.kts index 46cdd6e..21479bc 100644 --- a/kzmq-cio/build.gradle.kts +++ b/kzmq-cio/build.gradle.kts @@ -22,5 +22,10 @@ kotlin { implementation(libs.kermit) } } + commonTest { + dependencies { + implementation(project(":kzmq-test")) + } + } } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIODealerSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIODealerSocket.kt index a4bea87..5e0a2fe 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIODealerSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIODealerSocket.kt @@ -7,6 +7,7 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -60,52 +61,13 @@ internal class CIODealerSocket( ) : CIOSocket(engine, Type.DEALER), CIOSendSocket, CIOReceiveSocket, DealerSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val sendChannel = Channel() - override val receiveChannel = Channel() - - init { - setHandler { - val forwardJobs = JobMap() - - while (isActive) { - val (kind, peerMailbox) = peerEvents.receive() - when (kind) { - PeerEvent.Kind.ADDITION -> forwardJobs.add(peerMailbox) { dispatchRequestsReplies(peerMailbox) } - PeerEvent.Kind.REMOVAL -> forwardJobs.remove(peerMailbox) - else -> {} - } - } - } - } - - private fun CoroutineScope.dispatchRequestsReplies(peerMailbox: PeerMailbox) = launch { - launch { - while (isActive) { - val request = sendChannel.receive() - logger.d { "Dispatching request $request to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(request)) - } - } - launch { - try { - while (isActive) { - val reply = peerMailbox.receiveChannel.receive().messageOrThrow() - logger.d { "Dispatching reply $reply from $peerMailbox" } - receiveChannel.send(reply) - } - } catch (e: ClosedReceiveChannelException) { - // Coroutine's cancellation happened while suspending on receive - // and the receiveChannel of the peerMailbox has already been closed - } - } - } + override val handler = setupHandler(DealerSocketHandler()) override var conflate: Boolean get() = TODO("Not yet implemented") set(value) {} - override var routingId: ByteArray? by options::routingId + override var routingId: ByteString? by options::routingId override var probeRouter: Boolean get() = TODO("Not yet implemented") @@ -115,3 +77,22 @@ internal class CIODealerSocket( private val validPeerSocketTypes = setOf(Type.REP, Type.ROUTER) } } + +internal class DealerSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + mailboxes.update(peerEvents.receive()) + } + } + + override suspend fun send(message: Message) { + mailboxes.sendToFirstAvailable(message) + } + + override suspend fun receive(): Message { + val (_, message) = mailboxes.receiveFromFirst() + return message + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPairSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPairSocket.kt index d643eb3..5427baa 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPairSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPairSocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import org.zeromq.internal.* @@ -36,7 +37,7 @@ import org.zeromq.internal.* * socket SHALL destroy its double queue and SHALL discard any messages it contains. * 6. SHOULD constrain incoming and outgoing queue sizes to a runtime-configurable limit. * - * B. For processing incoming messages: + * B. For processing outgoing messages: * 1. SHALL consider its peer as available only when it has a outgoing queue that is not full. * 2. SHALL block on sending, or return a suitable error, when it has no available peer. * 3. SHALL not accept further messages when it has no available peer. @@ -51,56 +52,48 @@ internal class CIOPairSocket( ) : CIOSocket(engine, Type.PAIR), CIOReceiveSocket, CIOSendSocket, PairSocket { override val validPeerTypes: Set get() = validPeerSocketTypes + override val handler = setupHandler(PairSocketHandler()) - override val receiveChannel = Channel() - override val sendChannel = Channel() - - init { - setHandler { - var forwardJob: Job? = null - - while (isActive) { - val (kind, peerMailbox) = peerEvents.receive() - when (kind) { - PeerEvent.Kind.ADDITION -> { - // FIXME what should we do if it already has a peer? - if (forwardJob != null) continue - - forwardJob = forwardJob(peerMailbox) - } - - PeerEvent.Kind.REMOVAL -> { - if (forwardJob == null) continue + companion object { + private val validPeerSocketTypes = setOf(Type.PAIR) + } +} - forwardJob.cancel() - forwardJob = null - } +internal class PairSocketHandler : SocketHandler { + private val mailbox = atomic(null) - else -> {} - } - } + private suspend fun awaitCurrentPeer() { + var counter = 0 + while (mailbox.value == null) { + if (counter++ < 100) println("in awaitCurrentPeer: ${mailbox.value}") + yield() } } - private fun CoroutineScope.forwardJob(mailbox: PeerMailbox) = launch { - launch { - while (isActive) { - val message = sendChannel.receive() - logger.v { "Sending $message to $mailbox" } - mailbox.sendChannel.send(CommandOrMessage(message)) - } - } - launch { - while (isActive) { - val commandOrMessage = mailbox.receiveChannel.receive() - val message = commandOrMessage.messageOrThrow() - logger.v { "Receiving $message from $mailbox" } - receiveChannel.send(message) + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + val (kind, peerMailbox) = peerEvents.receive() + when (kind) { + PeerEvent.Kind.ADDITION -> mailbox.value = peerMailbox + PeerEvent.Kind.REMOVAL -> mailbox.value = null + else -> {} } } } - companion object { - private val validPeerSocketTypes = setOf(Type.PAIR) + override suspend fun send(message: Message) { + awaitCurrentPeer() + val mailbox = mailbox.value!! + logger.v { "Sending $message to $mailbox" } + mailbox.sendChannel.send(CommandOrMessage(message)) + } + + override suspend fun receive(): Message { + awaitCurrentPeer() + val mailbox = mailbox.value!! + val commandOrMessage = mailbox.receiveChannel.receive() + val message = commandOrMessage.messageOrThrow() + logger.v { "Receiving $message from $mailbox" } + return message } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPublisherSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPublisherSocket.kt index 5c28b63..7eb1b1d 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPublisherSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPublisherSocket.kt @@ -8,6 +8,7 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* +import kotlinx.io.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -87,45 +88,7 @@ internal class CIOPublisherSocket( ) : CIOSocket(engine, Type.PUB), CIOSendSocket, PublisherSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val sendChannel = Channel() - - init { - setHandler { - val peerMailboxes = hashSetOf() - var subscriptions = SubscriptionTrie() - - while (isActive) { - select { - peerEvents.onReceive { (kind, peerMailbox) -> - when (kind) { - PeerEvent.Kind.ADDITION -> peerMailboxes.add(peerMailbox) - PeerEvent.Kind.REMOVAL -> peerMailboxes.remove(peerMailbox) - else -> {} - } - } - - for (peerMailbox in peerMailboxes) { - peerMailbox.receiveChannel.onReceive { commandOrMessage -> - logger.d { "Handling $commandOrMessage from $peerMailbox" } - subscriptions = when (val command = commandOrMessage.commandOrThrow()) { - is SubscribeCommand -> subscriptions.add(command.topic, peerMailbox) - is CancelCommand -> subscriptions.remove(command.topic, peerMailbox) - else -> protocolError("Expected SUBSCRIBE or CANCEL, but got ${command.name}") - } - } - } - - sendChannel.onReceive { message -> - subscriptions.forEachMatching(message.first()) { peerMailbox -> - logger.d { "Dispatching $message to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(message)) - } - } - } - } - } - } + override val handler = setupHandler(PublisherSocketHandler()) override var conflate: Boolean get() = TODO("Not yet implemented") @@ -141,3 +104,40 @@ internal class CIOPublisherSocket( private val validPeerSocketTypes = setOf(Type.SUB, Type.XSUB) } } + +internal class PublisherSocketHandler : SocketHandler { + private val mailboxes = hashSetOf() + private var subscriptions = SubscriptionTrie() + + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + select { + peerEvents.onReceive { (kind, peerMailbox) -> + when (kind) { + PeerEvent.Kind.ADDITION -> mailboxes.add(peerMailbox) + PeerEvent.Kind.REMOVAL -> mailboxes.remove(peerMailbox) + else -> {} + } + } + + for (mailbox in mailboxes) { + mailbox.receiveChannel.onReceive { commandOrMessage -> + logger.d { "Handling $commandOrMessage from $mailbox" } + subscriptions = when (val command = commandOrMessage.commandOrThrow()) { + is SubscribeCommand -> subscriptions.add(command.topic, mailbox) + is CancelCommand -> subscriptions.remove(command.topic, mailbox) + else -> protocolError("Expected SUBSCRIBE or CANCEL, but got ${command.name}") + } + } + } + } + } + } + + override suspend fun send(message: Message) { + subscriptions.forEachMatching(message.peekFirstFrame().readByteArray()) { mailbox -> + logger.d { "Dispatching $message to $mailbox" } + mailbox.sendChannel.send(CommandOrMessage(message)) + } + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPullSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPullSocket.kt index fa790fb..574a540 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPullSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPullSocket.kt @@ -7,7 +7,6 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import kotlinx.coroutines.selects.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -50,12 +49,7 @@ internal class CIOPullSocket( ) : CIOSocket(engine, Type.PULL), CIOReceiveSocket, PullSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val receiveChannel = Channel() - - init { - setHandler { handlePullSocket(peerEvents, receiveChannel) } - } + override val handler = setupHandler(PullSocketHandler()) override var conflate: Boolean get() = TODO("Not yet implemented") @@ -66,25 +60,17 @@ internal class CIOPullSocket( } } -internal suspend fun handlePullSocket( - peerEvents: ReceiveChannel, - receiveChannel: SendChannel, -) = coroutineScope { - val mailboxes = CircularQueue() +internal class PullSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() - while (isActive) { - select { - peerEvents.onReceive(mailboxes::update) - - if (mailboxes.isNotEmpty()) { - mailboxes.forEachIndexed { index, mailbox -> - mailbox.receiveChannel.onReceive { commandOrMessage -> - logger.v { "Received command or message from $mailbox" } - mailboxes.rotateAfter(index) - receiveChannel.send(commandOrMessage.messageOrThrow()) - } - } - } + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + mailboxes.update(peerEvents.receive()) } } + + override suspend fun receive(): Message { + val (_, message) = mailboxes.receiveFromFirst() + return message + } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPushSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPushSocket.kt index e26f5de..19849b8 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPushSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOPushSocket.kt @@ -7,7 +7,6 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import kotlinx.coroutines.selects.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -53,12 +52,7 @@ internal class CIOPushSocket( ) : CIOSocket(engine, Type.PUSH), CIOSendSocket, PushSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val sendChannel = Channel() - - init { - setHandler { handlePushSocket(peerEvents, sendChannel) } - } + override val handler = setupHandler(PushSocketHandler()) override var conflate: Boolean get() = TODO("Not yet implemented") @@ -69,64 +63,16 @@ internal class CIOPushSocket( } } -internal suspend fun handlePushSocket( - peerEvents: ReceiveChannel, - sendChannel: ReceiveChannel, -) = coroutineScope { - val mailboxes = CircularQueue() - - while (isActive) { - select { - peerEvents.onReceive(mailboxes::update) - - if (mailboxes.isNotEmpty()) { - sendChannel.onReceive { message -> - // Fast path: Find the first mailbox we can send immediately - logger.v { "Try send message to first available" } - val sent = mailboxes.trySendToFirstAvailable(message) - - if (!sent) { - // Slow path: Biased select on each mailbox's onSend - logger.v { "Send message to first available" } - select { - peerEvents.onReceive(mailboxes::update) +internal class PushSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() - val commandOrMessage = CommandOrMessage(message) - mailboxes.forEachIndexed { index, mailbox -> - mailbox.sendChannel.onSend(commandOrMessage) { - logger.v { "Sent message to $mailbox" } - mailboxes.rotateAfter(index) - } - } - } - } - } - } + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + mailboxes.update(peerEvents.receive()) } } -} - -internal fun CircularQueue.update(event: PeerEvent) { - val mailbox = event.peerMailbox - when (event.kind) { - PeerEvent.Kind.ADDITION -> add(mailbox) - PeerEvent.Kind.REMOVAL -> remove(mailbox) - else -> {} - } -} -internal fun CircularQueue.trySendToFirstAvailable(message: Message): Boolean { - val commandOrMessage = CommandOrMessage(message) - val index = indexOfFirst { mailbox -> - val result = mailbox.sendChannel.trySend(commandOrMessage) - logger.v { - if (result.isSuccess) "Sent message to $mailbox" - else "Failed to send message to $mailbox" - } - result.isSuccess + override suspend fun send(message: Message) { + mailboxes.sendToFirstAvailable(message) } - - val sent = index != -1 - if (sent) rotateAfter(index) - return sent } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReceiveSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReceiveSocket.kt index 465071c..1c0b53c 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReceiveSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReceiveSocket.kt @@ -10,24 +10,22 @@ import org.zeromq.internal.* internal interface CIOReceiveSocket : ReceiveSocket { - val receiveChannel: ReceiveChannel + val handler: SocketHandler val options: SocketOptions - override suspend fun receive(): Message = receiveChannel.receive() + override suspend fun receive(): Message = handler.receive() override suspend fun receiveCatching(): SocketResult { - val result = receiveChannel.receiveCatching() + val result = runCatching { receive() } return if (result.isSuccess) SocketResult.success(result.getOrThrow()) else SocketResult.failure(result.exceptionOrNull()) } override fun tryReceive(): SocketResult { - val result = receiveChannel.tryReceive() - return if (result.isSuccess) SocketResult.success(result.getOrThrow()) - else SocketResult.failure(result.exceptionOrNull()) + TODO() } - override val onReceive get() = receiveChannel.onReceive + override val onReceive get() = TODO() override var receiveBufferSize: Int get() = TODO("Not yet implemented") diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReplySocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReplySocket.kt index fd4019e..ee237b8 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReplySocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOReplySocket.kt @@ -5,8 +5,11 @@ package org.zeromq +import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlinx.coroutines.sync.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -58,59 +61,57 @@ internal class CIOReplySocket( ) : CIOSocket(engine, Type.REP), CIOReceiveSocket, CIOSendSocket, ReplySocket { override val validPeerTypes: Set get() = validPeerSocketTypes + override val handler = setupHandler(ReplySocketHandler()) - override val receiveChannel = Channel() - override val sendChannel = Channel() + override var routingId: ByteString? by options::routingId - private val requestsChannel = Channel>() - - init { - setHandler { - launch { - val forwardJobs = JobMap() + companion object { + private val validPeerSocketTypes = setOf(Type.REQ, Type.DEALER) + } +} - while (isActive) { - val (kind, peerMailbox) = peerEvents.receive() - when (kind) { - PeerEvent.Kind.ADDITION -> forwardJobs.add(peerMailbox) { forwardRequests(peerMailbox) } - PeerEvent.Kind.REMOVAL -> forwardJobs.remove(peerMailbox) - else -> {} - } - } - } - launch { - while (isActive) { - val (peerMailbox, request) = requestsChannel.receive() +internal class ReplySocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + private var state = atomic(ReplySocketState.Idle) + private val requestReplyLock = Mutex() - logger.v { "Received request $request from $peerMailbox" } - val (identities, requestData) = extractPrefixAddress(request) - receiveChannel.send(requestData) + private suspend fun awaitState(predicate: (ReplySocketState?) -> Boolean) { + while (!predicate(state.value)) yield() + } - val replyData = sendChannel.receive() - val reply = addPrefixAddress(replyData, identities) - logger.v { "Sending reply $reply back to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(reply)) - } - } + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + val event = peerEvents.receive() + mailboxes.update(event) } } - private fun CoroutineScope.forwardRequests(peerMailbox: PeerMailbox) = launch { - try { - while (isActive) { - val message = peerMailbox.receiveChannel.receive().messageOrThrow() - logger.v { "Forwarding request $message from $peerMailbox" } - requestsChannel.send(peerMailbox to message) - } - } catch (e: ClosedReceiveChannelException) { - // Coroutine's cancellation happened while suspending on receive - // and the receiveChannel of the peerMailbox has already been closed + override suspend fun receive(): Message { + awaitState { it is ReplySocketState.Idle } + requestReplyLock.withLock { + val (mailbox, message) = mailboxes.receiveFromFirst() + state.value = ReplySocketState.ProcessingRequest(mailbox, message.popPrefixAddress()) + return message } } - override var routingId: ByteArray? by options::routingId + override suspend fun send(message: Message) { + awaitState { it is ReplySocketState.ProcessingRequest } + requestReplyLock.withLock { + val (peer, address) = state.value as ReplySocketState.ProcessingRequest - companion object { - private val validPeerSocketTypes = setOf(Type.REQ, Type.DEALER) + message.pushPrefixAddress(address) + peer.sendChannel.send(CommandOrMessage(message)) + state.value = ReplySocketState.Idle + } } } + +private sealed interface ReplySocketState { + data object Idle : ReplySocketState + + data class ProcessingRequest( + val peer: PeerMailbox, + val address: List, + ) : ReplySocketState +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORequestSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORequestSocket.kt index 4e3160b..6e26d2a 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORequestSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORequestSocket.kt @@ -5,8 +5,11 @@ package org.zeromq +import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlinx.coroutines.sync.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -54,17 +57,9 @@ internal class CIORequestSocket( ) : CIOSocket(engine, Type.REQ), CIOSendSocket, CIOReceiveSocket, RequestSocket { override val validPeerTypes: Set get() = validPeerSocketTypes + override val handler = setupHandler(RequestSocketHandler()) - override val sendChannel = Channel() - override val receiveChannel = Channel() - - init { - setHandler { - handleRequestSocket(peerEvents, sendChannel, receiveChannel) - } - } - - override var routingId: ByteArray? by options::routingId + override var routingId: ByteString? by options::routingId override var probeRouter: Boolean get() = TODO("Not yet implemented") set(value) {} @@ -80,68 +75,48 @@ internal class CIORequestSocket( } } -internal suspend fun handleRequestSocket( - peerEvents: ReceiveChannel, - sendChannel: ReceiveChannel, - receiveChannel: SendChannel, -) = coroutineScope { +internal class RequestSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + private var lastSentPeer = atomic(null) + private val requestReplyLock = Mutex() - val requestsChannel = Channel>() - val repliesChannel = Channel>() + private suspend fun awaitLastSentPeer(predicate: (PeerMailbox?) -> Boolean) { + while (!predicate(lastSentPeer.value)) yield() + } - fun CoroutineScope.dispatchRequestsReplies(peerMailbox: PeerMailbox) = launch { - launch { - while (isActive) { - val request = sendChannel.receive() - logger.v { "Dispatching request $request to $peerMailbox" } - requestsChannel.send(peerMailbox to request) - } - } - launch { - try { - while (isActive) { - val reply = peerMailbox.receiveChannel.receive().messageOrThrow() - logger.v { "Dispatching reply $reply from $peerMailbox" } - repliesChannel.send(peerMailbox to reply) - } - } catch (e: ClosedReceiveChannelException) { - // Coroutine's cancellation happened while suspending on receive - // and the receiveChannel of the peerMailbox has already been closed - } + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + mailboxes.update(peerEvents.receive()) } } - launch { - val forwardJobs = JobMap() - - while (isActive) { - val (kind, peerMailbox) = peerEvents.receive() - when (kind) { - PeerEvent.Kind.ADDITION -> forwardJobs.add(peerMailbox) { dispatchRequestsReplies(peerMailbox) } - PeerEvent.Kind.REMOVAL -> forwardJobs.remove(peerMailbox) - else -> {} - } + override suspend fun send(message: Message) { + awaitLastSentPeer { it == null } + requestReplyLock.withLock { + message.pushPrefixAddress() + val mailbox = mailboxes.sendToFirstAvailable(message) + lastSentPeer.value = mailbox + logger.v { "Sent request to $mailbox" } } } - launch { - while (isActive) { - val (peerMailbox, requestData) = requestsChannel.receive() - val request = addPrefixAddress(requestData) - logger.v { "Sending request $request to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(request)) + override suspend fun receive(): Message { + awaitLastSentPeer { it != null } + requestReplyLock.withLock { + while (true) { + val (mailbox, message) = mailboxes.receiveFromFirst() + + message.popPrefixAddress() - while (isActive) { - val (otherPeerMailbox, reply) = repliesChannel.receive() - if (otherPeerMailbox != peerMailbox) { - logger.w { "Ignoring reply $reply from $otherPeerMailbox" } + // Should we "discard" messages in another coroutine in `handle()`? + if (mailbox != lastSentPeer.value) { + logger.w { "Ignoring reply $message from $mailbox" } continue } - logger.v { "Sending back reply $reply from $peerMailbox" } - val (_, replyData) = extractPrefixAddress(reply) - receiveChannel.send(replyData) - break + logger.v { "Received reply $message from $mailbox" } + lastSentPeer.value = null + return message } } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORouterSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORouterSocket.kt index 90ae770..d736496 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORouterSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIORouterSocket.kt @@ -7,10 +7,9 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import kotlinx.coroutines.selects.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* import org.zeromq.internal.utils.* -import kotlin.random.* /** * An implementation of the [ROUTER socket](https://rfc.zeromq.org/spec/28/). @@ -66,91 +65,70 @@ internal class CIORouterSocket( ) : CIOSocket(engine, Type.ROUTER), CIOReceiveSocket, CIOSendSocket, RouterSocket { override val validPeerTypes: Set get() = validPeerSocketTypes + override val handler = setupHandler(RouterSocketHandler()) - override val receiveChannel = Channel() - override val sendChannel = Channel() - - init { - setHandler { - val forwardJobs = JobMap() - val perIdentityMailboxes = hashMapOf() + override var routingId: ByteString? by options::routingId + override var probeRouter: Boolean + get() = TODO("Not yet implemented") + set(value) {} + override var mandatory: Boolean + get() = TODO("Not yet implemented") + set(value) {} + override var handover: Boolean + get() = TODO("Not yet implemented") + set(value) {} - fun generateNewIdentity(): Identity { - while (true) { - val identity = Identity(Random.nextBytes(ByteArray(16))) - if (identity !in perIdentityMailboxes) return identity - } - } + companion object { + private val validPeerSocketTypes = setOf(Type.REQ, Type.DEALER, Type.ROUTER) + } +} - while (isActive) { - select { - peerEvents.onReceive { (kind, peerMailbox) -> - when (kind) { - PeerEvent.Kind.CONNECTION -> { - val identity = - peerMailbox.identity ?: generateNewIdentity().also { peerMailbox.identity = it } - perIdentityMailboxes[identity] = peerMailbox +internal class RouterSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + private val perIdentityMailboxes = hashMapOf() - forwardJobs.add(peerMailbox) { routeRequests(peerMailbox) } - } + private fun randomIdentity(): Identity { + while (true) { + val identity = Identity.random() + if (identity !in perIdentityMailboxes) return identity + } + } - PeerEvent.Kind.DISCONNECTION -> { - peerMailbox.identity?.let { identity -> - perIdentityMailboxes[identity] = peerMailbox - } + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + val event = peerEvents.receive() - forwardJobs.remove(peerMailbox) - } + mailboxes.update(event) - else -> {} - } - } + val (kind, mailbox) = event + when (kind) { + PeerEvent.Kind.CONNECTION -> { + val identity = mailbox.identity ?: randomIdentity().also { mailbox.identity = it } + perIdentityMailboxes[identity] = mailbox + } - sendChannel.onReceive { - val (identity, message) = extractIdentity(it) - perIdentityMailboxes[identity]?.let { peerMailbox -> - logger.d { "Forwarding reply $message to $peerMailbox with identity $identity" } - peerMailbox.sendChannel.send(CommandOrMessage(message)) - } - } + PeerEvent.Kind.DISCONNECTION -> { + val identity = mailbox.identity ?: error("Peer identity should not be null") + perIdentityMailboxes.remove(identity) } + + else -> {} } } } - private fun CoroutineScope.routeRequests(peerMailbox: PeerMailbox) = launch { - try { - while (isActive) { - val message = peerMailbox.receiveChannel.receive().messageOrThrow() - peerMailbox.identity?.let { identity -> - logger.d { "Forwarding request $message from $peerMailbox with identity $identity" } - receiveChannel.send(prependIdentity(message, identity)) - } - } - } catch (e: ClosedReceiveChannelException) { - // Coroutine's cancellation happened while suspending on receive - // and the receiveChannel of the peerMailbox has already been closed + override suspend fun send(message: Message) { + val identity = message.popIdentity() + perIdentityMailboxes[identity]?.let { peerMailbox -> + logger.d { "Forwarding reply $message to $peerMailbox with identity $identity" } + peerMailbox.sendChannel.send(CommandOrMessage(message)) } } - override var routingId: ByteArray? by options::routingId - override var probeRouter: Boolean - get() = TODO("Not yet implemented") - set(value) {} - override var mandatory: Boolean - get() = TODO("Not yet implemented") - set(value) {} - override var handover: Boolean - get() = TODO("Not yet implemented") - set(value) {} - - companion object { - private val validPeerSocketTypes = setOf(Type.REQ, Type.DEALER, Type.ROUTER) + override suspend fun receive(): Message { + val (peerMailbox, message) = mailboxes.receiveFromFirst() + val identity = peerMailbox.identity ?: error("Peer identity should not be null") + message.pushIdentity(identity) + return message } } - -private fun prependIdentity(message: Message, identity: Identity): Message = - Message(listOf(identity.value) + message.frames) - -private fun extractIdentity(message: Message): Pair = - Identity(message.frames[0]) to Message(message.frames.subList(1, message.frames.size)) diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSendSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSendSocket.kt index c6c0039..63e12fc 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSendSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSendSocket.kt @@ -10,22 +10,19 @@ import org.zeromq.internal.* internal interface CIOSendSocket : SendSocket { - val sendChannel: SendChannel + val handler: SocketHandler val options: SocketOptions - override suspend fun send(message: Message) = sendChannel.send(message) + override suspend fun send(message: Message) = handler.send(message) - override suspend fun sendCatching(message: Message): SocketResult = try { - sendChannel.send(message) - SocketResult.success(Unit) - } catch (t: Throwable) { - SocketResult.failure(t) + override suspend fun sendCatching(message: Message): SocketResult { + val result = runCatching { send(message) } + return if (result.isSuccess) SocketResult.success(result.getOrThrow()) + else SocketResult.failure(result.exceptionOrNull()) } override fun trySend(message: Message): SocketResult { - val result = sendChannel.trySend(message) - return if (result.isSuccess) SocketResult.success(Unit) - else SocketResult.failure(result.exceptionOrNull()) + TODO() } override var multicastHops: Int diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSocket.kt index 766ce36..449b124 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSocket.kt @@ -20,7 +20,7 @@ internal abstract class CIOSocket( engine.lingerScope, engine.transportRegistry ) - protected val peerEvents = peerManager.peerEvents + private val peerEvents = peerManager.peerEvents private val exceptionHandler = CoroutineExceptionHandler { _, throwable -> logger.e(throwable) { "An error occurred in socket" } @@ -29,10 +29,11 @@ internal abstract class CIOSocket( private lateinit var socketJob: Job - fun setHandler(block: suspend CoroutineScope.() -> Unit) { + fun setupHandler(handler: H): H { socketJob = engine.mainScope.launch(exceptionHandler + coroutineName) { - block() + handler.handle(peerEvents) } + return handler } override fun close() { diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSubscriberSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSubscriberSocket.kt index 1c479e3..935d98e 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSubscriberSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOSubscriberSocket.kt @@ -8,7 +8,9 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* +import org.zeromq.internal.utils.* /** * An implementation of the [SUB socket](https://rfc.zeromq.org/spec/29/). @@ -78,100 +80,30 @@ internal class CIOSubscriberSocket( ) : CIOSocket(engine, Type.SUB), CIOReceiveSocket, SubscriberSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val receiveChannel = Channel() - - private var subscriptions = mutableListOf() - private var lateSubscriptionCommands = Channel(10) - - init { - setHandler { - val peerMailboxes = hashSetOf() - - while (isActive) { - select { - peerEvents.onReceive { (kind, peerMailbox) -> - when (kind) { - PeerEvent.Kind.ADDITION -> { - peerMailboxes.add(peerMailbox) - - for (subscription in subscriptions) { - logger.d { "Sending subscription ${subscription.contentToString()} to $peerMailbox" } - peerMailbox.sendChannel.send( - CommandOrMessage(SubscribeCommand(subscription)) - ) - } - } - - PeerEvent.Kind.REMOVAL -> peerMailboxes.remove(peerMailbox) - else -> {} - } - } - - lateSubscriptionCommands.onReceive { command -> - for (peerMailbox in peerMailboxes) { - logger.d { "Sending late subscription $command to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(command)) - } - } - - for (peerMailbox in peerMailboxes) { - peerMailbox.receiveChannel.onReceive { commandOrMessage -> - val message = commandOrMessage.messageOrThrow() - logger.v { "Receiving $message from $peerMailbox" } - receiveChannel.send(message) - } - } - } - } - } - } + override val handler = setupHandler(SubscriberSocketHandler()) override suspend fun subscribe() { - subscribe(listOf()) + handler.subscriptions.subscribe(listOf()) } - override suspend fun subscribe(vararg topics: ByteArray) { - subscribe(topics.toList()) + override suspend fun subscribe(vararg topics: ByteString) { + handler.subscriptions.subscribe(topics.toList()) } override suspend fun subscribe(vararg topics: String) { - subscribe(topics.map { it.encodeToByteArray() }) - } - - private suspend fun subscribe(topics: List) { - val effectiveTopics = topics.ifEmpty { listOf(byteArrayOf()) } - - subscriptions.addAll(effectiveTopics) - - for (topic in effectiveTopics) { - lateSubscriptionCommands.send(SubscribeCommand(topic)) - } + handler.subscriptions.subscribe(topics.map { it.encodeToByteString() }) } override suspend fun unsubscribe() { - unsubscribe(listOf()) + handler.subscriptions.unsubscribe(listOf()) } - override suspend fun unsubscribe(vararg topics: ByteArray) { - unsubscribe(topics.toList()) + override suspend fun unsubscribe(vararg topics: ByteString) { + handler.subscriptions.unsubscribe(topics.toList()) } override suspend fun unsubscribe(vararg topics: String) { - unsubscribe(topics.map { it.encodeToByteArray() }) - } - - private suspend fun unsubscribe(topics: List) { - val effectiveTopics = topics.ifEmpty { listOf(byteArrayOf()) } - - val removedTopics = mutableListOf() - for (topic in effectiveTopics) { - if (subscriptions.remove(topic)) removedTopics += topic - } - - for (topic in removedTopics) { - lateSubscriptionCommands.send(CancelCommand(topic)) - } + handler.subscriptions.unsubscribe(topics.map { it.encodeToByteString() }) } override var conflate: Boolean @@ -186,3 +118,42 @@ internal class CIOSubscriberSocket( private val validPeerSocketTypes = setOf(Type.PUB, Type.XPUB) } } + +internal class SubscriberSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + val subscriptions = SubscriptionManager() + + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + select { + peerEvents.onReceive { event -> + mailboxes.update(event) + + val (kind, mailbox) = event + when (kind) { + PeerEvent.Kind.ADDITION -> { + for (subscription in subscriptions.existing) { + logger.d { "Sending subscription $subscription to $mailbox" } + mailbox.sendChannel.send(CommandOrMessage(SubscribeCommand(subscription))) + } + } + + else -> {} + } + } + + subscriptions.lateSubscriptionCommands.onReceive { command -> + for (mailbox in mailboxes) { + logger.d { "Sending late subscription $command to $mailbox" } + mailbox.sendChannel.send(CommandOrMessage(command)) + } + } + } + } + } + + override suspend fun receive(): Message { + val (_, message) = mailboxes.receiveFromFirst() + return message + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXPublisherSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXPublisherSocket.kt index a3611bd..7093076 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXPublisherSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXPublisherSocket.kt @@ -8,6 +8,7 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* +import kotlinx.io.* import org.zeromq.internal.* import org.zeromq.internal.utils.* @@ -94,58 +95,7 @@ internal class CIOXPublisherSocket( ) : CIOSocket(engine, Type.XPUB), CIOSendSocket, CIOReceiveSocket, XPublisherSocket { override val validPeerTypes: Set get() = validPeerSocketTypes - - override val sendChannel = Channel() - override val receiveChannel = Channel() - - init { - setHandler { - val peerMailboxes = hashSetOf() - var subscriptions = SubscriptionTrie() - - while (isActive) { - select { - peerEvents.onReceive { (kind, peerMailbox) -> - when (kind) { - PeerEvent.Kind.ADDITION -> peerMailboxes.add(peerMailbox) - PeerEvent.Kind.REMOVAL -> peerMailboxes.remove(peerMailbox) - else -> {} - } - } - - for (peerMailbox in peerMailboxes) { - peerMailbox.receiveChannel.onReceive { commandOrMessage -> - logger.d { "Handling $commandOrMessage from $peerMailbox" } - if (commandOrMessage.isCommand) { - subscriptions = when (val command = commandOrMessage.commandOrThrow()) { - is SubscribeCommand -> { - receiveChannel.send(subscriptionMessageOf(true, command.topic)) - subscriptions.add(command.topic, peerMailbox) - } - - is CancelCommand -> { - receiveChannel.send(subscriptionMessageOf(false, command.topic)) - subscriptions.remove(command.topic, peerMailbox) - } - - else -> protocolError("Expected SUBSCRIBE or CANCEL, but got ${command.name}") - } - } else { - receiveChannel.send(commandOrMessage.messageOrThrow()) - } - } - } - - sendChannel.onReceive { message -> - subscriptions.forEachMatching(message.first()) { peerMailbox -> - logger.d { "Dispatching $message to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(message)) - } - } - } - } - } - } + override val handler = setupHandler(XPublisherSocketHandler()) override var invertMatching: Boolean get() = TODO("Not yet implemented") @@ -164,3 +114,56 @@ internal class CIOXPublisherSocket( private val validPeerSocketTypes = setOf(Type.SUB, Type.XSUB) } } + +internal class XPublisherSocketHandler : SocketHandler { + private val mailboxes = hashSetOf() + private var subscriptions = SubscriptionTrie() + + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + select { + peerEvents.onReceive { (kind, peerMailbox) -> + when (kind) { + PeerEvent.Kind.ADDITION -> mailboxes.add(peerMailbox) + PeerEvent.Kind.REMOVAL -> mailboxes.remove(peerMailbox) + else -> {} + } + } + } + } + } + + override suspend fun send(message: Message) { + subscriptions.forEachMatching(message.peekFirstFrame().readByteArray()) { peerMailbox -> + logger.d { "Dispatching $message to $peerMailbox" } + peerMailbox.sendChannel.send(CommandOrMessage(message)) + } + } + + override suspend fun receive(): Message { + return select { + for (mailbox in mailboxes) { + mailbox.receiveChannel.onReceive { commandOrMessage -> + logger.d { "Handling $commandOrMessage from $mailbox" } + if (commandOrMessage.isCommand) { + when (val command = commandOrMessage.commandOrThrow()) { + is SubscribeCommand -> { + subscriptions = subscriptions.add(command.topic, mailbox) + SubscriptionMessage(true, command.topic).toMessage() + } + + is CancelCommand -> { + subscriptions = subscriptions.remove(command.topic, mailbox) + SubscriptionMessage(false, command.topic).toMessage() + } + + else -> protocolError("Expected SUBSCRIBE or CANCEL, but got ${command.name}") + } + } else { + commandOrMessage.messageOrThrow() + } + } + } + } + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXSubscriberSocket.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXSubscriberSocket.kt index ffdbab7..0d83472 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXSubscriberSocket.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/CIOXSubscriberSocket.kt @@ -9,6 +9,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* import org.zeromq.internal.* +import org.zeromq.internal.utils.* /** * An implementation of the [XSUB socket](https://rfc.zeromq.org/spec/29/). @@ -90,92 +91,64 @@ internal class CIOXSubscriberSocket( override val validPeerTypes: Set get() = validPeerSocketTypes - override val sendChannel = Channel() - override val receiveChannel = Channel() + override val handler = setupHandler(XSubscriberSocketHandler()) - private var subscriptions = mutableListOf() - private var lateSubscriptionCommands = Channel(10) - - init { - setHandler { - val peerMailboxes = hashSetOf() - - while (isActive) { - select { - peerEvents.onReceive { (kind, peerMailbox) -> - when (kind) { - PeerEvent.Kind.ADDITION -> { - peerMailboxes.add(peerMailbox) - - for (subscription in subscriptions) { - logger.d { "Sending subscription ${subscription.contentToString()} to $peerMailbox" } - peerMailbox.sendChannel.send( - CommandOrMessage(SubscribeCommand(subscription)) - ) - } - } - - PeerEvent.Kind.REMOVAL -> peerMailboxes.remove(peerMailbox) - else -> {} - } - } + companion object { + private val validPeerSocketTypes = setOf(Type.PUB, Type.XPUB) + } +} - sendChannel.onReceive { message -> - val subscriptionTopicPair = destructureSubscriptionMessage(message) - if (subscriptionTopicPair != null) { - val (subscribe, topic) = subscriptionTopicPair - if (subscribe) subscribe(listOf(topic)) else unsubscribe(listOf(topic)) - } else { - peerMailboxes.forEach { peerMailbox -> - logger.v { "Sending message $message to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(message)) +internal class XSubscriberSocketHandler : SocketHandler { + private val mailboxes = CircularQueue() + private val subscriptions = SubscriptionManager() + + override suspend fun handle(peerEvents: ReceiveChannel) = coroutineScope { + while (isActive) { + select { + peerEvents.onReceive { event -> + mailboxes.update(event) + + val (kind, peerMailbox) = event + when (kind) { + PeerEvent.Kind.ADDITION -> { + for (subscription in subscriptions.existing) { + logger.d { "Sending subscription $subscription to $peerMailbox" } + peerMailbox.sendChannel.send(CommandOrMessage(SubscribeCommand(subscription))) } } - } - lateSubscriptionCommands.onReceive { command -> - for (peerMailbox in peerMailboxes) { - logger.d { "Sending late subscription $command to $peerMailbox" } - peerMailbox.sendChannel.send(CommandOrMessage(command)) - } + else -> {} } + } - for (peerMailbox in peerMailboxes) { - peerMailbox.receiveChannel.onReceive { commandOrMessage -> - val message = commandOrMessage.messageOrThrow() - logger.v { "Receiving $message from $peerMailbox" } - receiveChannel.send(message) - } + subscriptions.lateSubscriptionCommands.onReceive { command -> + for (peerMailbox in mailboxes) { + logger.d { "Sending late subscription $command to $peerMailbox" } + peerMailbox.sendChannel.send(CommandOrMessage(command)) } } } } } - private suspend fun subscribe(topics: List) { - val effectiveTopics = topics.ifEmpty { listOf(byteArrayOf()) } - - subscriptions.addAll(effectiveTopics) - - for (topic in effectiveTopics) { - lateSubscriptionCommands.send(SubscribeCommand(topic)) - } - } - - private suspend fun unsubscribe(topics: List) { - val effectiveTopics = topics.ifEmpty { listOf(byteArrayOf()) } - - val removedTopics = mutableListOf() - for (topic in effectiveTopics) { - if (subscriptions.remove(topic)) removedTopics += topic + override suspend fun send(message: Message) { + val subscriptionTopicPair = message.toSubscriptionMessage()?.let { + it.subscribe to it.topic } - - for (topic in removedTopics) { - lateSubscriptionCommands.send(CancelCommand(topic)) + if (subscriptionTopicPair != null) { + val (subscribe, topic) = subscriptionTopicPair + if (subscribe) subscriptions.subscribe(listOf(topic)) + else subscriptions.unsubscribe(listOf(topic)) + } else { + mailboxes.forEach { mailbox -> + logger.v { "Sending message $message to $mailbox" } + mailbox.sendChannel.send(CommandOrMessage(message)) + } } } - companion object { - private val validPeerSocketTypes = setOf(Type.PUB, Type.XPUB) + override suspend fun receive(): Message { + val (_, message) = mailboxes.receiveFromFirst() + return message } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/Command.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/Command.kt index ea7e922..3703459 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/Command.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/Command.kt @@ -5,6 +5,8 @@ package org.zeromq.internal +import kotlinx.io.bytestring.* + internal sealed interface Command { val name: CommandName } @@ -17,38 +19,38 @@ internal enum class CommandName { PING, PONG; - val bytes: ByteArray = name.encodeToByteArray() + val bytes: ByteString = ByteString(name.encodeToByteArray()) companion object { fun find(string: String): CommandName? { - return values().find { it.name.lowercase() == string.lowercase() } + return entries.find { it.name.lowercase() == string.lowercase() } } } } -internal data class ReadyCommand(val properties: Map) : Command { +internal data class ReadyCommand(val properties: Map) : Command { override val name = CommandName.READY - constructor(vararg properties: Pair) : this(mapOf(*properties)) + constructor(vararg properties: Pair) : this(mapOf(*properties)) } internal data class ErrorCommand(val reason: String) : Command { override val name = CommandName.READY } -internal data class SubscribeCommand(val topic: ByteArray) : Command { +internal data class SubscribeCommand(val topic: ByteString) : Command { override val name = CommandName.SUBSCRIBE } -internal data class CancelCommand(val topic: ByteArray) : Command { +internal data class CancelCommand(val topic: ByteString) : Command { override val name = CommandName.CANCEL } -internal data class PingCommand(val ttl: UShort, val context: ByteArray) : Command { +internal data class PingCommand(val ttl: UShort, val context: ByteString) : Command { override val name = CommandName.PING } -internal data class PongCommand(val context: ByteArray) : Command { +internal data class PongCommand(val context: ByteString) : Command { override val name = CommandName.PONG } @@ -57,11 +59,11 @@ internal enum class PropertyName(val propertyName: String) { IDENTITY("Identity"), RESOURCE("Resource"); - val bytes: ByteArray = propertyName.encodeToByteArray() + val bytes: ByteString = ByteString(propertyName.encodeToByteArray()) companion object { fun find(string: String): PropertyName? { - return values().find { it.propertyName.lowercase() == string.lowercase() } + return entries.find { it.propertyName.lowercase() == string.lowercase() } } } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/MessageOps.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/MessageOps.kt new file mode 100644 index 0000000..7f5954e --- /dev/null +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/MessageOps.kt @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.internal + +import kotlinx.io.* +import kotlinx.io.bytestring.* +import org.zeromq.* + +internal fun Message.popIdentity(): Identity { + return Identity(readFrame().readByteString()) +} + +internal fun Message.pushIdentity(identity: Identity) { + writeFrames(listOf(Buffer().apply { write(identity.value) }) + readFrames()) +} + +internal fun Message.popPrefixAddress(): List { + val frames = readFrames() + val delimiterIndex = frames.indexOfFirst { it.exhausted() } + val identities = frames.subList(0, delimiterIndex).map { it.readByteString() } + this.writeFrames(frames.subList(delimiterIndex + 1, frames.size)) + return identities +} + +internal fun Message.pushPrefixAddress(identities: List = listOf()) { + writeFrames(identities.map { Buffer().apply { write(it) } } + Buffer() + readFrames()) +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/NullMechanismHandshake.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/NullMechanismHandshake.kt index 3f5922c..94bd79c 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/NullMechanismHandshake.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/NullMechanismHandshake.kt @@ -6,13 +6,14 @@ package org.zeromq.internal import io.ktor.utils.io.* +import kotlinx.io.bytestring.* internal suspend fun nullMechanismHandshake( - localProperties: MutableMap, + localProperties: MutableMap, isServer: Boolean, input: ByteReadChannel, output: ByteWriteChannel, -): Map { +): Map { return if (isServer) { logger.v { "Expecting READY command" } val properties = expectReadyCommand(input) @@ -27,7 +28,7 @@ internal suspend fun nullMechanismHandshake( } } -private suspend fun expectReadyCommand(input: ByteReadChannel): Map { +private suspend fun expectReadyCommand(input: ByteReadChannel): Map { return when (val command = input.readCommand()) { is ReadyCommand -> command.properties is ErrorCommand -> fatalProtocolError("Peer error occurred: ${command.reason}") @@ -35,6 +36,6 @@ private suspend fun expectReadyCommand(input: ByteReadChannel): Map) { +private suspend fun ByteWriteChannel.sendReadyCommand(properties: MutableMap) { writeCommand(ReadyCommand(properties)) } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerMailbox.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerMailbox.kt index 79b79d3..7c46e35 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerMailbox.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerMailbox.kt @@ -6,6 +6,9 @@ package org.zeromq.internal import kotlinx.coroutines.channels.* +import kotlinx.io.bytestring.* +import kotlin.jvm.* +import kotlin.random.* internal class PeerMailbox(val endpoint: String, socketOptions: SocketOptions) { val receiveChannel = Channel(socketOptions.receiveQueueSize) @@ -33,19 +36,9 @@ internal class PeerMailbox(val endpoint: String, socketOptions: SocketOptions) { } } -internal class Identity(val value: ByteArray) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false - - other as Identity - - if (!value.contentEquals(other.value)) return false - - return true - } - - override fun hashCode(): Int { - return value.contentHashCode() +@JvmInline +internal value class Identity(val value: ByteString) { + companion object { + fun random() = Identity(ByteString(Random.nextBytes(ByteArray(16)))) } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerManager.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerManager.kt index a5c3df0..4f7b28e 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerManager.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/PeerManager.kt @@ -50,6 +50,7 @@ internal class PeerManager( private val _peerEvents = Channel() val peerEvents: ReceiveChannel get() = _peerEvents + @OptIn(DelicateCoroutinesApi::class) suspend fun notify(event: PeerEvent) { if (!_peerEvents.isClosedForSend) { logger.d { "Peer event: $event" } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketHandler.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketHandler.kt new file mode 100644 index 0000000..f327fd6 --- /dev/null +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketHandler.kt @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.internal + +import kotlinx.coroutines.channels.* +import org.zeromq.* + +internal interface SocketHandler { + suspend fun handle(peerEvents: ReceiveChannel) + suspend fun send(message: Message): Unit = error("Should not be called") + suspend fun receive(): Message = error("Should not be called") +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketOptions.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketOptions.kt index 907f765..6f7dc56 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketOptions.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/SocketOptions.kt @@ -5,6 +5,7 @@ package org.zeromq.internal +import kotlinx.io.bytestring.* import kotlin.time.* internal class SocketOptions { @@ -16,7 +17,7 @@ internal class SocketOptions { var lingerTimeout: Duration = Duration.INFINITE - var routingId: ByteArray? = null + var routingId: ByteString? = null } internal class PlainMechanismOptions { @@ -26,9 +27,9 @@ internal class PlainMechanismOptions { } internal class CurveMechanismOptions { - var publicKey: ByteArray? = null - var secretKey: ByteArray? = null - var serverKey: ByteArray? = null + var publicKey: ByteString? = null + var secretKey: ByteString? = null + var serverKey: ByteString? = null var asServer: Boolean = false } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/message-utils.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/message-utils.kt deleted file mode 100644 index bd263e7..0000000 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/message-utils.kt +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2022-2024 Didier Villevalois and Kzmq contributors. - * Use of this source code is governed by the Apache 2.0 license. - */ - -package org.zeromq.internal - -import org.zeromq.* - -internal fun addPrefixAddress(message: Message, identities: List = listOf()): Message = - Message(identities + ByteArray(0) + message.frames) - -internal fun extractPrefixAddress(message: Message): Pair, Message> { - val delimiterIndex = message.frames.indexOfFirst { it.isEmpty() } - val identities = message.frames.subList(0, delimiterIndex) - val data = Message(message.frames.subList(delimiterIndex + 1, message.frames.size)) - return identities to data -} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/tcp/TcpSocketHandler.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/tcp/TcpSocketHandler.kt index 915d6c0..244fb3b 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/tcp/TcpSocketHandler.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/tcp/TcpSocketHandler.kt @@ -8,6 +8,7 @@ package org.zeromq.internal.tcp import io.ktor.network.sockets.* import io.ktor.network.sockets.Socket import kotlinx.coroutines.* +import kotlinx.io.bytestring.* import org.zeromq.* import org.zeromq.internal.* @@ -41,8 +42,8 @@ internal class TcpSocketHandler( if (mechanism != peerSecuritySpec.mechanism) protocolError("Invalid peer security mechanism: ${peerSecuritySpec.mechanism}") - val localProperties = mutableMapOf().apply { - put(PropertyName.SOCKET_TYPE, socketInfo.type.name.encodeToByteArray()) + val localProperties = mutableMapOf().apply { + put(PropertyName.SOCKET_TYPE, socketInfo.type.name.encodeToByteString()) socketInfo.options.routingId?.let { identity -> put(PropertyName.IDENTITY, identity) } } @@ -81,6 +82,7 @@ internal class TcpSocketHandler( } } + @OptIn(DelicateCoroutinesApi::class) suspend fun handleLinger() { withTimeout(socketInfo.options.lingerTimeout) { try { @@ -108,7 +110,6 @@ internal class TcpSocketHandler( transformSubscriptionMessages(raw) } else raw - logger.v { "Read: $incoming" } return incoming } @@ -119,8 +120,6 @@ internal class TcpSocketHandler( } else outgoing output.writeCommandOrMessage(transformed) - - logger.v { "Wrote: $outgoing" } } } @@ -130,7 +129,9 @@ private fun transformSubscriptionMessages(commandOrMessage: CommandOrMessage): C } else commandOrMessage private fun extractSubscriptionCommand(message: Message): CommandOrMessage? { - return destructureSubscriptionMessage(message)?.let { (subscribe, topic) -> + return message.toSubscriptionMessage()?.let { + it.subscribe to it.topic + }?.let { (subscribe, topic) -> CommandOrMessage( if (subscribe) SubscribeCommand(topic) else CancelCommand(topic) ) @@ -140,16 +141,16 @@ private fun extractSubscriptionCommand(message: Message): CommandOrMessage? { private fun transformSubscriptionCommands(commandOrMessage: CommandOrMessage): CommandOrMessage = if (commandOrMessage.isCommand) { when (val command = commandOrMessage.commandOrThrow()) { - is SubscribeCommand -> CommandOrMessage(subscriptionMessageOf(true, command.topic)) + is SubscribeCommand -> CommandOrMessage(SubscriptionMessage(true, command.topic).toMessage()) - is CancelCommand -> CommandOrMessage(subscriptionMessageOf(false, command.topic)) + is CancelCommand -> CommandOrMessage(SubscriptionMessage(false, command.topic).toMessage()) else -> commandOrMessage } } else commandOrMessage private fun validateSocketType( - properties: Map, + properties: Map, peerSocketTypes: Set, ) { val socketTypeProperty = @@ -159,4 +160,4 @@ private fun validateSocketType( } private fun findSocketType(socketTypeString: String): Type? = - Type.values().find { it.name == socketTypeString.uppercase() } + Type.entries.find { it.name == socketTypeString.uppercase() } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/CircularQueue.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/CircularQueue.kt index 0032e33..c4b2ebf 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/CircularQueue.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/CircularQueue.kt @@ -7,26 +7,84 @@ package org.zeromq.internal.utils internal class CircularQueue private constructor( private val _elements: MutableList, -) : List by _elements { +) : List { + + private var currentIndex: Int = 0 constructor() : this(mutableListOf()) + val elements get() = _elements.subList(currentIndex, _elements.size) + _elements.subList(0, currentIndex) + fun add(element: T) { - _elements.add(element) + _elements.add(currentIndex, element) + currentIndex = (currentIndex + 1) % _elements.size } fun remove(element: T) { - _elements.remove(element) + val index = _elements.indexOf(element) + if (index == -1) return + + _elements.removeAt(index) + if (currentIndex > index) { + currentIndex-- + } } fun rotate(): T { - check(isNotEmpty()) { "Queue is empty." } - val mailbox = _elements.removeFirst() - _elements += mailbox - return mailbox + check(_elements.isNotEmpty()) { "Queue is empty." } + val element = _elements[currentIndex] + currentIndex = (currentIndex + 1) % _elements.size + return element + } + + fun rotate(count: Int) { + check(_elements.isNotEmpty()) { "Queue is empty." } + currentIndex = (currentIndex + count) % elements.size + } + + fun rotateAfter(index: Int) { + rotate(index + 1) + } + + override val size: Int get() = _elements.size + + override fun isEmpty(): Boolean { + return _elements.isEmpty() + } + + override fun get(index: Int): T { + return elements[index] + } + + override fun iterator(): Iterator { + return elements.iterator() + } + + override fun listIterator(): ListIterator { + return elements.listIterator() + } + + override fun listIterator(index: Int): ListIterator { + return elements.listIterator(index) + } + + override fun subList(fromIndex: Int, toIndex: Int): List { + return elements.subList(fromIndex, toIndex) } -} -internal fun CircularQueue<*>.rotateAfter(index: Int) { - repeat(index + 1) { rotate() } + override fun lastIndexOf(element: T): Int { + return elements.lastIndexOf(element) + } + + override fun indexOf(element: T): Int { + return elements.indexOf(element) + } + + override fun containsAll(elements: Collection): Boolean { + return _elements.containsAll(elements) + } + + override fun contains(element: T): Boolean { + return _elements.contains(element) + } } diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/MailboxDistribution.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/MailboxDistribution.kt new file mode 100644 index 0000000..bae7cf4 --- /dev/null +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/MailboxDistribution.kt @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.internal.utils + +import kotlinx.coroutines.selects.* +import org.zeromq.* +import org.zeromq.internal.* + +internal fun CircularQueue.update(event: PeerEvent) { + val mailbox = event.peerMailbox + when (event.kind) { + PeerEvent.Kind.ADDITION -> add(mailbox) + PeerEvent.Kind.REMOVAL -> remove(mailbox) + else -> {} + } +} + +internal suspend fun CircularQueue.sendToFirstAvailable(message: Message): PeerMailbox? { + // Fast path: Find the first mailbox we can send immediately + logger.v { "Try sending message $message to first available" } + var targetMailbox = trySendToFirstAvailable(message) + + if (targetMailbox == null) { + // Slow path: Biased select on each mailbox's onSend + logger.v { "Sending message $message to first available" } + select { + val commandOrMessage = CommandOrMessage(message) + forEachIndexed { index, mailbox -> + mailbox.sendChannel.onSend(commandOrMessage) { + logger.v { "Sent message to $mailbox" } + rotateAfter(index) + targetMailbox = mailbox + } + } + } + } + + return targetMailbox +} + +internal fun CircularQueue.trySendToFirstAvailable(message: Message): PeerMailbox? { + val commandOrMessage = CommandOrMessage(message) + val index = indexOfFirst { mailbox -> + val result = mailbox.sendChannel.trySend(commandOrMessage) + logger.v { + if (result.isSuccess) "Sent message to $mailbox" + else "Failed to send message to $mailbox" + } + result.isSuccess + } + + val targetMailbox = if (index != -1) getOrNull(index) else null + if (targetMailbox != null) rotateAfter(index) + return targetMailbox +} + +internal suspend fun CircularQueue.receiveFromFirst(): Pair { + return select { + forEachIndexed { index, mailbox -> + mailbox.receiveChannel.onReceive { commandOrMessage -> + val message = commandOrMessage.messageOrThrow() + logger.v { "Received $message from $mailbox" } + rotateAfter(index) + mailbox to message + } + } + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionManager.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionManager.kt new file mode 100644 index 0000000..2327de1 --- /dev/null +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionManager.kt @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.internal.utils + +import kotlinx.coroutines.channels.* +import kotlinx.io.bytestring.* +import org.zeromq.internal.* + +internal class SubscriptionManager { + val existing = mutableListOf() + val lateSubscriptionCommands = Channel(10) + + suspend fun subscribe(topics: List) { + val effectiveTopics = topics.ifEmpty { listOf(ByteString()) } + + existing.addAll(effectiveTopics) + + for (topic in effectiveTopics) { + lateSubscriptionCommands.send(SubscribeCommand(topic)) + } + } + + suspend fun unsubscribe(topics: List) { + val effectiveTopics = topics.ifEmpty { listOf(ByteString()) } + + val removedTopics = mutableListOf() + for (topic in effectiveTopics) { + if (existing.remove(topic)) removedTopics += topic + } + + for (topic in removedTopics) { + lateSubscriptionCommands.send(CancelCommand(topic)) + } + } +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionTrie.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionTrie.kt index bb2bcb8..7a2183b 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionTrie.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/utils/SubscriptionTrie.kt @@ -5,6 +5,8 @@ package org.zeromq.internal.utils +import kotlinx.io.bytestring.* + /** * Represents a subscription trie. * @@ -20,7 +22,7 @@ internal data class SubscriptionTrie( val subscriptions: Map = hashMapOf(), val children: Map> = hashMapOf(), ) { - fun add(prefix: ByteArray, element: T): SubscriptionTrie = this.add(prefix.iterator(), element) + fun add(prefix: ByteString, element: T): SubscriptionTrie = this.add(prefix.iterator(), element) private fun add(prefix: ByteIterator, element: T): SubscriptionTrie = if (prefix.hasNext()) { val byte = prefix.nextByte() @@ -31,7 +33,7 @@ internal data class SubscriptionTrie( this.copy(subscriptions = subscriptions + (element to newCount)) } - fun remove(prefix: ByteArray, element: T): SubscriptionTrie = this.remove(prefix.iterator(), element) + fun remove(prefix: ByteString, element: T): SubscriptionTrie = this.remove(prefix.iterator(), element) private fun remove(prefix: ByteIterator, element: T): SubscriptionTrie = if (prefix.hasNext()) { val byte = prefix.nextByte() @@ -69,3 +71,9 @@ internal data class SubscriptionTrie( } } } + +private fun ByteString.iterator() = object : ByteIterator() { + private var index = 0 + override fun hasNext(): Boolean = index < size + override fun nextByte(): Byte = get(index++) +} diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-reading.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-reading.kt index 9b9ef1b..92d6f9b 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-reading.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-reading.kt @@ -9,6 +9,8 @@ import io.ktor.utils.io.* import io.ktor.utils.io.core.* import kotlinx.coroutines.channels.* import kotlinx.io.* +import kotlinx.io.Buffer +import kotlinx.io.bytestring.* import org.zeromq.* @OptIn(ExperimentalUnsignedTypes::class) @@ -65,15 +67,15 @@ private fun Source.readCommandContent(): Command { null -> invalidFrame("Invalid command name: ${readShortString()}") CommandName.READY -> ReadyCommand(readProperties()) CommandName.ERROR -> ErrorCommand(readShortString()) - CommandName.SUBSCRIBE -> SubscribeCommand(readByteArray()) - CommandName.CANCEL -> CancelCommand(readByteArray()) - CommandName.PING -> PingCommand(readUShort(), readByteArray()) - CommandName.PONG -> PongCommand(readByteArray()) + CommandName.SUBSCRIBE -> SubscribeCommand(readByteString()) + CommandName.CANCEL -> CancelCommand(readByteString()) + CommandName.PING -> PingCommand(readUShort(), readByteString()) + CommandName.PONG -> PongCommand(readByteString()) } } -private fun Source.readProperties(): Map { - val properties = mutableMapOf() +private fun Source.readProperties(): Map { + val properties = mutableMapOf() while (remaining > 0) { val (propertyName, value) = readProperty() properties[propertyName] = value @@ -81,11 +83,11 @@ private fun Source.readProperties(): Map { return properties } -private fun Source.readProperty(): Pair { +private fun Source.readProperty(): Pair { val propertyNameString = readShortString() val propertyName = PropertyName.find(propertyNameString) ?: invalidFrame("Can't read property") val valueSize = readInt() - val valueBytes = readByteArray(valueSize) + val valueBytes = readByteString(valueSize) return propertyName to valueBytes } @@ -96,7 +98,7 @@ private fun Source.readShortString(): String { private suspend fun ByteReadChannel.readMessageContent(initialFlags: ZmqFlags): Message { var flags = initialFlags - val parts = mutableListOf() + val parts = mutableListOf() do { if (flags.isCommand) invalidFrame("Expected message") @@ -110,9 +112,9 @@ private suspend fun ByteReadChannel.readMessageContent(initialFlags: ZmqFlags): return Message(parts) } -private suspend fun ByteReadChannel.readMessagePartContent(flags: ZmqFlags): ByteArray { +private suspend fun ByteReadChannel.readMessagePartContent(flags: ZmqFlags): Buffer { val size = readSize(flags) - return readBuffer(size.toInt()).readByteArray() + return readBuffer(size.toInt()) } private suspend fun ByteReadChannel.readSize(flags: ZmqFlags): Long { diff --git a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-writing.kt b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-writing.kt index d8debaf..5d35a22 100644 --- a/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-writing.kt +++ b/kzmq-cio/src/commonMain/kotlin/org/zeromq/internal/wire-format-writing.kt @@ -10,6 +10,7 @@ import io.ktor.utils.io.core.* import io.ktor.utils.io.core.writeUByte import kotlinx.io.* import kotlinx.io.Buffer +import kotlinx.io.bytestring.* import org.zeromq.* private suspend inline fun ByteWriteChannel.write(writer: Sink.() -> Unit) { @@ -68,19 +69,19 @@ internal suspend fun ByteWriteChannel.writeCommand(command: ReadyCommand) = writ internal suspend fun ByteWriteChannel.writeCommand(command: ErrorCommand) = write { writeCommand(CommandName.READY) { - writeShortString(command.reason.encodeToByteArray()) + writeShortString(command.reason.encodeToByteString()) } } internal suspend fun ByteWriteChannel.writeCommand(command: SubscribeCommand) = write { writeCommand(CommandName.SUBSCRIBE) { - writeFully(command.topic) + write(command.topic) } } internal suspend fun ByteWriteChannel.writeCommand(command: CancelCommand) = write { writeCommand(CommandName.CANCEL) { - writeFully(command.topic) + write(command.topic) } } @@ -88,13 +89,13 @@ internal suspend fun ByteWriteChannel.writeCommand(command: CancelCommand) = wri internal suspend fun ByteWriteChannel.writeCommand(command: PingCommand) = write { writeCommand(CommandName.PING) { writeUShort(command.ttl) - writeFully(command.context) + write(command.context) } } internal suspend fun ByteWriteChannel.writeCommand(command: PongCommand) = write { writeCommand(CommandName.PONG) { - writeFully(command.context) + write(command.context) } } @@ -113,20 +114,20 @@ private fun Sink.writeCommand( private fun Sink.writeProperty( propertyName: PropertyName, - valueBytes: ByteArray, + valueBytes: ByteString, ) { writeShortString(propertyName.bytes) writeInt(valueBytes.size) - writeFully(valueBytes) + write(valueBytes) } -private fun Sink.writeShortString(bytes: ByteArray) { +private fun Sink.writeShortString(bytes: ByteString) { writeUByte(bytes.size.toUByte()) - writeFully(bytes) + write(bytes) } private suspend fun ByteWriteChannel.writeMessage(message: Message) = write { - val parts = message.frames + val parts = message.readFrames() val lastIndex = parts.lastIndex for ((index, part) in parts.withIndex()) { val hasMore = index < lastIndex @@ -134,10 +135,10 @@ private suspend fun ByteWriteChannel.writeMessage(message: Message) = write { } } -private fun Sink.writeMessagePart(hasMore: Boolean, part: ByteArray) { +private fun Sink.writeMessagePart(hasMore: Boolean, part: Buffer) { val flags = if (hasMore) ZmqFlags.more else ZmqFlags.none - writeFrameHeader(flags, part.size.toLong()) - writeFully(part) + writeFrameHeader(flags, part.size) + transferFrom(part) } private fun Sink.writeFrameHeader(flags: ZmqFlags, size: Long) { diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/PairSocketHandlerTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PairSocketHandlerTests.kt new file mode 100644 index 0000000..a9fc6ea --- /dev/null +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PairSocketHandlerTests.kt @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import io.kotest.core.spec.style.* +import io.kotest.core.test.* +import io.kotest.matchers.* +import kotlinx.coroutines.* +import org.zeromq.internal.* +import org.zeromq.test.* +import org.zeromq.utils.* +import kotlin.time.Duration.Companion.seconds + +class PairSocketHandlerTests : FunSpec({ + suspend fun TestScope.withHandler(test: SocketHandlerTest) = + withSocketHandler(PairSocketHandler(), test) + + test("SHALL consider a peer as available only when it has an outgoing queue that is not full") { + withHandler { peerEvents, send, _ -> + val peer = PeerMailbox("peer", SocketOptions().apply { sendQueueSize = 5 }).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + } + + val messages = messages(5) { index -> writeFrame { writeByte(index.toByte()) } } + + // Send each message of the first batch once + messages.forEach { send(it.buildMessage()) } + + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + + peer.sendChannel shouldReceiveExactly messages + } + } + + test("SHALL suspend on sending when it has no available peer") { + withHandler { peerEvents, send, _ -> + val peer = PeerMailbox("peer", SocketOptions().apply { sendQueueSize = 5 }).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + } + + val messages = messages(5) { index -> writeFrame { writeByte(index.toByte()) } } + val blockedMessage = message { writeFrame { writeByte((10).toByte()) } } + + // Send each message of the first batch once + messages.forEach { send(it.buildMessage()) } + + withTimeoutOrNull(1.seconds) { + // Send an additional message + send(blockedMessage.buildMessage()) + } shouldBe null + + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + + peer.sendChannel shouldReceiveExactly messages + } + } + + test("SHALL not accept further messages when it has no available peer") { + withHandler { peerEvents, send, _ -> + val peer = PeerMailbox("peer", SocketOptions().apply { sendQueueSize = 5 }).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + } + + val messages = messages(5) { index -> writeFrame { writeByte(index.toByte()) } } + val blockedMessage = message { writeFrame { writeByte((10).toByte()) } } + + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + + // Send each message of the first batch once + messages.forEach { send(it.buildMessage()) } + + peerEvents.send(PeerEvent(PeerEvent.Kind.DISCONNECTION, peer)) + + withTimeoutOrNull(1.seconds) { + // Send an additional message + send(blockedMessage.buildMessage()) + } shouldBe null + + peer.sendChannel shouldReceiveExactly messages + } + } + + test("SHALL receive incoming messages from its single peer if it has one") { + withHandler { peerEvents, _, receive -> + val peer = PeerMailbox("peer", SocketOptions().apply { sendQueueSize = 5 }).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + + val messages = messages(5) { index -> writeFrame { writeByte(index.toByte()) } } + + // Send each message of the first batch once + messages.forEach { peer.receiveChannel.send(CommandOrMessage(it.buildMessage())) } + + receive shouldReceiveExactly messages + } + } +}) diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/PullSocketHandlerTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PullSocketHandlerTests.kt index 73b4d6c..afdecba 100644 --- a/kzmq-cio/src/commonTest/kotlin/org/zeromq/PullSocketHandlerTests.kt +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PullSocketHandlerTests.kt @@ -8,51 +8,53 @@ package org.zeromq import io.kotest.assertions.* import io.kotest.core.spec.style.* import io.kotest.core.test.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import org.zeromq.internal.* +import org.zeromq.test.* import org.zeromq.utils.* class PullSocketHandlerTests : FunSpec({ + suspend fun TestScope.withHandler(test: SocketHandlerTest) = + withSocketHandler(PullSocketHandler(), test) test("SHALL receive incoming messages from its peers using a fair-queuing strategy") { - withHandler { peerEvents, receiveChannel -> - val peers = List(5) { index -> PeerMailbox(index.toString(), SocketOptions()) } - - peers.forEach { peer -> - peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) - peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + withHandler { peerEvents, _, receive -> + val peers = List(5) { index -> + PeerMailbox(index.toString(), SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } } - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } + val messages = messages(10) { index -> + writeFrame { writeByte(index.toByte()) } + } peers.forEach { peer -> - messages.forEach { message -> - peer.receiveChannel.send(CommandOrMessage(message)) - } + messages.forEach { peer.receiveChannel.send(it) } } all { messages.forEach { message -> - receiveChannel shouldReceiveExactly List(peers.size) { message } + receive shouldReceiveExactly List(peers.size) { message } } } } } -}) -private suspend fun TestScope.withHandler( - block: suspend TestScope.( - peerEvents: SendChannel, - receiveChannel: ReceiveChannel, - ) -> Unit, -) = coroutineScope { - val peerEvents = Channel() - val receiveChannel = Channel() + test("SHALL deliver these to its calling application") { + withHandler { peerEvents, _, receive -> + val peer = PeerMailbox("peer", SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } - val handlerJob = launch { handlePullSocket(peerEvents, receiveChannel) } + val messages = messages(10) { index -> + writeFrame { writeByte(index.toByte()) } + } - block(peerEvents, receiveChannel) + messages.forEach { peer.receiveChannel.send(it) } - handlerJob.cancelAndJoin() -} + receive shouldReceiveExactly messages + } + } +}) diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/PushSocketHandlerTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PushSocketHandlerTests.kt index 7f2835c..955b63a 100644 --- a/kzmq-cio/src/commonTest/kotlin/org/zeromq/PushSocketHandlerTests.kt +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/PushSocketHandlerTests.kt @@ -10,15 +10,18 @@ import io.kotest.core.spec.style.* import io.kotest.core.test.* import io.kotest.matchers.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* +import kotlinx.io.bytestring.* import org.zeromq.internal.* +import org.zeromq.test.* import org.zeromq.utils.* import kotlin.time.Duration.Companion.seconds class PushSocketHandlerTests : FunSpec({ + suspend fun TestScope.withHandler(test: SocketHandlerTest) = + withSocketHandler(PushSocketHandler(), test) test("SHALL consider a peer as available only when it has an outgoing queue that is not full") { - withHandler { peerEvents, sendChannel -> + withHandler { peerEvents, send, _ -> val peer1 = PeerMailbox("1", SocketOptions()) val peer2 = PeerMailbox("2", SocketOptions().apply { sendQueueSize = 5 }) @@ -27,13 +30,13 @@ class PushSocketHandlerTests : FunSpec({ peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer1)) - val firstBatch = List(5) { index -> Message(ByteArray(1) { index.toByte() }) } - val secondBatch = List(10) { index -> Message(ByteArray(1) { (index + 10).toByte() }) } + val firstBatch = messages(5) { index -> writeFrame { writeByte(index.toByte()) } } + val secondBatch = messages(10) { index -> writeFrame { writeByte((index + 10).toByte()) } } - // Send each message of the first batch once per receiver - firstBatch.forEach { message -> repeat(2) { sendChannel.send(message) } } + // Send each message of the first batch once per peer + firstBatch.forEach { message -> repeat(2) { send(message.buildMessage()) } } // Send each message of the second batch once - secondBatch.forEach { message -> sendChannel.send(message) } + secondBatch.forEach { message -> send(message.buildMessage()) } peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer2)) @@ -45,7 +48,7 @@ class PushSocketHandlerTests : FunSpec({ } test("SHALL route outgoing messages to available peers using a round-robin strategy") { - withHandler { peerEvents, sendChannel -> + withHandler { peerEvents, send, _ -> val peers = List(5) { index -> PeerMailbox(index.toString(), SocketOptions()) } peers.forEach { peer -> @@ -53,47 +56,57 @@ class PushSocketHandlerTests : FunSpec({ peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) } - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } - - // Send each message once per receiver - messages.forEach { message -> repeat(peers.size) { sendChannel.send(message) } } + // Send each message once per peer + repeat(10) { messageIndex -> + repeat(peers.size) { peerIndex -> + send(message { + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }.buildMessage()) + } + } all { - // Check each receiver got every messages - peers.forEach { peer -> peer.sendChannel shouldReceiveExactly messages } + peers.forEachIndexed { peerIndex, peer -> + peer.sendChannel shouldReceiveExactly + messages(10) { messageIndex -> + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + } + } } } } test("SHALL suspend on sending when it has no available peers") { - withHandler { _, sendChannel -> - val message = Message("Won't be sent".encodeToByteArray()) + withHandler { _, send, _ -> + val message = buildMessage { writeFrame("Won't be sent".encodeToByteString()) } withTimeoutOrNull(1.seconds) { - sendChannel.send(message) + send(message) } shouldBe null } } test("SHALL not accept further messages when it has no available peers") { - withHandler { _, sendChannel -> - val message = Message("Won't be sent".encodeToByteArray()) + withHandler { _, send, _ -> + val message = buildMessage { writeFrame("Won't be sent".encodeToByteString()) } withTimeoutOrNull(1.seconds) { - sendChannel.send(message) + send(message) } shouldBe null } } test("SHALL NOT discard messages that it cannot queue") { - withHandler { peerEvents, sendChannel -> + withHandler { peerEvents, send, _ -> val peer = PeerMailbox("1", SocketOptions()) peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } + val messages = messages(10) { index -> writeFrame { writeByte(index.toByte()) } } // Send each message once - messages.forEach { message -> sendChannel.send(message) } + messages.forEach { send(it.buildMessage()) } peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) @@ -102,19 +115,3 @@ class PushSocketHandlerTests : FunSpec({ } } }) - -private suspend fun TestScope.withHandler( - block: suspend TestScope.( - peerEvents: SendChannel, - sendChannel: SendChannel, - ) -> Unit, -) = coroutineScope { - val peerEvents = Channel() - val sendChannel = Channel() - - val handlerJob = launch { handlePushSocket(peerEvents, sendChannel) } - - block(peerEvents, sendChannel) - - handlerJob.cancelAndJoin() -} diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/ReplySocketHandlerTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/ReplySocketHandlerTests.kt new file mode 100644 index 0000000..9e51426 --- /dev/null +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/ReplySocketHandlerTests.kt @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import io.kotest.assertions.* +import io.kotest.core.spec.style.* +import io.kotest.core.test.* +import kotlinx.coroutines.* +import kotlinx.io.bytestring.* +import org.zeromq.internal.* +import org.zeromq.test.* +import org.zeromq.utils.* + +class ReplySocketHandlerTests : FunSpec({ + suspend fun TestScope.withHandler(test: SocketHandlerTest) = + withSocketHandler(ReplySocketHandler(), test) + + test("SHALL receive incoming messages from its peers using a fair-queuing strategy") { + withHandler { peerEvents, send, receive -> + val peerCount = 5 + val messageCount = 10 + + val peers = List(peerCount) { index -> + PeerMailbox(index.toString(), SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + } + + peers.forEachIndexed { peerIndex, peer -> + launch { + repeat(messageCount) { messageIndex -> + peer.receiveChannel.send(CommandOrMessage(message { + writeFrame("dummy-address-$messageIndex#1") + writeFrame("dummy-address-$messageIndex#2") + writeFrame("dummy-address-$messageIndex#3") + writeEmptyFrame() + writeFrame("REQUEST".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }.buildMessage())) + } + } + } + + all { + repeat(messageCount) { messageIndex -> + peers.forEachIndexed { peerIndex, peer -> + receive shouldReceiveExactly listOf(message { + writeFrame("REQUEST".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }) + + send(message { + writeFrame("REPLY".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }.buildMessage()) + + peer.sendChannel shouldReceiveExactly listOf(message { + writeFrame("dummy-address-$messageIndex#1") + writeFrame("dummy-address-$messageIndex#2") + writeFrame("dummy-address-$messageIndex#3") + writeEmptyFrame() + writeFrame("REPLY".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }) + } + } + } + } + } +}) diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/RequestSocketHandlerTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/RequestSocketHandlerTests.kt new file mode 100644 index 0000000..0dffcf7 --- /dev/null +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/RequestSocketHandlerTests.kt @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import io.kotest.core.spec.style.* +import io.kotest.core.test.* +import io.kotest.matchers.* +import kotlinx.coroutines.* +import kotlinx.io.bytestring.* +import org.zeromq.internal.* +import org.zeromq.test.* +import org.zeromq.utils.* +import kotlin.time.Duration.Companion.seconds + +class RequestSocketHandlerTests : FunSpec({ + suspend fun TestScope.withHandler(test: SocketHandlerTest) = + withSocketHandler(RequestSocketHandler(), test) + + test("SHALL prefix the outgoing message with an empty delimiter frame") { + withHandler { peerEvents, send, _ -> + val peer = PeerMailbox("peer", SocketOptions().apply { sendQueueSize = 5 }).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + + val request = message { writeFrame("Hello") } + + send.send(request) + + peer.sendChannel shouldReceiveExactly listOf(message { + writeEmptyFrame() + writeFrame("Hello") + }) + } + } + + test("SHALL route outgoing messages to connected peers using a round-robin strategy") { + withHandler { peerEvents, send, receive -> + val peerCount = 5 + val messageCount = 10 + + val peers = List(peerCount) { index -> + PeerMailbox(index.toString(), SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + } + + repeat(messageCount) { messageIndex -> + peers.forEachIndexed { peerIndex, peer -> + send(message { + writeFrame("REQUEST".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }.buildMessage()) + + peer.sendChannel shouldReceiveExactly listOf(message { + writeEmptyFrame() + writeFrame("REQUEST".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }) + + peer.receiveChannel.send(CommandOrMessage(message { + writeEmptyFrame() + writeFrame("REPLY".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }.buildMessage())) + + receive shouldReceiveExactly listOf(message { + writeFrame("REPLY".encodeToByteString()) + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(peerIndex.toByte()) } + }) + } + } + } + } + + test("SHALL suspend on sending when it has no available peers") { + withHandler { _, send, _ -> + val message = buildMessage { writeFrame("Won't be sent".encodeToByteString()) } + + withTimeoutOrNull(1.seconds) { + send(message) + } shouldBe null + } + } + + + test("SHALL accept an incoming message only from the last peer that it sent a request to") { + withHandler { peerEvents, send, receive -> + val peers = List(2) { index -> + PeerMailbox(index.toString(), SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + } + + send(message { + writeFrame("REQUEST".encodeToByteString()) + }.buildMessage()) + + peers[0].sendChannel shouldReceiveExactly listOf(message { + writeEmptyFrame() + writeFrame("REQUEST".encodeToByteString()) + }) + + peers[1].receiveChannel.send(CommandOrMessage(message { + writeEmptyFrame() + writeFrame("IGNORED-REPLY".encodeToByteString()) + }.buildMessage())) + + receive.shouldReceiveNothing() + } + } + + test("SHALL discard silently any messages received from other peers") { + withHandler { peerEvents, send, receive -> + val peers = List(2) { index -> + PeerMailbox(index.toString(), SocketOptions()).also { peer -> + peerEvents.send(PeerEvent(PeerEvent.Kind.ADDITION, peer)) + peerEvents.send(PeerEvent(PeerEvent.Kind.CONNECTION, peer)) + } + } + + send(message { + writeFrame("REQUEST".encodeToByteString()) + }.buildMessage()) + + peers[0].sendChannel shouldReceiveExactly listOf(message { + writeEmptyFrame() + writeFrame("REQUEST".encodeToByteString()) + }) + + repeat(10) { + peers[1].receiveChannel.send(CommandOrMessage(message { + writeEmptyFrame() + writeFrame("IGNORED-REPLY".encodeToByteString()) + }.buildMessage())) + } + + receive.shouldReceiveNothing() + + peers[0].receiveChannel.send(CommandOrMessage(message { + writeEmptyFrame() + writeFrame("REPLY".encodeToByteString()) + }.buildMessage())) + + receive shouldReceiveExactly listOf(message { + writeFrame("REPLY".encodeToByteString()) + }) + } + } +}) diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/CircularQueueTests.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/CircularQueueTests.kt index cdb1d7a..00b7190 100644 --- a/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/CircularQueueTests.kt +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/CircularQueueTests.kt @@ -8,6 +8,7 @@ package org.zeromq.internal.utils import io.kotest.assertions.throwables.* import io.kotest.core.spec.style.* import io.kotest.matchers.* +import io.kotest.matchers.equals.* @Suppress("unused") class CircularQueueTests : FunSpec({ @@ -32,4 +33,42 @@ class CircularQueueTests : FunSpec({ queue.rotate() shouldBe 3 } } + + test("remove current element") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.remove(1) + queue.elements shouldBeEqual listOf(2, 3) + } + + test("remove other element") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.remove(2) + queue.elements shouldBeEqual listOf(1, 3) + } + + test("rotates to original with rotate(count)") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.rotate(3) + queue.elements shouldBeEqual listOf(1, 2, 3) + } + + test("rotates to original with rotateAfter(index)") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.rotateAfter(2) + queue.elements shouldBeEqual listOf(1, 2, 3) + } + + test("remove last element after rotation") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.rotate(2) + queue.remove(2) + queue.elements shouldBeEqual listOf(3, 1) + } + + test("remove next element after rotation") { + val queue = CircularQueue().apply { add(1); add(2); add(3) } + queue.rotate(2) + queue.remove(3) + queue.elements shouldBeEqual listOf(1, 2) + } }) diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/SubscriptionTrieTest.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/SubscriptionTrieTest.kt index 4d589e9..e2ac3bf 100644 --- a/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/SubscriptionTrieTest.kt +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/internal/utils/SubscriptionTrieTest.kt @@ -6,6 +6,7 @@ package org.zeromq.internal.utils import kotlinx.coroutines.test.* +import kotlinx.io.bytestring.* import kotlin.test.* internal class SubscriptionTrieTest { @@ -63,8 +64,8 @@ internal class SubscriptionTrieTest { } private fun SubscriptionTrie.add(prefix: String, element: T) = - add(prefix.encodeToByteArray(), element) + add(prefix.encodeToByteString(), element) private fun SubscriptionTrie.remove(prefix: String, element: T) = - remove(prefix.encodeToByteArray(), element) + remove(prefix.encodeToByteString(), element) } diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/Matchers.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/Matchers.kt index 682079c..15a7b24 100644 --- a/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/Matchers.kt +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/Matchers.kt @@ -5,37 +5,43 @@ package org.zeromq.utils -import io.kotest.assertions.* -import io.kotest.assertions.print.* -import io.kotest.matchers.collections.* +import io.kotest.matchers.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import org.zeromq.* import org.zeromq.internal.* +import org.zeromq.test.* import kotlin.jvm.* +import kotlin.time.Duration.Companion.seconds @JvmName("messageChannelShouldReceiveExactly") -internal suspend infix fun ReceiveChannel.shouldReceiveExactly(expected: List) { +internal suspend infix fun ReceiveChannel.shouldReceiveExactly(expected: List) { shouldReceiveExactly(expected) { receive() } } @JvmName("commandOrMessageChannelShouldReceiveExactly") -internal suspend infix fun ReceiveChannel.shouldReceiveExactly(expected: List) { +internal suspend infix fun ReceiveChannel.shouldReceiveExactly(expected: List) { shouldReceiveExactly(expected) { receive().messageOrThrow() } } -private suspend fun shouldReceiveExactly(expected: List, receive: suspend () -> Message) { - val received = mutableListOf() - try { - repeat(expected.size) { - received += receive() - } - received shouldContainExactly expected - } catch (e: TimeoutCancellationException) { - throw failure( - Expected(expected.print()), - Actual(received.print()), - "Only ${received.size} of the expected ${expected.size} messages were received.", - ) - } +@JvmName("messageChannelSend") +internal suspend fun SendChannel.send(template: MessageTemplate) = + send(buildMessage { template.frames.forEach { writeFrame(it) } }) + +@JvmName("commandOrMessageChannelSend") +internal suspend fun SendChannel.send(template: MessageTemplate) = + send(CommandOrMessage(buildMessage { template.frames.forEach { writeFrame(it) } })) + +internal suspend infix fun (suspend () -> Message).shouldReceiveExactly(expected: List) { + shouldReceiveExactly(expected) { this() } +} + +internal suspend fun (suspend () -> Message).shouldReceiveNothing() { + withTimeoutOrNull(1.seconds) { + this@shouldReceiveNothing() + } shouldBe null +} + +internal suspend infix fun (suspend (Message) -> Unit).send(expected: MessageTemplate) { + this(expected.buildMessage()) } diff --git a/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/SocketHandlerTest.kt b/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/SocketHandlerTest.kt new file mode 100644 index 0000000..43a649b --- /dev/null +++ b/kzmq-cio/src/commonTest/kotlin/org/zeromq/utils/SocketHandlerTest.kt @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.utils + +import io.kotest.core.test.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import org.zeromq.* +import org.zeromq.internal.* + +internal typealias SocketHandlerTest = + suspend TestScope.( + peerEvents: SendChannel, + send: suspend (Message) -> Unit, + receive: suspend () -> Message, + ) -> Unit + +internal suspend fun TestScope.withSocketHandler( + handler: SocketHandler, + block: SocketHandlerTest, +) = coroutineScope { + val peerEvents = Channel() + + val handlerJob = launch { handler.handle(peerEvents) } + + block(peerEvents, handler::send, handler::receive) + + handlerJob.cancelAndJoin() +} diff --git a/kzmq-core/build.gradle.kts b/kzmq-core/build.gradle.kts index b482a67..6e56893 100644 --- a/kzmq-core/build.gradle.kts +++ b/kzmq-core/build.gradle.kts @@ -17,6 +17,14 @@ kotlin { jvmTargets() jsTargets() nativeTargets { it.isSupportedByCIO || it.isSupportedByLibzmq } + + sourceSets { + commonMain { + dependencies { + implementation(libs.kotlinx.io.core) + } + } + } } tasks.withType { diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/DealerSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/DealerSocket.kt index 5c841ac..45ceff4 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/DealerSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/DealerSocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A ZeroMQ socket of type [DEALER][Type.DEALER]. * Peers must be [ReplySocket]s or [RouterSocket]s. @@ -52,7 +54,7 @@ public interface DealerSocket : Socket, SendSocket, ReceiveSocket { * * See [ZMQ_ROUTING_ID](http://api.zeromq.org/master:zmq-getsockopt) */ - public var routingId: ByteArray? + public var routingId: ByteString? /** * When set to `true`, the socket will automatically send an empty message when a new diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/Message.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/Message.kt index 26daa95..436954f 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/Message.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/Message.kt @@ -5,6 +5,9 @@ package org.zeromq +import kotlinx.io.* +import kotlinx.io.bytestring.* + /** * A ZeroMQ message container. Messages carry application data and are not generally created, modified, * or filtered by the ZMTP implementation except in some cases. Messages consist of one or more frames @@ -12,61 +15,79 @@ package org.zeromq * * @param frames the frames of the message. */ -public data class Message(val frames: List) { - init { - require(frames.isNotEmpty()) { "A message should contain at least one frame" } - } - +public class Message(private var frames: List) : ReadScope, WriteScope { /** * Builds a ZeroMQ message. * * @param frames the frames of the message. */ - public constructor(vararg frames: ByteArray) : this(frames.toList()) + public constructor(vararg frames: Buffer) : this(frames.toList()) /** * Returns `true` if this message contains a single frame. */ - val isSingle: Boolean get() = frames.size == 1 + public val isSingle: Boolean get() = frames.size == 1 /** * Returns `true` if this message contains more than one frame. */ - val isMultipart: Boolean get() = frames.size > 1 + public val isMultipart: Boolean get() = frames.size > 1 /** * Returns the single frame of this message, or throws if this message is multipart. */ - public fun singleOrThrow(): ByteArray = if (isSingle) frames[0] else error("Message is multipart") + public fun singleOrThrow(): Source { + if (!isSingle) error("Message is multipart") + return readFrame() + } /** - * Returns the first frame. + * Peeks the first frame. */ - public fun first(): ByteArray = frames.first() + public fun peekFirstFrame(): Source = frames.first().copy().peek() - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false + public override fun readFrame(): Buffer { + val frame = frames.first() + this.frames = frames.drop(1) + return frame + } - other as Message + override fun ignoreRemainingFrames() { + readFrames().forEach { frame -> frame.skip(frame.size) } + } - if (frames.size != other.frames.size) return false - for (i in frames.indices) { - if (!frames[i].contentEquals(other.frames[i])) return false - } + override fun ensureNoRemainingFrames() { + if (frames.isNotEmpty()) error("Remaining ${frames.size} frame(s) in $this") + } - return true + public fun readFrames(): List { + val frames = this.frames + this.frames = listOf() + return frames } - override fun hashCode(): Int { - var result = 0 - for (frame in frames) { - result = 31 * result + frame.contentHashCode() - } - return result + public override fun writeFrame(buffer: Buffer) { + frames += buffer } + public fun writeFrames(sources: List) { + frames += sources + } + + public fun copy(): Message { + return Message(frames.map { it.copy() }) + } + + @OptIn(ExperimentalStdlibApi::class) override fun toString(): String { - return "Message(parts=${frames.joinToString { it.contentToString() }})" + return "Message(frames=${frames.joinToString { it.copy().readByteString().toHexString() }})" } } + +public fun Message(frames: List): Message = Message(frames.map { Buffer().apply { write(it) } }) + +public fun Message(frame: ByteString): Message = Message(listOf(frame)) + +public fun Message(frame: String): Message = Message(listOf(frame.encodeToByteString())) + +public fun buildMessage(writer: WriteScope.() -> Unit): Message = Message().apply(writer) diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/MessageUtils.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/MessageUtils.kt deleted file mode 100644 index b7e60fb..0000000 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/MessageUtils.kt +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2024 Didier Villevalois and Kzmq contributors. - * Use of this source code is governed by the Apache 2.0 license. - */ - -package org.zeromq - -/** - * Builds a subscription or unsubscription message. - * - * @param subscribe `true` if the message is a subscription message, false if it is an unsubscription message. - * @param topic the topic to subscription/unsubscription message. - */ -public fun subscriptionMessageOf(subscribe: Boolean, topic: ByteArray): Message { - val bytes = ByteArray(topic.size + 1) { index -> - if (index == 0) if (subscribe) 1 else 0 - else topic[index - 1] - } - return Message(bytes) -} - -/** - * Destructures a subscription or unsubscription message. - * - * @return a pair consisting of a boolean indicating if the message is a subscription message or an unsubscription - * message, and a topic. Returns `null` if the message is not a subscription/unsubscription message. - */ -public fun destructureSubscriptionMessage(message: Message): Pair? { - if (message.isSingle) { - val bytes = message.singleOrThrow() - val firstByte = bytes[0].toInt() - if (firstByte == 0 || firstByte == 1) { - val subscribe = firstByte == 1 - val topic = bytes.sliceArray(1 until bytes.size) - return subscribe to topic - } - } - return null -} diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/ReadScope.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReadScope.kt new file mode 100644 index 0000000..ba01ca6 --- /dev/null +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReadScope.kt @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import kotlinx.io.* + +public interface ReadScope { + + public fun readFrame(): Buffer + + public fun ignoreRemainingFrames() + + public fun ensureNoRemainingFrames() +} + +public inline fun ReadScope.readFrame(reader: Source.() -> T): T { + val frame = readFrame() + val result = frame.reader() + if (!frame.exhausted()) error("Message frame is not exhausted: $frame") + return result +} diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/ReceiveSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReceiveSocket.kt index f716804..29ddf0c 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/ReceiveSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReceiveSocket.kt @@ -68,3 +68,18 @@ public interface ReceiveSocket { */ public var receiveTimeout: Int } + +public suspend inline fun ReceiveSocket.receive(crossinline block: ReadScope.() -> T): T = + receive().let { it.checkingNoRemainingFrames { block() } } + +public suspend inline fun ReceiveSocket.receiveCatching(crossinline block: ReadScope.() -> T): SocketResult = + receiveCatching().map { it.checkingNoRemainingFrames { block() } } + +public inline fun ReceiveSocket.tryReceive(crossinline block: ReadScope.() -> T): SocketResult = + tryReceive().map { it.checkingNoRemainingFrames { block() } } + +public inline fun Message.checkingNoRemainingFrames(crossinline block: ReadScope.() -> T): T { + val result = block() + ensureNoRemainingFrames() + return result +} diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/ReplySocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReplySocket.kt index 37ee2fb..52b85b4 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/ReplySocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/ReplySocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A ZeroMQ socket of type [REP][Type.REP]. * Peers must be [RequestSocket]s or [DealerSocket]s. @@ -35,5 +37,5 @@ public interface ReplySocket : Socket, SendSocket, ReceiveSocket { * * See [ZMQ_ROUTING_ID](http://api.zeromq.org/master:zmq-getsockopt) */ - public var routingId: ByteArray? + public var routingId: ByteString? } diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/RequestSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/RequestSocket.kt index 2e81650..799e2a7 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/RequestSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/RequestSocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A ZeroMQ socket of type [REQ][Type.REQ]. * Peers must be [ReplySocket]s or [RouterSocket]s. @@ -39,7 +41,7 @@ public interface RequestSocket : Socket, SendSocket, ReceiveSocket { * * See [ZMQ_ROUTING_ID](http://api.zeromq.org/master:zmq-getsockopt) */ - public var routingId: ByteArray? + public var routingId: ByteString? /** * When set to `true`, the socket will automatically send an empty message when a new diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/RouterSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/RouterSocket.kt index 9f7b6c3..4d23973 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/RouterSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/RouterSocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A ZeroMQ socket of type [REP][Type.ROUTER]. * Peers must be [RequestSocket]s or [DealerSocket]s. @@ -55,7 +57,7 @@ public interface RouterSocket : Socket, SendSocket, ReceiveSocket { * * See [ZMQ_ROUTING_ID](http://api.zeromq.org/master:zmq-getsockopt) */ - public var routingId: ByteArray? + public var routingId: ByteString? /** * When set to `true`, the socket will automatically send an empty message when a new diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/SendSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/SendSocket.kt index 3079dd7..4944bc4 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/SendSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/SendSocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A socket that can send messages. */ @@ -76,3 +78,13 @@ public interface SendSocket { */ public var sendTimeout: Int } + +public suspend fun SendSocket.send(sender: WriteScope.() -> Unit) { + send(Message(listOf()).apply { sender() }) +} + +public suspend fun SendSocket.sendCatching(sender: WriteScope.() -> Unit): SocketResult = + sendCatching(Message(listOf()).apply { sender() }) + +public fun SendSocket.trySend(sender: WriteScope.() -> Unit): SocketResult = + trySend(Message(listOf()).apply { sender() }) diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/SocketResult.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/SocketResult.kt index c8b2987..30092be 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/SocketResult.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/SocketResult.kt @@ -58,3 +58,8 @@ public sealed class SocketResult { public fun failure(cause: Throwable? = null): SocketResult = Failure(cause) } } + +public fun SocketResult.map(transform: (T) -> R): SocketResult = when (this) { + is SocketResult.Success -> SocketResult.success(transform(value)) + is SocketResult.Failure -> this +} diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriberSocket.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriberSocket.kt index b33693a..542f7d1 100644 --- a/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriberSocket.kt +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriberSocket.kt @@ -5,6 +5,8 @@ package org.zeromq +import kotlinx.io.bytestring.* + /** * A ZeroMQ socket of type [SUB][Type.SUB]. * Peers must be [PublisherSocket]s or [XPublisherSocket]s. @@ -46,7 +48,7 @@ public interface SubscriberSocket : Socket, ReceiveSocket { * * @param topics the topics to subscribe to */ - public suspend fun subscribe(vararg topics: ByteArray) + public suspend fun subscribe(vararg topics: ByteString) /** * Establishes a new message filter. Newly created [SubscriberSocket] sockets will filter out @@ -78,7 +80,7 @@ public interface SubscriberSocket : Socket, ReceiveSocket { * * @param topics the topics to unsubscribe from */ - public suspend fun unsubscribe(vararg topics: ByteArray) + public suspend fun unsubscribe(vararg topics: ByteString) /** * Removes the specified existing message filter previously established with [subscribe]. diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriptionMessage.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriptionMessage.kt new file mode 100644 index 0000000..26c97d5 --- /dev/null +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/SubscriptionMessage.kt @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import kotlinx.io.* +import kotlinx.io.bytestring.* + +public data class SubscriptionMessage( + val subscribe: Boolean, + val topic: ByteString, +) + +/** + * Builds a subscription or unsubscription message. + */ +public fun SubscriptionMessage.toMessage(): Message { + return Message(Buffer().apply { + writeByte(if (subscribe) 1 else 0) + write(topic) + }) +} + +/** + * Destructures a subscription or unsubscription message. + * + * @return a pair consisting of a boolean indicating if the message is a subscription message or an unsubscription + * message, and a topic. Returns `null` if the message is not a subscription/unsubscription message. + */ +public fun Message.toSubscriptionMessage(): SubscriptionMessage? { + if (isSingle) { + val bytes = singleOrThrow() + val firstByte = bytes.readByte().toInt() + if (firstByte == 0 || firstByte == 1) { + val subscribe = firstByte == 1 + val topic = bytes.readByteString() + return SubscriptionMessage(subscribe, topic) + } + } + return null +} diff --git a/kzmq-core/src/commonMain/kotlin/org/zeromq/WriteScope.kt b/kzmq-core/src/commonMain/kotlin/org/zeromq/WriteScope.kt new file mode 100644 index 0000000..839222b --- /dev/null +++ b/kzmq-core/src/commonMain/kotlin/org/zeromq/WriteScope.kt @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq + +import kotlinx.io.* +import kotlinx.io.bytestring.* + +public interface WriteScope { + + public fun writeFrame(source: Buffer) +} + +public inline fun WriteScope.writeFrame(writer: Sink.() -> Unit) { + val frame = Buffer() + frame.writer() + writeFrame(frame) +} + +public fun WriteScope.writeEmptyFrame() { + writeFrame(ByteString()) +} + +public fun WriteScope.writeFrame(byteString: ByteString) { + writeFrame { write(byteString) } +} + +public fun WriteScope.writeFrame(string: String) { + writeFrame { writeString(string) } +} diff --git a/kzmq-core/src/commonTest/kotlin/org/zeromq/ReceiveSocketOpsTests.kt b/kzmq-core/src/commonTest/kotlin/org/zeromq/ReceiveSocketOpsTests.kt index fe7007f..9fc723b 100644 --- a/kzmq-core/src/commonTest/kotlin/org/zeromq/ReceiveSocketOpsTests.kt +++ b/kzmq-core/src/commonTest/kotlin/org/zeromq/ReceiveSocketOpsTests.kt @@ -13,7 +13,7 @@ import kotlinx.coroutines.flow.* class ReceiveSocketOpsTests : FunSpec({ test("consumeAsFlow") { - val messages = List(10) { Message("message-$it".encodeToByteArray()) } + val messages = List(10) { Message("message-$it") } val socket = mock { val messageIterator = messages.iterator() diff --git a/kzmq-core/src/commonTest/kotlin/org/zeromq/SendSocketOpsTests.kt b/kzmq-core/src/commonTest/kotlin/org/zeromq/SendSocketOpsTests.kt index a67d895..92e5715 100644 --- a/kzmq-core/src/commonTest/kotlin/org/zeromq/SendSocketOpsTests.kt +++ b/kzmq-core/src/commonTest/kotlin/org/zeromq/SendSocketOpsTests.kt @@ -16,7 +16,7 @@ class SendSocketOpsTests : FunSpec({ val socket = mock { everySuspend { send(any()) } returns Unit } - val messages = List(10) { Message("message-$it".encodeToByteArray()) } + val messages = List(10) { Message("message-$it") } messages.asFlow().collectToSocket(socket) diff --git a/kzmq-jeromq/build.gradle.kts b/kzmq-jeromq/build.gradle.kts index d49d7fd..1d8d558 100644 --- a/kzmq-jeromq/build.gradle.kts +++ b/kzmq-jeromq/build.gradle.kts @@ -12,8 +12,12 @@ kotlin { sourceSets { jvmMain { + languageSettings { + languageVersion = "2.1" + } dependencies { implementation(project(":kzmq-core")) + implementation(libs.kotlinx.io.core) implementation(libs.jeromq) } } diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQDealerSocket.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQDealerSocket.kt index 41bdc1c..707f523 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQDealerSocket.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQDealerSocket.kt @@ -5,12 +5,14 @@ package org.zeromq +import kotlinx.io.bytestring.* + internal class JeroMQDealerSocket internal constructor( factory: (type: SocketType) -> ZMQ.Socket, ) : JeroMQSocket(factory, SocketType.DEALER, Type.DEALER), DealerSocket { override var conflate: Boolean by underlying::conflate - override var routingId: ByteArray? by underlying::identity + override var routingId: ByteString? by underlying::identity.converted() // TODO there no getter for setProbeRouter in underlying socket override var probeRouter: Boolean by notImplementedProperty() diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQReplySocket.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQReplySocket.kt index 690e5f8..493b3bf 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQReplySocket.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQReplySocket.kt @@ -5,9 +5,11 @@ package org.zeromq +import kotlinx.io.bytestring.* + internal class JeroMQReplySocket internal constructor( factory: (type: SocketType) -> ZMQ.Socket, ) : JeroMQSocket(factory, SocketType.REP, Type.REP), ReplySocket { - override var routingId: ByteArray? by underlying::identity + override var routingId: ByteString? by underlying::identity.converted() } diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRequestSocket.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRequestSocket.kt index a86b9e8..df783e4 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRequestSocket.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRequestSocket.kt @@ -5,11 +5,13 @@ package org.zeromq +import kotlinx.io.bytestring.* + internal class JeroMQRequestSocket internal constructor( factory: (type: SocketType) -> ZMQ.Socket, ) : JeroMQSocket(factory, SocketType.REQ, Type.REQ), RequestSocket { - override var routingId: ByteArray? by underlying::identity + override var routingId: ByteString? by underlying::identity.converted() override var probeRouter: Boolean by notImplementedProperty() override var correlate: Boolean by notImplementedProperty() override var relaxed: Boolean by notImplementedProperty() diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRouterSocket.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRouterSocket.kt index e8eb331..887d29f 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRouterSocket.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQRouterSocket.kt @@ -5,11 +5,13 @@ package org.zeromq +import kotlinx.io.bytestring.* + internal class JeroMQRouterSocket internal constructor( factory: (type: SocketType) -> ZMQ.Socket, ) : JeroMQSocket(factory, SocketType.ROUTER, Type.ROUTER), RouterSocket { - override var routingId: ByteArray? by underlying::identity + override var routingId: ByteString? by underlying::identity.converted() // TODO there no getter for setProbeRouter in underlying socket override var probeRouter: Boolean by notImplementedProperty() diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQSocket.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQSocket.kt index ca1a52d..c63c52f 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQSocket.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/JeroMQSocket.kt @@ -6,6 +6,9 @@ package org.zeromq import kotlinx.coroutines.selects.* +import kotlinx.io.* +import kotlinx.io.bytestring.* +import zmq.* @Suppress("RedundantSuspendModifier") internal abstract class JeroMQSocket internal constructor( @@ -25,9 +28,9 @@ internal abstract class JeroMQSocket internal constructor( suspend fun subscribe(): Unit = wrapping { underlying.subscribe(byteArrayOf()) } - suspend fun subscribe(vararg topics: ByteArray): Unit = wrapping { + suspend fun subscribe(vararg topics: ByteString): Unit = wrapping { if (topics.isEmpty()) underlying.subscribe(byteArrayOf()) - else topics.forEach { underlying.subscribe(it) } + else topics.forEach { underlying.subscribe(it.toByteArray()) } } suspend fun subscribe(vararg topics: String): Unit = wrapping { @@ -37,9 +40,9 @@ internal abstract class JeroMQSocket internal constructor( suspend fun unsubscribe(): Unit = wrapping { underlying.unsubscribe(byteArrayOf()) } - suspend fun unsubscribe(vararg topics: ByteArray): Unit = wrapping { + suspend fun unsubscribe(vararg topics: ByteString): Unit = wrapping { if (topics.isEmpty()) underlying.unsubscribe(byteArrayOf()) - else topics.forEach { underlying.unsubscribe(it) } + else topics.forEach { underlying.unsubscribe(it.toByteArray()) } } suspend fun unsubscribe(vararg topics: String): Unit = wrapping { @@ -52,29 +55,32 @@ internal abstract class JeroMQSocket internal constructor( fun trySend(message: Message): SocketResult = catching { sendImmediate(message) } private suspend fun sendSuspend(message: Message) = trace("sendSuspend") { - val parts = message.frames + val parts = message.readFrames() val lastIndex = parts.size - 1 for ((index, part) in parts.withIndex()) { sendPartSuspend(part, index < lastIndex) } } - private suspend fun sendPartSuspend(part: ByteArray, sendMore: Boolean) { - suspendOnIO { underlying.send(part, if (sendMore) ZMQ.SNDMORE else 0) } + private suspend fun sendPartSuspend(part: Buffer, sendMore: Boolean) { + suspendOnIO { underlying.sendMsg(part.toMsg(), if (sendMore) ZMQ.SNDMORE else 0) } } private fun sendImmediate(message: Message) = trace("sendImmediate") { - val parts = message.frames + val parts = message.readFrames() val lastIndex = parts.size - 1 for ((index, part) in parts.withIndex()) { sendPartImmediate(part, index < lastIndex) } } - private fun sendPartImmediate(part: ByteArray, sendMore: Boolean) { - underlying.send(part, ZMQ.DONTWAIT or if (sendMore) ZMQ.SNDMORE else 0) + private fun sendPartImmediate(part: Buffer, sendMore: Boolean) { + underlying.sendMsg(part.toMsg(), ZMQ.DONTWAIT or if (sendMore) ZMQ.SNDMORE else 0) } + // TODO optimize? + private fun Buffer.toMsg(): Msg = Msg.Builder().apply { put(readByteArray()) }.build() + // TODO multicastHops is a long in underlying socket var multicastHops: Int by notImplementedProperty() var sendBufferSize: Int by underlying::sendBufferSize @@ -89,29 +95,31 @@ internal abstract class JeroMQSocket internal constructor( get() = throw NotImplementedError("Not supported on JeroMQ engine") private suspend fun receiveSuspend(): Message = trace("receiveSuspend") { - val parts = mutableListOf() + val parts = mutableListOf() do { parts.add(receivePartSuspend()) } while (underlying.hasReceiveMore()) - return Message(*parts.toTypedArray()) + return Message(parts) } - private suspend fun receivePartSuspend(): ByteArray { - return suspendOnIO { underlying.recv(0) } + private suspend fun receivePartSuspend(): Buffer { + return suspendOnIO { underlying.recvMsg().toPart() } } private fun receiveImmediate(): Message = trace("receiveImmediate") { - val parts = mutableListOf() + val parts = mutableListOf() do { parts.add(receivePartImmediate() ?: error("No message received")) } while (underlying.hasReceiveMore()) - return Message(*parts.toTypedArray()) + return Message(parts) } - private fun receivePartImmediate(): ByteArray? { - return underlying.recv(ZMQ.DONTWAIT) + private fun receivePartImmediate(): Buffer? { + return underlying.recvMsg(ZMQ.DONTWAIT)?.toPart() } + private fun Msg.toPart(): Buffer = Buffer().transferFrom(buf()) + var receiveBufferSize: Int by underlying::receiveBufferSize var receiveHighWaterMark: Int by underlying::rcvHWM var receiveTimeout: Int by underlying::receiveTimeOut diff --git a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/delegates.kt b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/delegates.kt index 8297444..d7cf354 100644 --- a/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/delegates.kt +++ b/kzmq-jeromq/src/jvmMain/kotlin/org/zeromq/delegates.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import kotlin.reflect.* internal fun notImplementedProperty() = NotImplementedPropertyDelegate() @@ -16,3 +17,21 @@ internal class NotImplementedPropertyDelegate() { operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T): Unit = TODO("JeroMQ does not implement ${property.name}") } + +internal fun KMutableProperty0.converted(into: (V) -> U, from: (U) -> V) = + MappedPropertyDelegate(this, into, from) + +internal class MappedPropertyDelegate( + private val delegate: KMutableProperty0, + private val into: (V) -> U, + private val from: (U) -> V, +) { + operator fun getValue(thisRef: Any?, property: KProperty<*>): U = into(delegate.get()) + + operator fun setValue(thisRef: Any?, property: KProperty<*>, value: U): Unit = delegate.set(from(value)) +} + +internal fun KMutableProperty0.converted() = converted( + into = { it?.let { ByteString(it) } }, + from = { it?.toByteArray() } +) diff --git a/kzmq-jeromq/src/jvmTest/kotlin/org/zeromq/SimpleTest.kt b/kzmq-jeromq/src/jvmTest/kotlin/org/zeromq/SimpleTest.kt index d15b7c7..d384616 100644 --- a/kzmq-jeromq/src/jvmTest/kotlin/org/zeromq/SimpleTest.kt +++ b/kzmq-jeromq/src/jvmTest/kotlin/org/zeromq/SimpleTest.kt @@ -6,6 +6,7 @@ package org.zeromq import kotlinx.coroutines.test.* +import kotlinx.io.* import kotlin.test.* class SimpleTest { @@ -19,7 +20,7 @@ class SimpleTest { val pull = ctx2.createPull().apply { connect(address) } val messageContent = "Hello" - push.send(Message(messageContent.encodeToByteArray())) - assertEquals(messageContent, pull.receive().frames.getOrNull(0)?.decodeToString()) + push.send { writeFrame { writeString(messageContent) } } + assertEquals(messageContent, pull.receive { readFrame { readString() } }) } } diff --git a/kzmq-libzmq/src/nativeMain/kotlin/org/zeromq/LibzmqSocket.kt b/kzmq-libzmq/src/nativeMain/kotlin/org/zeromq/LibzmqSocket.kt index 8bca76e..a7cc152 100644 --- a/kzmq-libzmq/src/nativeMain/kotlin/org/zeromq/LibzmqSocket.kt +++ b/kzmq-libzmq/src/nativeMain/kotlin/org/zeromq/LibzmqSocket.kt @@ -8,6 +8,8 @@ package org.zeromq import kotlinx.cinterop.* import kotlinx.coroutines.* import kotlinx.coroutines.selects.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.internal.libzmq.* @OptIn(ExperimentalForeignApi::class) @@ -34,9 +36,9 @@ internal abstract class LibzmqSocket internal constructor( subscribe(byteArrayOf().toCValues()) } - suspend fun subscribe(vararg topics: ByteArray) { + suspend fun subscribe(vararg topics: ByteString) { if (topics.isEmpty()) subscribe(byteArrayOf().toCValues()) - else topics.forEach { subscribe(it.toCValues()) } + else topics.forEach { subscribe(it.toByteArray().toCValues()) } } suspend fun subscribe(vararg topics: String) { @@ -59,9 +61,9 @@ internal abstract class LibzmqSocket internal constructor( unsubscribe(byteArrayOf().toCValues()) } - suspend fun unsubscribe(vararg topics: ByteArray) { + suspend fun unsubscribe(vararg topics: ByteString) { if (topics.isEmpty()) unsubscribe(byteArrayOf().toCValues()) - else topics.forEach { unsubscribe(it.toCValues()) } + else topics.forEach { unsubscribe(it.toByteArray().toCValues()) } } suspend fun unsubscribe(vararg topics: String) { @@ -93,14 +95,14 @@ internal abstract class LibzmqSocket internal constructor( } private fun doSend(message: Message, blocking: Boolean) { - val parts = message.frames - val lastPartIndex = parts.lastIndex + val frames = message.readFrames() + val lastFrameIndex = frames.lastIndex val baseFlags = if (blocking) 0 else ZMQ_DONTWAIT - for ((index, part) in parts.withIndex()) { - val nativeData = part.toCValues() - val flags = baseFlags or if (index < lastPartIndex) ZMQ_SNDMORE else 0 + for ((index, frame) in frames.withIndex()) { + val nativeData = frame.readByteArray().toCValues() + val flags = baseFlags or if (index < lastFrameIndex) ZMQ_SNDMORE else 0 checkNativeError(zmq_send(underlying, nativeData, nativeData.size.toULong(), flags)) } } @@ -132,12 +134,12 @@ internal abstract class LibzmqSocket internal constructor( val onReceive: SelectClause1 get() = TODO() private fun doReceiveMessage(blocking: Boolean): Message { - val parts = mutableListOf() + val frames = mutableListOf() do { - val part = doReceiveMessagePart(blocking) ?: continue - parts += part + val frame = doReceiveMessagePart(blocking) ?: continue + frames += Buffer().apply { write(frame) } } while (hasMoreParts) - return Message(parts) + return Message(frames) } private val hasMoreParts: Boolean by socketOption(underlying, ZMQ_RCVMORE, booleanConverter) diff --git a/kzmq-test/build.gradle.kts b/kzmq-test/build.gradle.kts new file mode 100644 index 0000000..ff30457 --- /dev/null +++ b/kzmq-test/build.gradle.kts @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021-2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +import org.jetbrains.dokka.gradle.* +import java.net.* + +plugins { + id("plugin.library") + id("plugin.atomicfu") + id("plugin.kotest") + id("plugin.mocks") +} + +kotlin { + jvmTargets() + jsTargets() + nativeTargets { it.isSupportedByCIO || it.isSupportedByLibzmq } + + sourceSets { + commonMain { + dependencies { + implementation(libs.kotlinx.io.core) + implementation(project(":kzmq-core")) + + implementation(libs.kotest.framework.engine) + implementation(libs.kotest.framework.datatest) + implementation(libs.kotest.assertions.core) + } + } + } +} + +tasks.withType { + dokkaSourceSets { + named("commonMain") { + sourceLink { + localDirectory.set(file("src/commonMain/kotlin")) + remoteUrl.set(URL("https://github.com/ptitjes/kzmq/tree/master/kzmq-core/src/commonMain/kotlin")) + remoteLineSuffix.set("#L") + } + } + } +} diff --git a/kzmq-test/src/commonMain/kotlin/org/zeromq/test/Matchers.kt b/kzmq-test/src/commonMain/kotlin/org/zeromq/test/Matchers.kt new file mode 100644 index 0000000..5ca5b33 --- /dev/null +++ b/kzmq-test/src/commonMain/kotlin/org/zeromq/test/Matchers.kt @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.test + +import io.kotest.assertions.* +import io.kotest.assertions.print.* +import io.kotest.matchers.equals.* +import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* +import org.zeromq.* +import kotlin.jvm.* + +@JvmName("shouldReceiveExactlyListMessageTemplate") +public suspend fun shouldReceiveExactly( + expected: List, + receive: suspend () -> Message, +) { + shouldReceiveExactly(expected.map { it.frames }, receive) +} + +@JvmName("shouldReceiveExactlyListListByteString") +private suspend fun shouldReceiveExactly( + expected: List>, + receive: suspend () -> Message, +) { + val received = mutableListOf>() + try { + repeat(expected.size) { + val message = receive() + received += message.readFrames().map { it.readByteString() } + } + + received shouldBeEqual expected + } catch (e: TimeoutCancellationException) { + throw failure( + Expected(expected.print()), + Actual(received.print()), + "Only ${received.size} of the expected ${expected.size} messages were received.", + ) + } +} diff --git a/kzmq-test/src/commonMain/kotlin/org/zeromq/test/MessageTemplate.kt b/kzmq-test/src/commonMain/kotlin/org/zeromq/test/MessageTemplate.kt new file mode 100644 index 0000000..1fdefb5 --- /dev/null +++ b/kzmq-test/src/commonMain/kotlin/org/zeromq/test/MessageTemplate.kt @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Didier Villevalois and Kzmq contributors. + * Use of this source code is governed by the Apache 2.0 license. + */ + +package org.zeromq.test + +import kotlinx.io.* +import kotlinx.io.bytestring.* +import org.zeromq.* + +public data class MessageTemplate(val frames: List) + +public fun MessageTemplate.buildMessage(): Message = Message().apply { + frames.forEach { writeFrame(it) } +} + +public fun messages(count: Int, writer: WriteScope.(index: Int) -> Unit): List { + return List(count) { index -> message { writer(index) } } +} + +public fun message(writer: WriteScope.() -> Unit): MessageTemplate = + MessageTemplate(TemplateWriteScope().apply { writer() }.frames) + +private class TemplateWriteScope : WriteScope { + val frames = mutableListOf() + + override fun writeFrame(source: Buffer) { + frames += source.readByteString() + } +} + +public fun message(frame: ByteString): MessageTemplate = message { writeFrame(frame) } + +public suspend fun SendSocket.send(message: MessageTemplate): Unit = send(message.buildMessage()) diff --git a/kzmq-tests/build.gradle.kts b/kzmq-tests/build.gradle.kts index ba2769c..f620c6d 100644 --- a/kzmq-tests/build.gradle.kts +++ b/kzmq-tests/build.gradle.kts @@ -22,6 +22,8 @@ kotlin { commonMain { dependencies { implementation(project(":kzmq-core")) + implementation(project(":kzmq-test")) + implementation(libs.kotlinx.io.core) } } diff --git a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/Matchers.kt b/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/Matchers.kt index 393f15f..617b7fe 100644 --- a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/Matchers.kt +++ b/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/Matchers.kt @@ -5,28 +5,11 @@ package org.zeromq.tests.utils -import io.kotest.assertions.* -import io.kotest.assertions.print.* -import io.kotest.matchers.collections.* -import kotlinx.coroutines.* import org.zeromq.* +import org.zeromq.test.* -suspend infix fun ReceiveSocket.shouldReceiveExactly(expected: List) { - shouldReceiveExactly(expected) { receive() } -} +suspend infix fun ReceiveSocket.shouldReceive(expected: MessageTemplate) = + shouldReceiveExactly(listOf(expected)) { receive() } -private suspend fun shouldReceiveExactly(expected: List, receive: suspend () -> Message) { - val received = mutableListOf() - try { - repeat(expected.size) { - received += receive() - } - received shouldContainExactly expected - } catch (e: TimeoutCancellationException) { - throw failure( - Expected(expected.print()), - Actual(received.print()), - "Only ${received.size} of the expected ${expected.size} messages were received.", - ) - } -} +suspend infix fun ReceiveSocket.shouldReceiveExactly(expected: List) = + shouldReceiveExactly(expected) { receive() } diff --git a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/generators.kt b/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/generators.kt index c482a98..6992243 100644 --- a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/generators.kt +++ b/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/generators.kt @@ -5,8 +5,6 @@ package org.zeromq.tests.utils -import org.zeromq.* - private val characters = listOf('_') + ('a'..'z').toList() + ('A'..'Z').toList() + ('0'..'9').toList() @@ -16,7 +14,7 @@ enum class Protocol { TCP, } -suspend fun randomAddress(protocol: Protocol = Protocol.TCP): String { +suspend fun randomEndpoint(protocol: Protocol = Protocol.TCP): String { return when (protocol) { Protocol.INPROC -> "inproc://${randomAddressSuffix()}" Protocol.IPC -> "ipc:///tmp/${randomAddressSuffix()}" @@ -25,5 +23,3 @@ suspend fun randomAddress(protocol: Protocol = Protocol.TCP): String { } private fun randomAddressSuffix() = List(16) { characters.random() }.joinToString("") - -fun generateMessages(messageCount: Int) = List(messageCount) { Message("message-$it".encodeToByteArray()) } diff --git a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/messages.kt b/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/messages.kt deleted file mode 100644 index 48d56fa..0000000 --- a/kzmq-tests/src/commonMain/kotlin/org/zeromq/tests/utils/messages.kt +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2021-2024 Didier Villevalois and Kzmq contributors. - * Use of this source code is governed by the Apache 2.0 license. - */ - -package org.zeromq.tests.utils - -import org.zeromq.* - -object MessageComparator : Comparator { - override fun compare(a: Message, b: Message): Int { - val aAsString = a.frames.joinToString { it.decodeToString() } - val bAsString = b.frames.joinToString { it.decodeToString() } - return if (aAsString < bAsString) -1 - else if (aAsString > bAsString) +1 - else 0 - } -} diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/DealerRouterTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/DealerRouterTests.kt index 39b98fb..60c84ee 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/DealerRouterTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/DealerRouterTests.kt @@ -7,8 +7,11 @@ package org.zeromq.tests.sockets import io.kotest.core.spec.style.* import io.kotest.matchers.* +import io.kotest.matchers.equals.* import kotlinx.atomicfu.* import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* import org.zeromq.tests.utils.* @@ -20,26 +23,23 @@ class DealerRouterTests : FunSpec({ withContexts("base").config( // TODO fix when testing more Dealer and Router logic - skip = setOf("inproc"), +// skip = setOf("jeromq"), ) { ctx1, ctx2, protocol -> val dealerCount = 2 val routerCount = 3 - val addresses = Array(routerCount) { randomAddress(protocol) } + val addresses = Array(routerCount) { randomEndpoint(protocol) } val routers = addresses.map { - ctx2.createRouter().apply { - bind(it) - } + ctx2.createRouter().apply { bind(it) } } val dealers = Array(dealerCount) { index -> ctx1.createDealer().apply { - routingId = index.encodeRoutingId() + routingId = index.encodeAsRoutingId() addresses.forEach { connect(it) } + waitForConnections(addresses.size) } } - waitForConnections(dealerCount * routerCount) - class Trace { val receivedReplyIds = atomic(setOf()) } @@ -49,70 +49,68 @@ class DealerRouterTests : FunSpec({ coroutineScope { launch { repeat(dealerCount * routerCount) { requestId -> - val dealer = dealers[requestId % dealers.size] - val requestData = byteArrayOf(requestId.toByte()) - - dealer.send( - Message( - REQUEST_MARKER.encodeToByteArray(), - requestData - ) - ) + val dealerId = requestId % dealers.size + val dealer = dealers[dealerId] + + dealer.send { + writeFrame(REQUEST_MARKER) + writeFrame { writeByte(requestId.toByte()) } + } } } routers.forEach { router -> launch { repeat(dealerCount) { - val request = router.receive() - - request.frames.size shouldBe 3 - val dealerIdFrame = request.frames[0] - request.frames[1].decodeToString() shouldBe REQUEST_MARKER - val requestIdFrame = request.frames[2] - - router.send( - Message( - dealerIdFrame, - REPLY_MARKER.encodeToByteArray(), - requestIdFrame, - dealerIdFrame - ) - ) + val (dealerId, requestId) = router.receive { + val dealerId = readFrame { readByteString() } + readFrame { readString() shouldBe REQUEST_MARKER } + val requestId = readFrame { readByte() } + dealerId to requestId + } + + router.send { + writeFrame(dealerId) + writeFrame(REPLY_MARKER) + writeFrame { writeByte(requestId) } + writeFrame(dealerId) + } } } } dealers.forEach { dealer -> launch { repeat(routerCount) { - val reply = dealer.receive() - - reply.frames.size shouldBe 3 - reply.frames[0].decodeToString() shouldBe REPLY_MARKER - val requestIdFrame = reply.frames[1] - val dealerIdFrame = reply.frames[2] - - val requestId = requestIdFrame[0].toInt() - val dealerId = dealerIdFrame.decodeRoutingId() - - dealerId shouldBe dealer.routingId?.decodeRoutingId() - dealerId shouldBe requestId % dealerCount - - trace.receivedReplyIds.getAndUpdate { it + requestId } + val (dealerId, requestId) = dealer.receive { + readFrame { readString() shouldBe REPLY_MARKER } + val requestId = readFrame { readByte() } + val dealerId = readFrame { readByteString() } + dealerId to requestId + } + + val realDealerId = dealerId.decodeFromRoutingId() + realDealerId shouldBe dealer.routingId?.decodeFromRoutingId() + realDealerId shouldBe requestId % dealerCount + + trace.receivedReplyIds.getAndUpdate { it + requestId.toInt() } } } } } - trace.receivedReplyIds.value shouldBe (0 until 6).toSet() + trace.receivedReplyIds.value shouldBeEqual (0 until 6).toSet() } }) /* * TODO Remove when https://github.com/zeromq/zeromq.js/issues/506 is fixed. */ -private fun Int.encodeRoutingId(): ByteArray = byteArrayOf(1, (this + 1).toByte()) -private fun ByteArray.decodeRoutingId(): Int { - require(size == 2) //{ "Size should be 2, but is $size" } +private fun Int.encodeAsRoutingId(): ByteString = buildByteString { + append(1.toByte()) + append((this@encodeAsRoutingId + 1).toByte()) +} + +private fun ByteString.decodeFromRoutingId(): Int { + require(size == 2) { "Size should be 2, but is $size" } require(this[0] == 1.toByte()) return this[1].toInt() - 1 } diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PairTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PairTests.kt index 956db36..acbf1af 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PairTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PairTests.kt @@ -6,16 +6,19 @@ package org.zeromq.tests.sockets import io.kotest.core.spec.style.* -import io.kotest.matchers.* +import kotlinx.io.bytestring.* import org.zeromq.* +import org.zeromq.test.* import org.zeromq.tests.utils.* @Suppress("unused") class PairTests : FunSpec({ withContexts("bind-connect") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val message = Message("Hello 0MQ!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val message = message { + writeFrame("Hello 0MQ!".encodeToByteString()) + } val pair1 = ctx1.createPair().apply { bind(address) } val pair2 = ctx2.createPair().apply { connect(address) } @@ -23,15 +26,17 @@ class PairTests : FunSpec({ waitForConnections() pair1.send(message) - pair2.receive() shouldBe message + pair2 shouldReceive message pair2.send(message) - pair1.receive() shouldBe message + pair1 shouldReceive message } withContexts("connect-bind") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val message = Message("Hello 0MQ!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val message = message { + writeFrame("Hello 0MQ!".encodeToByteString()) + } val pair2 = ctx2.createPair().apply { bind(address) } val pair1 = ctx1.createPair().apply { connect(address) } @@ -39,9 +44,9 @@ class PairTests : FunSpec({ waitForConnections() pair1.send(message) - pair2.receive() shouldBe message + pair2 shouldReceive message pair2.send(message) - pair1.receive() shouldBe message + pair1 shouldReceive message } }) diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherSubscriberTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherSubscriberTests.kt index 27601f8..33820de 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherSubscriberTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherSubscriberTests.kt @@ -6,18 +6,22 @@ package org.zeromq.tests.sockets import io.kotest.core.spec.style.* -import io.kotest.matchers.* import io.kotest.matchers.collections.* import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* +import org.zeromq.test.* import org.zeromq.tests.utils.* @Suppress("unused") class PublisherSubscriberTests : FunSpec({ withContexts("bind-connect") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val message = Message("Hello 0MQ!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val template = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } val publisher = ctx1.createPublisher().apply { bind(address) } val subscriber = ctx2.createSubscriber().apply { connect(address) } @@ -29,19 +33,19 @@ class PublisherSubscriberTests : FunSpec({ waitForSubscriptions() coroutineScope { - launch { publisher.send(message) } - launch { subscriber.receive() shouldBe message } + launch { publisher.send(template) } + launch { subscriber shouldReceive template } } } // TODO Figure out why this test is hanging with JeroMQ and ZeroMQ.js - // TODO Figure out why this test makes all jvmTest hang withContexts("connect-bind").config( skip = setOf("jeromq", "zeromq.js"), - only = setOf("tcp") ) { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val message = Message("Hello 0MQ!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val template = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } val subscriber = ctx2.createSubscriber().apply { bind(address) } val publisher = ctx1.createPublisher().apply { connect(address) } @@ -53,13 +57,13 @@ class PublisherSubscriberTests : FunSpec({ waitForSubscriptions() coroutineScope { - launch { publisher.send(message) } - launch { subscriber.receive() shouldBe message } + launch { publisher.send(template) } + launch { subscriber shouldReceive template } } } withContexts("subscription filter") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val sent = listOf("prefixed data", "non-prefixed data", "prefix is good") val expected = sent.filter { it.startsWith("prefix") } @@ -75,13 +79,13 @@ class PublisherSubscriberTests : FunSpec({ coroutineScope { launch { - sent.forEach { publisher.send(Message(it.encodeToByteArray())) } + sent.forEach { publisher.send(Message(it.encodeToByteString())) } } launch { val received = mutableListOf() repeat(2) { - received += subscriber.receive().singleOrThrow().decodeToString() + received += subscriber.receive().singleOrThrow().readByteArray().decodeToString() } received shouldContainExactly expected } diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherXSubscriberTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherXSubscriberTests.kt index 6e3d808..8de91f5 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherXSubscriberTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PublisherXSubscriberTests.kt @@ -8,6 +8,8 @@ package org.zeromq.tests.sockets import io.kotest.core.spec.style.* import io.kotest.matchers.collections.* import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* import org.zeromq.tests.utils.* @@ -15,7 +17,7 @@ import org.zeromq.tests.utils.* class PublisherXSubscriberTests : FunSpec({ withContexts("subscription filter") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val sent = listOf("prefixed data", "non-prefixed data", "prefix is good") val expected = sent.filter { it.startsWith("prefix") } @@ -25,19 +27,19 @@ class PublisherXSubscriberTests : FunSpec({ waitForConnections() - subscriber.send(subscriptionMessageOf(true, "prefix".encodeToByteArray())) + subscriber.send(SubscriptionMessage(true, "prefix".encodeToByteString()).toMessage()) waitForSubscriptions() coroutineScope { launch { - sent.forEach { publisher.send(Message(it.encodeToByteArray())) } + sent.forEach { publisher.send(Message(it.encodeToByteString())) } } launch { val received = mutableListOf() repeat(2) { - received += subscriber.receive().singleOrThrow().decodeToString() + received += subscriber.receive().singleOrThrow().readByteArray().decodeToString() } received shouldContainExactly expected } diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PullTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PullTests.kt index bc0c5c3..6809314 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PullTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PullTests.kt @@ -8,7 +8,9 @@ package org.zeromq.tests.sockets import io.kotest.assertions.* import io.kotest.common.* import io.kotest.core.spec.style.* +import kotlinx.io.bytestring.* import org.zeromq.* +import org.zeromq.test.* import org.zeromq.tests.utils.* @OptIn(ExperimentalKotest::class) @@ -22,7 +24,7 @@ class PullTests : FunSpec({ // TODO Investigate why this fails with CIO native if (platform == Platform.Native) return@config - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val pushSocketCount = 5 val pullSocket = ctx2.createPull().apply { bind(address) } @@ -30,18 +32,16 @@ class PullTests : FunSpec({ waitForConnections(pushSocketCount) - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } + val templates = messages(10) { index -> + writeFrame(ByteString(index.toByte())) + } pushSockets.forEach { pushSocket -> - messages.forEach { message -> - pushSocket.send(message) - } + templates.forEach { pushSocket.send(it) } } all { - messages.forEach { message -> - pullSocket shouldReceiveExactly List(pushSocketCount) { message } - } + pullSocket shouldReceiveExactly templates.flatMap { template -> List(pushSocketCount) { template } } } } }) diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PushTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PushTests.kt index 71a1a00..62f2bf2 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PushTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/PushTests.kt @@ -10,7 +10,9 @@ import io.kotest.common.* import io.kotest.core.spec.style.* import io.kotest.matchers.* import kotlinx.coroutines.* +import kotlinx.io.bytestring.* import org.zeromq.* +import org.zeromq.test.* import org.zeromq.tests.utils.* import kotlin.time.Duration.Companion.seconds @@ -19,32 +21,36 @@ import kotlin.time.Duration.Companion.seconds class PushTests : FunSpec({ withContexts("simple connect-bind") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val pushSocket = ctx1.createPush().apply { connect(address) } val pullSocket = ctx2.createPull().apply { bind(address) } waitForConnections() - val message = Message("Hello, 0MQ!".encodeToByteArray()) - pushSocket.send(message) - pullSocket shouldReceiveExactly listOf(message) + val template = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } + pushSocket.send(template) + pullSocket shouldReceive template pushSocket.close() pullSocket.close() } withContexts("simple bind-connect") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val pushSocket = ctx1.createPush().apply { bind(address) } val pullSocket = ctx2.createPull().apply { connect(address) } waitForConnections() - val message = Message("Hello, 0MQ!".encodeToByteArray()) - pushSocket.send(message) - pullSocket shouldReceiveExactly listOf(message) + val template = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } + pushSocket.send(template) + pullSocket shouldReceive template pushSocket.close() pullSocket.close() @@ -54,7 +60,7 @@ class PushTests : FunSpec({ // TODO investigate why these pairs are flaky skip = setOf("cio-jeromq", "jeromq-cio"), ) { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val messageCount = 100 val pullSocket = ctx2.createPull().apply { bind(address) } @@ -62,21 +68,14 @@ class PushTests : FunSpec({ waitForConnections() - var sent = 0 - while (sent < messageCount) { - val message = Message(sent.encodeToByteArray()) - pushSocket.send(message) - sent++ + val templates = messages(messageCount) { index -> + writeFrame(ByteString(index.toByte())) } + + templates.forEach { pushSocket.send(it) } pushSocket.disconnect(address) - var received = 0 - while (received < messageCount) { - val message = pullSocket.receive() - message.singleOrThrow().decodeToInt() shouldBe received - received++ - } - received shouldBe messageCount + pullSocket shouldReceiveExactly templates pushSocket.close() pullSocket.close() @@ -86,7 +85,7 @@ class PushTests : FunSpec({ // TODO investigate why these tests are flaky skip = setOf("cio"), ) { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val messageCount = 100 val pullSocket = ctx2.createPull().apply { bind(address) } @@ -94,31 +93,28 @@ class PushTests : FunSpec({ waitForConnections() - var sent = 0 - while (sent < messageCount) { - val message = Message(sent.encodeToByteArray()) - pushSocket.send(message) - sent++ + val templates = messages(messageCount) { index -> + writeFrame(ByteString(index.toByte())) } + + templates.forEach { pushSocket.send(it) } pushSocket.close() - var received = 0 - while (received < messageCount) { - val message = pullSocket.receive() - message.singleOrThrow().decodeToInt() shouldBe received - received++ - } - received shouldBe messageCount + pullSocket shouldReceiveExactly templates pullSocket.close() } withContexts("SHALL consider a peer as available only when it has an outgoing queue that is not full") { ctx1, ctx2, protocol -> - val address1 = randomAddress(protocol) - val address2 = randomAddress(protocol) + val address1 = randomEndpoint(protocol) + val address2 = randomEndpoint(protocol) - val firstBatch = List(5) { index -> Message(ByteArray(1) { index.toByte() }) } - val secondBatch = List(10) { index -> Message(ByteArray(1) { (index + 10).toByte() }) } + val firstBatch = messages(5) { index -> + writeFrame(ByteString(index.toByte())) + } + val secondBatch = messages(10) { index -> + writeFrame(ByteString((index + 10).toByte())) + } val push = ctx1.createPush() val pull1 = ctx2.createPull() @@ -135,9 +131,9 @@ class PushTests : FunSpec({ waitForConnections() // Send each message of the first batch once per receiver - firstBatch.forEach { message -> repeat(2) { push.send(message) } } + firstBatch.forEach { template -> repeat(2) { push.send(template) } } // Send each message of the second batch once - secondBatch.forEach { message -> push.send(message) } + secondBatch.forEach { template -> push.send(template) } pull2.apply { bind(address2) } waitForConnections() @@ -151,21 +147,23 @@ class PushTests : FunSpec({ withContexts("SHALL route outgoing messages to available peers using a round-robin strategy") { ctx1, ctx2, protocol -> val pullCount = 5 - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val push = ctx1.createPush().apply { bind(address) } val pulls = List(pullCount) { ctx2.createPull().apply { connect(address) } } waitForConnections(pullCount) - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } + val templates = messages(10) { index -> + writeFrame(ByteString(index.toByte())) + } // Send each message once per receiver - messages.forEach { message -> repeat(pulls.size) { push.send(message) } } + templates.forEach { template -> repeat(pulls.size) { push.send(template) } } all { // Check each receiver got every messages - pulls.forEach { it shouldReceiveExactly messages } + pulls.forEach { it shouldReceiveExactly templates } } } @@ -175,7 +173,7 @@ class PushTests : FunSpec({ ) { ctx, _ -> val push = ctx.createPush() - val message = Message("Won't be sent".encodeToByteArray()) + val message = Message("Won't be sent".encodeToByteString()) withTimeoutOrNull(1.seconds) { push.send(message) @@ -189,27 +187,29 @@ class PushTests : FunSpec({ ) { ctx, _ -> val push = ctx.createPush() - val message = Message("Won't be sent".encodeToByteArray()) + val message = Message("Won't be sent".encodeToByteString()) withTimeoutOrNull(1.seconds) { push.send(message) } shouldBe null } - withContexts("SHALL NOT discard messages that it cannot queue") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + withContexts("SHALL NOT discard messages that it cannot queue").config( + only = setOf(), + ) { ctx1, ctx2, protocol -> + val address = randomEndpoint(protocol) val push = ctx1.createPush().apply { connect(address) } - val messages = List(10) { index -> Message(ByteArray(1) { index.toByte() }) } + val templates = messages(10) { index -> listOf(ByteString(index.toByte() )) } // Send each message once - messages.forEach { message -> push.send(message) } + templates.forEach { push.send(it) } val pull = ctx2.createPull().apply { bind(address) } waitForConnections() // Check each receiver got every messages - pull shouldReceiveExactly messages + pull shouldReceiveExactly templates } }) diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/RequestReplyTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/RequestReplyTests.kt index 9f9c20f..5731d5d 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/RequestReplyTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/RequestReplyTests.kt @@ -6,130 +6,162 @@ package org.zeromq.tests.sockets import io.kotest.core.spec.style.* -import io.kotest.matchers.* +import io.kotest.matchers.equals.* import kotlinx.coroutines.* -import kotlinx.coroutines.selects.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* +import org.zeromq.test.* import org.zeromq.tests.utils.* @Suppress("unused") class RequestReplyTests : FunSpec({ withContexts("bind-connect") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val requestMessage = Message("Hello 0MQ!".encodeToByteArray()) - val replyMessage = Message("Hello back!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val requestTemplate = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } + val replyTemplate = message { + writeFrame("Hello back!".encodeToByteString()) + } val request = ctx1.createRequest().apply { bind(address) } val reply = ctx2.createReply().apply { connect(address) } waitForConnections() - request.send(requestMessage) - reply.receive() shouldBe requestMessage + request.send(requestTemplate) + reply shouldReceive requestTemplate - reply.send(replyMessage) - request.receive() shouldBe replyMessage + reply.send(replyTemplate) + request shouldReceive replyTemplate } withContexts("connect-bind") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) - val requestMessage = Message("Hello 0MQ!".encodeToByteArray()) - val replyMessage = Message("Hello back!".encodeToByteArray()) + val address = randomEndpoint(protocol) + val requestTemplate = message { + writeFrame("Hello, 0MQ!".encodeToByteString()) + } + val replyTemplate = message { + writeFrame("Hello back!".encodeToByteString()) + } val request = ctx1.createRequest().apply { bind(address) } val reply = ctx2.createReply().apply { connect(address) } waitForConnections() - request.send(requestMessage) - reply.receive() shouldBe requestMessage + request.send(requestTemplate) + reply shouldReceive requestTemplate - reply.send(replyMessage) - request.receive() shouldBe replyMessage + reply.send(replyTemplate) + request shouldReceive replyTemplate } withContexts("round-robin connected reply sockets").config( skip = setOf("jeromq", "zeromq.js"), ) { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val request = ctx1.createRequest().apply { bind(address) } + val reply1 = ctx2.createReply().apply { connect(address) } + waitForConnections(2) val reply2 = ctx2.createReply().apply { connect(address) } - waitForConnections(2) - var lastReplier: ReplySocket? = null - repeat(10) { i -> - val requestMessage = Message("Hello $i".encodeToByteArray()) - val replyMessage = Message("Hello back $i".encodeToByteArray()) - - request.send(requestMessage) - - select { - reply1.onReceive { message -> - lastReplier shouldNotBe reply1 - lastReplier = reply1 - - message shouldBe requestMessage - reply1.send(replyMessage) + val replies = listOf(reply1, reply2) + + val replyJob = launch { + replies.forEachIndexed { replyIndex, reply -> + launch { + while (true) { + val index = reply.receive { + readFrame { readString() } shouldBeEqual "some request" + readFrame { readByte() } + } + + reply.send { + writeFrame("some reply") + writeFrame { writeByte(index) } + writeFrame { writeByte(replyIndex.toByte()) } + } + } } - reply2.onReceive { message -> - lastReplier shouldNotBe reply2 - lastReplier = reply2 + } + } - message shouldBe requestMessage - reply2.send(replyMessage) - } + repeat(10) { index -> + request.send { + writeFrame("some request") + writeFrame { writeByte(index.toByte()) } } - request.receive() shouldBe replyMessage + request shouldReceive message { + writeFrame("some reply") + writeFrame { writeByte(index.toByte()) } + writeFrame { writeByte((index % 2).toByte()) } + } } + + replyJob.cancelAndJoin() } withContexts("fair-queuing request sockets").config( - skip = setOf("cio", "jeromq", "zeromq.js"), + skip = setOf("jeromq", "zeromq.js"), ) { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val reply = ctx2.createReply().apply { bind(address) } + val request1 = ctx1.createRequest().apply { connect(address) } + waitForConnections(2) val request2 = ctx1.createRequest().apply { connect(address) } - waitForConnections(2) - coroutineScope { - launch { - var lastRequester: String? = null - repeat(10) { - val request = reply.receive() - val requester = request.singleOrThrow().decodeToString().substringAfterLast(" ") - - lastRequester shouldNotBe requester - lastRequester = requester + val requests = listOf(request1, request2) - reply.send(Message("Hello back to $requester".encodeToByteArray())) + val replyJob = launch { + repeat(10) { replyMessageIndex -> + val (requestIndex, messageIndex) = reply.receive { + readFrame { readString() } shouldBeEqual "some request" + val requestIndex = readFrame { readByte() } + val messageIndex = readFrame { readByte() } + requestIndex to messageIndex } - } - launch { - val requestMessage = Message("Hello from request1".encodeToByteArray()) - val replyMessage = Message("Hello back to request1".encodeToByteArray()) - repeat(5) { - request1.send(requestMessage) - request1.receive() shouldBe replyMessage + reply.send { + writeFrame("some reply") + writeFrame { writeByte(requestIndex) } + writeFrame { writeByte(messageIndex) } + writeFrame { writeByte(replyMessageIndex.toByte()) } } } - launch { - val requestMessage = Message("Hello from request2".encodeToByteArray()) - val replyMessage = Message("Hello back to request2".encodeToByteArray()) + } - repeat(5) { - request2.send(requestMessage) - request2.receive() shouldBe replyMessage + coroutineScope { + requests.forEachIndexed { requestIndex, request -> + launch { + repeat(5) { messageIndex -> + request.send { + writeFrame("some request") + writeFrame { writeByte(requestIndex.toByte()) } + writeFrame { writeByte(messageIndex.toByte()) } + } + + val expectedMessageIndex = messageIndex * 2 + requestIndex + request shouldReceive message { + writeFrame("some reply") + writeFrame { writeByte(requestIndex.toByte()) } + writeFrame { writeByte(messageIndex.toByte()) } + writeFrame { writeByte(expectedMessageIndex.toByte()) } + } + } } } } + + replyJob.cancelAndJoin() } }) diff --git a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/XPublisherSubscriberTests.kt b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/XPublisherSubscriberTests.kt index 0617bb0..18af671 100644 --- a/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/XPublisherSubscriberTests.kt +++ b/kzmq-tests/src/commonTest/kotlin/org/zeromq/tests/sockets/XPublisherSubscriberTests.kt @@ -9,6 +9,8 @@ import io.kotest.core.spec.style.* import io.kotest.matchers.* import io.kotest.matchers.collections.* import kotlinx.coroutines.* +import kotlinx.io.* +import kotlinx.io.bytestring.* import org.zeromq.* import org.zeromq.tests.utils.* @@ -16,7 +18,7 @@ import org.zeromq.tests.utils.* class XPublisherSubscriberTests : FunSpec({ withContexts("subscription filter") { ctx1, ctx2, protocol -> - val address = randomAddress(protocol) + val address = randomEndpoint(protocol) val sent = listOf("prefixed data", "non-prefixed data", "prefix is good") val expected = sent.filter { it.startsWith("prefix") } @@ -33,9 +35,9 @@ class XPublisherSubscriberTests : FunSpec({ launch { val message = publisher.receive() - val subscribeTopicPair = destructureSubscriptionMessage(message) - subscribeTopicPair shouldNotBe null - subscribeTopicPair?.let { (subscribe, topic) -> + val subscriptionMessage = message.toSubscriptionMessage() + subscriptionMessage shouldNotBe null + subscriptionMessage?.let { (subscribe, topic) -> subscribe shouldBe true topic.decodeToString() shouldBe "prefix" } @@ -46,13 +48,13 @@ class XPublisherSubscriberTests : FunSpec({ coroutineScope { launch { - sent.forEach { publisher.send(Message(it.encodeToByteArray())) } + sent.forEach { publisher.send(Message(it.encodeToByteString())) } } launch { val received = mutableListOf() repeat(2) { - received += subscriber.receive().singleOrThrow().decodeToString() + received += subscriber.receive().singleOrThrow().readByteArray().decodeToString() } received shouldContainExactly expected } diff --git a/kzmq-tools/build.gradle.kts b/kzmq-tools/build.gradle.kts index 61477d9..d799df2 100644 --- a/kzmq-tools/build.gradle.kts +++ b/kzmq-tools/build.gradle.kts @@ -35,6 +35,7 @@ kotlin { dependencies { implementation(project(":kzmq-core")) implementation(project(":kzmq-cio")) + implementation(libs.kotlinx.io.core) implementation(libs.kotlinx.cli) } } diff --git a/kzmq-tools/src/commonMain/kotlin/org/zeromq/tools/Throughput.kt b/kzmq-tools/src/commonMain/kotlin/org/zeromq/tools/Throughput.kt index 2c5fe9e..2825fc4 100644 --- a/kzmq-tools/src/commonMain/kotlin/org/zeromq/tools/Throughput.kt +++ b/kzmq-tools/src/commonMain/kotlin/org/zeromq/tools/Throughput.kt @@ -7,6 +7,7 @@ package org.zeromq.tools import kotlinx.cli.* import kotlinx.coroutines.* +import kotlinx.io.bytestring.* import org.zeromq.* import kotlin.time.* @@ -37,7 +38,7 @@ fun main(args: Array) = runBlocking { val engine = engines.find { it.name == engineName } ?: error("No such engine: $engineName") - val message = Message(ByteArray(messageSize)) + val message = Message(buildByteString(messageSize) {}) val handler = CoroutineExceptionHandler { _, throwable -> throwable.printStackTrace() } val context = Context(engine, coroutineContext + handler + dispatcher) @@ -87,7 +88,6 @@ private suspend fun Context.push( } } -@OptIn(ExperimentalTime::class) private suspend fun Context.pull( messageCount: Int, verbose: Boolean, diff --git a/kzmq-zeromqjs/build.gradle.kts b/kzmq-zeromqjs/build.gradle.kts index e1c9538..3c97aeb 100644 --- a/kzmq-zeromqjs/build.gradle.kts +++ b/kzmq-zeromqjs/build.gradle.kts @@ -14,6 +14,7 @@ kotlin { jsMain { dependencies { implementation(project(":kzmq-core")) + implementation(libs.kotlinx.io.core) implementation(npm("zeromq", libs.versions.zeromqjs.get())) } } diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsDealerSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsDealerSocket.kt index cd1345f..f1eb447 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsDealerSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsDealerSocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.Dealer as ZDealer internal class ZeroMQJsDealerSocket internal constructor( @@ -17,7 +18,7 @@ internal class ZeroMQJsDealerSocket internal constructor( override var conflate: Boolean by underlying::conflate - override var routingId: ByteArray? by underlying::routingId.asNullableByteArrayProperty() + override var routingId: ByteString? by underlying::routingId.asNullableByteStringProperty() override var probeRouter: Boolean by underlying::probeRouter } diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReceiveSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReceiveSocket.kt index f4fed12..5b92773 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReceiveSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReceiveSocket.kt @@ -7,12 +7,13 @@ package org.zeromq import kotlinx.coroutines.* import kotlinx.coroutines.selects.* +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.* internal class ZeroMQJsReceiveSocket(private val underlying: Readable) : ReceiveSocket { override suspend fun receive(): Message = - Message(underlying.receive().await().map { it.toByteArray() }) + Message(underlying.receive().await().map { ByteString(it.toByteArray()) }) override suspend fun receiveCatching(): SocketResult = try { SocketResult.success(receive()) diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReplySocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReplySocket.kt index 6c4e389..d456d70 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReplySocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsReplySocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.Reply as ZReply internal class ZeroMQJsReplySocket internal constructor( @@ -15,5 +16,5 @@ internal class ZeroMQJsReplySocket internal constructor( SendSocket by ZeroMQJsSendSocket(underlying), ReplySocket { - override var routingId: ByteArray? by underlying::routingId.asNullableByteArrayProperty() + override var routingId: ByteString? by underlying::routingId.asNullableByteStringProperty() } diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRequestSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRequestSocket.kt index b4e35ea..2e262ef 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRequestSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRequestSocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.Request as ZRequest internal class ZeroMQJsRequestSocket internal constructor( @@ -15,7 +16,7 @@ internal class ZeroMQJsRequestSocket internal constructor( SendSocket by ZeroMQJsSendSocket(underlying), RequestSocket { - override var routingId: ByteArray? by underlying::routingId.asNullableByteArrayProperty() + override var routingId: ByteString? by underlying::routingId.asNullableByteStringProperty() override var probeRouter: Boolean by underlying::probeRouter override var correlate: Boolean by underlying::correlate override var relaxed: Boolean by underlying::relaxed diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRouterSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRouterSocket.kt index 515ffc4..3d1c89d 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRouterSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsRouterSocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.Router as ZRouter internal class ZeroMQJsRouterSocket internal constructor( @@ -15,7 +16,7 @@ internal class ZeroMQJsRouterSocket internal constructor( ReceiveSocket by ZeroMQJsReceiveSocket(underlying), RouterSocket { - override var routingId: ByteArray? by underlying::routingId.asNullableByteArrayProperty() + override var routingId: ByteString? by underlying::routingId.asNullableByteStringProperty() override var probeRouter: Boolean by underlying::probeRouter override var mandatory: Boolean by underlying::mandatory override var handover: Boolean by underlying::handover diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSendSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSendSocket.kt index f91e8ad..dd93a5e 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSendSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSendSocket.kt @@ -6,12 +6,13 @@ package org.zeromq import kotlinx.coroutines.* +import kotlinx.io.* import org.zeromq.internal.zeromqjs.* internal class ZeroMQJsSendSocket(private val underlying: Writable) : SendSocket { override suspend fun send(message: Message): Unit = - underlying.send(message.frames.map { it.toBuffer() }.toTypedArray()).await() + underlying.send(message.readFrames().map { it.readByteArray().toBuffer() }.toTypedArray()).await() override suspend fun sendCatching(message: Message): SocketResult = try { SocketResult.success(send(message)) diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSubscriberSocket.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSubscriberSocket.kt index 5034b5c..f8064d1 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSubscriberSocket.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/ZeroMQJsSubscriberSocket.kt @@ -5,6 +5,7 @@ package org.zeromq +import kotlinx.io.bytestring.* import org.zeromq.internal.zeromqjs.Subscriber as ZSubscriber internal class ZeroMQJsSubscriberSocket internal constructor(override val underlying: ZSubscriber = ZSubscriber()) : @@ -16,7 +17,7 @@ internal class ZeroMQJsSubscriberSocket internal constructor(override val underl underlying.subscribe() } - override suspend fun subscribe(vararg topics: ByteArray) { + override suspend fun subscribe(vararg topics: ByteString) { underlying.subscribe(*topics.map { it.decodeToString() }.toTypedArray()) } @@ -28,7 +29,7 @@ internal class ZeroMQJsSubscriberSocket internal constructor(override val underl underlying.unsubscribe() } - override suspend fun unsubscribe(vararg topics: ByteArray) { + override suspend fun unsubscribe(vararg topics: ByteString) { underlying.unsubscribe(*topics.map { it.decodeToString() }.toTypedArray()) } diff --git a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/utils.kt b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/utils.kt index 7fad5ab..7e3741a 100644 --- a/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/utils.kt +++ b/kzmq-zeromqjs/src/jsMain/kotlin/org/zeromq/utils.kt @@ -6,6 +6,7 @@ package org.zeromq import Buffer +import kotlinx.io.bytestring.* import org.khronos.webgl.* import kotlin.properties.* import kotlin.reflect.* @@ -28,12 +29,12 @@ internal fun Buffer.toByteArray(): ByteArray { ).unsafeCast() } -internal fun KMutableProperty0.asNullableByteArrayProperty(): ReadWriteProperty { - return object : ReadWriteProperty { - override fun getValue(thisRef: R, property: KProperty<*>): ByteArray? = - this@asNullableByteArrayProperty.get()?.encodeToByteArray() +internal fun KMutableProperty0.asNullableByteStringProperty(): ReadWriteProperty { + return object : ReadWriteProperty { + override fun getValue(thisRef: R, property: KProperty<*>): ByteString? = + this@asNullableByteStringProperty.get()?.encodeToByteString() - override fun setValue(thisRef: R, property: KProperty<*>, value: ByteArray?) = - this@asNullableByteArrayProperty.set(value?.decodeToString()) + override fun setValue(thisRef: R, property: KProperty<*>, value: ByteString?) = + this@asNullableByteStringProperty.set(value?.decodeToString()) } } diff --git a/settings.gradle.kts b/settings.gradle.kts index bbcbd5a..37baebd 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -22,6 +22,7 @@ dependencyResolutionManagement { rootProject.name = "kzmq" include(":kzmq-core") +include(":kzmq-test") include(":kzmq-zeromqjs") include(":kzmq-jeromq") include(":kzmq-libzmq")