Skip to content

Commit

Permalink
Fix bug on okHttpStompIntegrationTest
Browse files Browse the repository at this point in the history
Change queue to use topic
Add overloading. method awaitCountAndCheck with sleep strategy
Update StompIntegrationTest

Add message to require

Fix code style

Rewrite stomp message decoder

* Rewrite stomp message decoder using okio buffer instead of java buffer
* Update okio version

Add method awaitCountAndCheck with wait strategy

Change using system time millis to use nanoTime

* change using system time millis to use nanoTime
* add test for websocket connection
* fix code style

Remove extra method

Remove extra comment
  • Loading branch information
ZaltsmanNikita committed Jul 29, 2020
1 parent 7517609 commit 2287e18
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 92 deletions.
2 changes: 1 addition & 1 deletion dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ext {
okHttpServerSentEvent = "com.squareup.okhttp3:okhttp-sse:$okHttpVersion"
okHttpLoggingInterceptor = "com.squareup.okhttp3:logging-interceptor:$okHttpVersion"

okio = 'com.squareup.okio:okio:1.13.0'
okio = 'com.squareup.okio:okio:2.5.0'
mockWebServer = 'com.squareup.okhttp3:mockwebserver:3.11.0'
timber = 'com.jakewharton.timber:timber:4.6.0'
okSse = 'com.github.heremaps:oksse:0.9.0'
Expand Down
1 change: 1 addition & 0 deletions scarlet-protocol-stomp/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies {

implementation project(':scarlet-core-internal')
implementation rootProject.ext.rxJava
implementation rootProject.ext.okio
implementation rootProject.ext.kotlinStdlib

api rootProject.ext.okHttp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class OkHttpStompMessageChannel(
}

override fun createMessageQueue(listener: MessageQueue.Listener): MessageQueue {
require(messageQueueListener == null)
require(messageQueueListener == null) { "message queue was already created" }
messageQueueListener = listener
return this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import com.tinder.scarlet.stomp.okhttp.support.StompMessageDecoder
import com.tinder.scarlet.stomp.okhttp.support.StompMessageEncoder
import okhttp3.WebSocket
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit

/**
* Okhttp websocket based implementation of {@link Connection}.
*/
class WebSocketConnection(
private val webSocket: WebSocket
private val webSocket: WebSocket,
private val executor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor()
) : Connection, MessageHandler {

@Volatile
Expand All @@ -25,8 +27,6 @@ class WebSocketConnection(
@Volatile
private var lastWriteTime: Long = -1

private val executor = Executors.newSingleThreadScheduledExecutor()

private val messageEncoder = StompMessageEncoder()
private val messageDecoder = StompMessageDecoder()

Expand All @@ -42,34 +42,36 @@ class WebSocketConnection(
override fun sendMessage(message: StompMessage): Boolean {
val lastWriteTime = lastWriteTime
if (lastWriteTime != -1L) {
this.lastWriteTime = System.currentTimeMillis()
this.lastWriteTime = System.nanoTime()
}
val encodedMessage = messageEncoder.encode(message)
return webSocket.send(String(encodedMessage))
val encodedMessage = messageEncoder.encode(message).toString(Charsets.UTF_8)
return webSocket.send(encodedMessage)
}

/**
* {@inheritDoc}
*/
override fun onReceiveInactivity(duration: Long, runnable: () -> Unit) {
lastReadTime = System.currentTimeMillis()
check(duration > 0) { "Duration must be more than 0" }
lastReadTime = System.nanoTime()
executor.scheduleWithFixedDelay({
if (System.currentTimeMillis() - lastReadTime > duration) {
if ((System.nanoTime() - lastReadTime) > TimeUnit.MILLISECONDS.toNanos(duration)) {
runnable.invoke()
}
}, 0, duration / 2, TimeUnit.MILLISECONDS)
}, 0, duration / 2L, TimeUnit.MILLISECONDS)
}

/**
* {@inheritDoc}
*/
override fun onWriteInactivity(duration: Long, runnable: () -> Unit) {
lastWriteTime = System.currentTimeMillis()
check(duration > 0) { "Duration must be more than 0" }
lastWriteTime = System.nanoTime()
executor.scheduleWithFixedDelay({
if (System.currentTimeMillis() - lastWriteTime > duration) {
if ((System.nanoTime() - lastWriteTime) > TimeUnit.MILLISECONDS.toNanos(duration)) {
runnable.invoke()
}
}, 0, duration / 2, TimeUnit.MILLISECONDS)
}, 0, duration / 2L, TimeUnit.MILLISECONDS)
}

/**
Expand All @@ -94,7 +96,7 @@ class WebSocketConnection(
override fun handle(data: ByteArray): StompMessage? {
val lastReadTime = lastReadTime
if (lastReadTime != -1L) {
this.lastReadTime = System.currentTimeMillis()
this.lastReadTime = System.nanoTime()
}
return messageDecoder.decode(data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,4 @@ class StompMessage private constructor(
"headers=$headers" +
")"
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ package com.tinder.scarlet.stomp.okhttp.support

import com.tinder.scarlet.stomp.okhttp.models.StompCommand
import com.tinder.scarlet.stomp.okhttp.models.StompMessage
import okio.BufferedSource
import okio.buffer
import okio.source
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer

/**
* An decoder for STOMP frames.
Expand All @@ -18,18 +20,18 @@ class StompMessageDecoder {
* @param array the array to decode
*/
fun decode(array: ByteArray): StompMessage? {
val byteBuffer = ByteBuffer.wrap(array)
return decode(byteBuffer)
val buffer = array.inputStream().source().buffer()
return decode(buffer)
}

private fun decode(byteBuffer: ByteBuffer): StompMessage? {
fun decode(byteBuffer: BufferedSource): StompMessage? {
skipLeadingEol(byteBuffer)
val stompCommand = readCommand(byteBuffer) ?: return null

return if (stompCommand != StompCommand.HEARTBEAT) {
val headerAccessor = StompHeaderAccessor.of()

val payload = if (byteBuffer.isNotEmpty()) {
val payload = if (!byteBuffer.exhausted()) {
readHeaders(byteBuffer, headerAccessor)
readPayloadOrNull(byteBuffer, headerAccessor) ?: return null
} else {
Expand All @@ -46,62 +48,64 @@ class StompMessageDecoder {
}
}

private fun skipLeadingEol(byteBuffer: ByteBuffer) {
private fun skipLeadingEol(byteBuffer: BufferedSource) {
while (true) {
if (!tryConsumeEndOfLine(byteBuffer)) break
}
}

private fun readPayloadOrNull(
byteBuffer: ByteBuffer,
bufferedSource: BufferedSource,
headerAccessor: StompHeaderAccessor
): ByteArray? {
val contentLength = headerAccessor.contentLength
return if (contentLength != null && contentLength >= 0) {
readPayloadWithContentLength(byteBuffer, contentLength)
readPayloadWithContentLength(bufferedSource, contentLength)
} else {
readPayloadWithoutContentLength(byteBuffer)
readPayloadWithoutContentLength(bufferedSource)
}
}

private fun readPayloadWithContentLength(
byteBuffer: ByteBuffer,
bufferedSource: BufferedSource,
contentLength: Int
) = byteBuffer
.takeIf { buffer -> buffer.remaining() > contentLength }
?.let { buffer ->
val payload = ByteArray(contentLength)
buffer.get(payload)

val lastSymbolIsNullOctet = byteBuffer.get().toInt() == 0
check(lastSymbolIsNullOctet) { "Frame must be terminated with a null octet" }
payload
}
): ByteArray? {
if (bufferedSource.exhausted()) return null

val payload = ByteArray(contentLength)
bufferedSource.read(payload)

if (bufferedSource.exhausted()) return null
val lastSymbolIsNullOctet = bufferedSource.readUtf8CodePoint() == 0
check(lastSymbolIsNullOctet) { "Frame must be terminated with a null octet" }

private fun readPayloadWithoutContentLength(byteBuffer: ByteBuffer): ByteArray? {
return payload
}

private fun readPayloadWithoutContentLength(buffer: BufferedSource): ByteArray? {
val payload = ByteArrayOutputStream(256)
while (byteBuffer.isNotEmpty()) {
val byte = byteBuffer.get()
if (byte.toInt() != 0) {
payload.write(byte.toInt())
while (!buffer.exhausted()) {
val codePoint = buffer.readUtf8CodePoint()
if (codePoint != 0) {
payload.write(codePoint)
} else {
return payload.toByteArray()
}
}
return null
}

private fun readHeaders(byteBuffer: ByteBuffer, headerAccessor: StompHeaderAccessor) {
private fun readHeaders(byteBuffer: BufferedSource, headerAccessor: StompHeaderAccessor) {
while (true) {
val headerStream = ByteArrayOutputStream(256)
var headerComplete = false

while (byteBuffer.hasRemaining()) {
while (!byteBuffer.exhausted()) {
if (tryConsumeEndOfLine(byteBuffer)) {
headerComplete = true
break
}
headerStream.write(byteBuffer.get().toInt())
headerStream.write(byteBuffer.readUtf8CodePoint())
}

if (headerStream.size() > 0 && headerComplete) {
Expand All @@ -114,18 +118,18 @@ class StompMessageDecoder {

headerAccessor[headerName] = headerValue
} else {
check(byteBuffer.isEmpty()) { "Illegal header: '$header'. A header must be of the form <name>:[<value>]." }
check(byteBuffer.exhausted()) { "Illegal header: '$header'. A header must be of the form <name>:[<value>]." }
}
} else {
break
}
}
}

private fun readCommand(byteBuffer: ByteBuffer): StompCommand? {
private fun readCommand(byteBuffer: BufferedSource): StompCommand? {
val command = ByteArrayOutputStream(256)
while (byteBuffer.isNotEmpty() && !tryConsumeEndOfLine(byteBuffer)) {
command.write(byteBuffer.get().toInt())
while (!byteBuffer.exhausted() && !tryConsumeEndOfLine(byteBuffer)) {
command.write(byteBuffer.readUtf8CodePoint())
}
val commandString = command.toByteArray().toString(Charsets.UTF_8)
return try {
Expand All @@ -143,18 +147,24 @@ class StompMessageDecoder {
* Try to read an EOL incrementing the buffer position if successful.
* @return whether an EOL was consumed
*/
private fun tryConsumeEndOfLine(byteBuffer: ByteBuffer): Boolean = byteBuffer
.takeIf { buffer -> buffer.isNotEmpty() }
?.let { buffer ->
when (byteBuffer.get()) {
'\n'.toByte() -> true
'\r'.toByte() -> checkSequence(byteBuffer)
else -> {
buffer.position(buffer.position() - 1)
false
}
private fun tryConsumeEndOfLine(bufferedSource: BufferedSource): Boolean {
if (bufferedSource.exhausted()) return false
val peekSource = bufferedSource.peek()

return when (peekSource.readUtf8CodePoint().toChar()) {
'\n' -> {
bufferedSource.skip(1)
true
}
} ?: false
'\r' -> {
val nextChartIsNewLine = peekSource.readUtf8CodePoint().toChar() == '\n'
check(!peekSource.exhausted() && nextChartIsNewLine) { "'\\r' must be followed by '\\n'" }
bufferedSource.skip(2)
true
}
else -> false
}
}

/**
* See STOMP Spec 1.2:
Expand All @@ -181,14 +191,4 @@ class StompMessageDecoder {
stringBuilder.append(inString.substring(pos))
return stringBuilder.toString()
}

private fun checkSequence(byteBuffer: ByteBuffer): Boolean {
val nextChartIsNewLine = byteBuffer.get() == '\n'.toByte()
check(byteBuffer.remaining() > 0 && nextChartIsNewLine) { "'\\r' must be followed by '\\n'" }
return true
}

private fun ByteBuffer.isNotEmpty(): Boolean = remaining() > 0

private fun ByteBuffer.isEmpty(): Boolean = remaining() == 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.tinder.scarlet.testutils.rule.GozirraStompConnection
import com.tinder.scarlet.testutils.test
import com.tinder.scarlet.ws.Receive
import com.tinder.scarlet.ws.Send
import org.apache.activemq.command.ActiveMQDestination
import org.apache.activemq.junit.EmbeddedActiveMQBroker
import org.apache.activemq.transport.stomp.StompConnection
import org.junit.Rule
Expand All @@ -20,6 +21,8 @@ class StompIntegrationTest {
@get:Rule
val broker = object : EmbeddedActiveMQBroker() {
override fun configure() {
val destination = ActiveMQDestination.createDestination(SERVER_DESTINATION, 0)
brokerService.destinations = arrayOf(destination)
brokerService.addConnector(BROKER_URL)
}
}
Expand All @@ -31,7 +34,7 @@ class StompIntegrationTest {
PORT,
LOGIN,
PASSWORD,
DESTINATION
CLIENT_DESTINATION
)
)
@get:Rule
Expand All @@ -42,7 +45,7 @@ class StompIntegrationTest {
PORT,
LOGIN,
PASSWORD,
DESTINATION
CLIENT_DESTINATION
)
)

Expand All @@ -58,7 +61,7 @@ class StompIntegrationTest {
connection2.open()

LOGGER.info("${queueTextObserver.values}")
queueTextObserver.awaitCountAtLeast(1) // because broker has a bug and it loses messages sometimes
queueTextObserver.awaitCountAndCheck(2)
}

@Test
Expand All @@ -73,8 +76,8 @@ class StompIntegrationTest {
PASSWORD
)
connection1.begin("tx1")
connection1.send(DESTINATION, "message1", "tx1", null)
connection1.send(DESTINATION, "message2", "tx1", null)
connection1.send(CLIENT_DESTINATION, "message1", "tx1", null)
connection1.send(CLIENT_DESTINATION, "message2", "tx1", null)
connection1.commit("tx1")
connection1.disconnect()

Expand All @@ -95,7 +98,8 @@ class StompIntegrationTest {
private const val LOGIN = "system"
private const val PASSWORD = "manager"
private const val BROKER_URL = "stomp://$HOST:$PORT"
private const val DESTINATION = "/queue/test"
private const val SERVER_DESTINATION = "queue://test"
private const val CLIENT_DESTINATION = "/queue/test"

interface StompQueueTestService {
@Receive
Expand Down
Loading

0 comments on commit 2287e18

Please sign in to comment.