diff --git a/ktor-client/ktor-client-plugins/ktor-client-logging/common/src/io/ktor/client/plugins/logging/Logging.kt b/ktor-client/ktor-client-plugins/ktor-client-logging/common/src/io/ktor/client/plugins/logging/Logging.kt index 18e55f3016..472c833eaf 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-logging/common/src/io/ktor/client/plugins/logging/Logging.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-logging/common/src/io/ktor/client/plugins/logging/Logging.kt @@ -96,29 +96,97 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", fun isHeaders(): Boolean = level == LogLevel.HEADERS fun isBody(): Boolean = level == LogLevel.BODY || level == LogLevel.ALL - suspend fun logOutgoingContent(content: OutgoingContent, process: (ByteReadChannel) -> ByteReadChannel = { it }): Pair { + suspend fun detectIfBinary(body: ByteReadChannel, contentLength: Long?, contentType: ContentType?, headers: Headers): Triple { + if (headers.contains(HttpHeaders.ContentEncoding)) { + return Triple(true, contentLength, body) + } + + val charset = if (contentType != null) { + contentType.charset() ?: Charsets.UTF_8 + } else { + Charsets.UTF_8 + } + + var isBinary = false + val firstChunk = ByteArray(4096) + val firstRead = body.readAvailable(firstChunk) + val buffer = Buffer().apply { writeFully(firstChunk, 0, firstRead) } + val firstChunkText = charset.newDecoder().decode(buffer, firstRead) + + var lastCharIndex = -1 + for (ch in firstChunkText) { + lastCharIndex += 1 + } + + for ((i, ch) in firstChunkText.withIndex()) { + if (ch == '\ufffd' && i != lastCharIndex) { + isBinary = true + break + } + } + + if (!isBinary) { + val channel = ByteChannel() + channel.writeFully(firstChunk, 0, firstRead) + val copied = body.copyTo(channel) + channel.flushAndClose() + return Triple(isBinary, copied + firstRead, channel) + } + + return Triple(isBinary, contentLength, body) + } + + suspend fun logRequestBody(content: OutgoingContent, contentLength: Long?, headers: Headers, method: HttpMethod, body: ByteReadChannel) { + val (isBinary, size, newBody) = detectIfBinary(body, contentLength, content.contentType, headers) + + if (!isBinary) { + val contentType = content.contentType + val charset = if (contentType != null) { + contentType.charset() ?: Charsets.UTF_8 + } else { + Charsets.UTF_8 + } + + logger.log(newBody.readRemaining().readText(charset = charset)) + logger.log("--> END ${method.value} ($size-byte body)") + } else { + var type = "binary" + if (headers.contains(HttpHeaders.ContentEncoding)) { + type = "encoded" + } + + if (size != null) { + logger.log("--> END ${method.value} ($type $size-byte body omitted)") + } else { + logger.log("--> END ${method.value} ($type body omitted)") + } + } + } + + suspend fun logOutgoingContent(content: OutgoingContent, method: HttpMethod, headers: Headers, process: (ByteReadChannel) -> ByteReadChannel = { it }): OutgoingContent? { return when(content) { is OutgoingContent.ByteArrayContent -> { - val text = process(ByteReadChannel(content.bytes())).readRemaining().readText() - logger.log(text) - Pair(null, text.length.toLong()) + val bytes = content.bytes() + logRequestBody(content, bytes.size.toLong(), headers, method, ByteReadChannel(bytes)) + null } is OutgoingContent.ContentWrapper -> { - logOutgoingContent(content.delegate(), process) + logOutgoingContent(content.delegate(), method, headers, process) } is OutgoingContent.NoContent -> { logger.log("") - Pair(null, 0L) + logger.log("--> END ${method.value} (0-byte body)") + null } is OutgoingContent.ProtocolUpgrade -> { logger.log("") - Pair(null, 0L) + logger.log("--> END ${method.value} (0-byte body)") + null } is OutgoingContent.ReadChannelContent -> { val (origChannel, newChannel) = content.readFrom().split(GlobalScope) - val text = process(newChannel).readRemaining().readText() - logger.log(text) - Pair(LoggedContent(content, origChannel), text.length.toLong()) + logRequestBody(content, content.contentLength, headers, method, newChannel) + LoggedContent(content, origChannel) } is OutgoingContent.WriteChannelContent -> { val channel = ByteChannel() @@ -126,15 +194,14 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", channel.close() val (origChannel, newChannel) = channel.split(GlobalScope) - val text = process(newChannel).readRemaining().readText() - logger.log(text) - Pair(LoggedContent(content, origChannel), text.length.toLong()) + logRequestBody(content, content.contentLength, headers, method, newChannel) + LoggedContent(content, origChannel) } } } - suspend fun logRequestStdFormat(request: HttpRequestBuilder) { - if (isNone()) return + suspend fun logRequestStdFormat(request: HttpRequestBuilder): OutgoingContent? { + if (isNone()) return null val uri = URLBuilder().takeFrom(request.url).build().pathQuery() val body = request.body @@ -171,7 +238,7 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", logger.log(startLine) if (!level.headers && level != LogLevel.BODY) { - return + return null } for ((name, values) in headers.entries()) { @@ -184,70 +251,43 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", if (!isBody() || request.method == HttpMethod.Get) { logger.log("--> END ${request.method.value}") - return + return null } logger.log("") if (body !is OutgoingContent) { logger.log("--> END ${request.method.value}") - return + return null } - val endLine = if (request.headers[HttpHeaders.ContentEncoding] == "gzip") { - val (newBody, size) = logOutgoingContent(body) { channel -> + val newContent = if (request.headers[HttpHeaders.ContentEncoding] == "gzip") { + logOutgoingContent(body, request.method, headers) { channel -> GZipEncoder.decode(channel) } - - "--> END ${request.method.value} ($size-byte, gzipped)" } else { - val (newBody, size) = logOutgoingContent(body) - "--> END ${request.method.value} ($size-byte)" + logOutgoingContent(body, request.method, headers) } - logger.log(endLine) + return newContent } suspend fun logResponseBody(response: HttpResponse, body: ByteReadChannel) { logger.log("") - val contentType = response.contentType() - - val charset = if (contentType != null) { - contentType.charset() ?: Charsets.UTF_8 - } else { - Charsets.UTF_8 - } - - var isBinary = false - val firstChunk = ByteArray(4096) - val firstRead = body.readAvailable(firstChunk) - val buffer = Buffer().apply { writeFully(firstChunk, 0, firstRead) } - val firstChunkText = charset.newDecoder().decode(buffer, firstRead) - - var lastCharIndex = -1 - for (ch in firstChunkText) { - lastCharIndex += 1 - } - - for ((i, ch) in firstChunkText.withIndex()) { - if (ch == '\ufffd' && i != lastCharIndex) { - isBinary = true - break - } - } + val (isBinary, size, newBody) = detectIfBinary(body, response.contentLength(), response.contentType(), response.headers) val duration = response.responseTime.timestamp - response.requestTime.timestamp - val contentLength = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() if (!isBinary) { - val channel = ByteChannel() - channel.writeFully(firstChunk, 0, firstRead) - val copied = body.copyTo(channel) - channel.flushAndClose() + val contentType = response.contentType() + val charset = if (contentType != null) { + contentType.charset() ?: Charsets.UTF_8 + } else { + Charsets.UTF_8 + } - logger.log(channel.readRemaining().readText(charset = charset)) - val size = copied + firstRead + logger.log(newBody.readRemaining().readText(charset = charset)) logger.log("<-- END HTTP (${duration}ms, $size-byte body)") } else { var type = "binary" @@ -255,8 +295,8 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", type = "encoded" } - if (contentLength != null) { - logger.log("<-- END HTTP (${duration}ms, $type $contentLength-byte body omitted)") + if (size != null) { + logger.log("<-- END HTTP (${duration}ms, $type $size-byte body omitted)") } else { logger.log("<-- END HTTP (${duration}ms, $type body omitted)") } @@ -338,7 +378,8 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", get() = request.headers } this.response = object : HttpResponse() { - override val call: HttpClientCall = self + override val call: HttpClientCall + get() = self override val status: HttpStatusCode get() = response.status override val version: HttpProtocolVersion @@ -452,10 +493,14 @@ public val Logging: ClientPlugin = createClientPlugin("Logging", } if (stdFormat) { - logRequestStdFormat(request) + val content = logRequestStdFormat(request) try { - proceed() + if (content != null) { + proceedWith(content) + } else { + proceed() + } } catch (cause: Throwable) { logger.log("<-- HTTP FAILED: $cause") throw cause diff --git a/ktor-client/ktor-client-plugins/ktor-client-logging/jvm/test/io/ktor/client/plugins/logging/NewFormatTest.kt b/ktor-client/ktor-client-plugins/ktor-client-logging/jvm/test/io/ktor/client/plugins/logging/NewFormatTest.kt index 5d75590e66..872a0a53ca 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-logging/jvm/test/io/ktor/client/plugins/logging/NewFormatTest.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-logging/jvm/test/io/ktor/client/plugins/logging/NewFormatTest.kt @@ -16,9 +16,12 @@ import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode import io.ktor.http.content.OutgoingContent import io.ktor.http.content.TextContent +import io.ktor.http.contentType import io.ktor.util.GZipEncoder import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.readText +import io.ktor.utils.io.writeFully import io.ktor.utils.io.writeStringUtf8 import kotlinx.coroutines.Job import kotlinx.coroutines.test.runTest @@ -610,24 +613,126 @@ class NewFormatTest { .assertNoMoreLogs() } -// @Test -// fun bodyPost() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> -// client.post("/") { -// setBody("test") -// } -// log.assertLogEqual("--> POST /") -// .assertLogEqual("Content-Type: text/plain; charset=UTF-8") -// .assertLogEqual("Content-Length: 4") -// .assertLogEqual("Accept-Charset: UTF-8") -// .assertLogEqual("Accept: */*") -// .assertLogEqual("") -// .assertLogEqual("test") -// .assertLogEqual("--> END POST (4-byte body)") -// .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) -// .assertLogEqual("Content-Length: 0") -// .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) -// .assertNoMoreLogs() -// } + @Test + fun bodyPost() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> + client.post("/") { + setBody("test") + } + log.assertLogEqual("--> POST /") + .assertLogEqual("Content-Type: text/plain; charset=UTF-8") + .assertLogEqual("Content-Length: 4") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("test") + .assertLogEqual("--> END POST (4-byte body)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } + + @Test + fun bodyPostReadChannel() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> + client.post("/") { + setBody(ByteReadChannel("test")) + contentType(ContentType.Text.Plain) + } + log.assertLogEqual("--> POST / (unknown-byte body)") + .assertLogEqual("Content-Type: text/plain") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("test") + .assertLogEqual("--> END POST (4-byte body)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } + + @Test + fun bodyPostReadChannelNotConsumed() = testWithLevel(LogLevel.BODY, handle = { + assertEquals("test", it.body.toByteReadPacket().readText()) + respondWithLength() + }) { client -> + client.post("/") { + setBody(ByteReadChannel("test")) + contentType(ContentType.Text.Plain) + } + log.assertLogEqual("--> POST / (unknown-byte body)") + .assertLogEqual("Content-Type: text/plain") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("test") + .assertLogEqual("--> END POST (4-byte body)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } + + // TODO: Test cancellation while logging the request (GlobalScope splitting) + // TODO: Test consumption of request body + + @Test + fun bodyPostBinaryReadChannel() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> + client.post("/") { + setBody(ByteReadChannel(byteArrayOf(0xC3.toByte(), 0x28))) + } + log.assertLogEqual("--> POST / (unknown-byte body)") + .assertLogEqual("Content-Type: application/octet-stream") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("--> END POST (binary body omitted)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } + + @Test + fun bodyPostBinaryWriteChannel() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> + client.post("/") { + setBody(object : OutgoingContent.WriteChannelContent() { + override suspend fun writeTo(channel: ByteWriteChannel) { + channel.writeFully(byteArrayOf(0xC3.toByte(), 0x28)) + channel.flushAndClose() + } + }) + } + log.assertLogEqual("--> POST / (unknown-byte body)") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("--> END POST (binary body omitted)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } + + @Test + fun bodyPostBinaryArrayContent() = testWithLevel(LogLevel.BODY, handle = { respondWithLength() }) { client -> + client.post("/") { + setBody(object : OutgoingContent.ByteArrayContent() { + override fun bytes(): ByteArray { + return byteArrayOf(0xC3.toByte(), 0x28) + } + }) + } + log.assertLogEqual("--> POST / (2-byte body)") + .assertLogEqual("Accept-Charset: UTF-8") + .assertLogEqual("Accept: */*") + .assertLogEqual("") + .assertLogEqual("--> END POST (binary 2-byte body omitted)") + .assertLogMatch(Regex("""<-- 200 OK / \(\d+ms\)""")) + .assertLogEqual("Content-Length: 0") + .assertLogMatch(Regex("""<-- END HTTP \(\d+ms, 0-byte body\)""")) + .assertNoMoreLogs() + } @Test fun bodyGetWithResponseBody() = testWithLevel(LogLevel.BODY, handle = { respondWithLength("hello!") }) { client -> @@ -872,8 +977,6 @@ class NewFormatTest { .assertNoMoreLogs() } - // TODO: Check partial content - private fun MockRequestHandleScope.respondWithLength(): HttpResponseData { return respond("", headers = Headers.build { append("Content-Length", "0")