Skip to content

Commit

Permalink
KTOR-2036 Fix CIO connection limit (#3140)
Browse files Browse the repository at this point in the history
* KTOR-2036 Fix CIO connection limit
  • Loading branch information
e5l authored Sep 2, 2022
1 parent 67ba0f2 commit aab530a
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.ktor.client.engine.cio

import io.ktor.client.call.*
import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
Expand Down Expand Up @@ -147,13 +148,19 @@ class CIORequestTest : TestWithKtor() {
}

test { client ->
var fail: Throwable? = null
for (i in 0..1000) {
try {
client.get("http://something.wrong").body<String>()
} catch (cause: UnresolvedAddressException) {
// ignore
} catch (cause: Throwable) {
fail = cause
}
}

assertNotNull(fail)
if (fail !is ConnectTimeoutException && fail !is UnresolvedAddressException) {
fail("Expected ConnectTimeoutException or UnresolvedAddressException, got $fail", fail)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ internal class CIOEngine(

private val selectorManager: SelectorManager by lazy { SelectorManager(dispatcher) }

private val connectionFactory = ConnectionFactory(selectorManager, config.maxConnectionsCount)
private val connectionFactory = ConnectionFactory(
selectorManager,
config.maxConnectionsCount,
config.endpoint.maxConnectionsPerRoute
)

private val requestsJob: CoroutineContext

Expand All @@ -42,6 +46,7 @@ internal class CIOEngine(
private val proxy: ProxyConfig? = when (val type = config.proxy?.type) {
ProxyType.SOCKS,
null -> null

ProxyType.HTTP -> config.proxy
else -> throw IllegalStateException("CIO engine does not currently support $type proxies.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@ package io.ktor.client.engine.cio

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.util.collections.*
import kotlinx.coroutines.sync.*

internal class ConnectionFactory(
private val selector: SelectorManager,
maxConnectionsCount: Int
connectionsLimit: Int,
private val addressConnectionsLimit: Int
) {
private val semaphore = Semaphore(maxConnectionsCount)
private val limit = Semaphore(connectionsLimit)
private val addressLimit = ConcurrentMap<InetSocketAddress, Semaphore>()

suspend fun connect(
address: InetSocketAddress,
configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {}
): Socket {
semaphore.acquire()
limit.acquire()
val addressSemaphore = addressLimit.computeIfAbsent(address) { Semaphore(addressConnectionsLimit) }
addressSemaphore.acquire()

return try {
aSocket(selector).tcpNoDelay().tcp().connect(address, configuration)
} catch (cause: Throwable) {
// a failure or cancellation
semaphore.release()
addressSemaphore.release()
limit.release()
throw cause
}
}

fun release() {
semaphore.release()
fun release(address: InetSocketAddress) {
addressLimit[address]!!.release()
limit.release()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ internal class Endpoint(
} catch (_: Throwable) {
}

connectionFactory.release()
connectionFactory.release(address)
throw cause
}
}
Expand Down Expand Up @@ -229,7 +229,8 @@ internal class Endpoint(
}

private fun releaseConnection() {
connectionFactory.release()
val address = InetSocketAddress(host, port)
connectionFactory.release(address)
connections.decrementAndGet()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public suspend fun OutgoingContent.toByteArray(): ByteArray = when (this) {
else -> ByteArray(0)
}

@OptIn(DelicateCoroutinesApi::class)
@Suppress("KDocMissingDocumentation")
@OptIn(DelicateCoroutinesApi::class)
public suspend fun OutgoingContent.toByteReadPacket(): ByteReadPacket = when (this) {
is OutgoingContent.ByteArrayContent -> ByteReadPacket(bytes())
is OutgoingContent.ReadChannelContent -> readFrom().readRemaining()
Expand Down

0 comments on commit aab530a

Please sign in to comment.