Skip to content

Commit

Permalink
Merge pull request #63 from compscidr/jason/non-blocking-server
Browse files Browse the repository at this point in the history
Switch server to a non-blocking datagram channel
  • Loading branch information
compscidr authored Dec 16, 2024
2 parents 9c28aeb + c3cb49f commit f55a97e
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class KAnonVpnService: VpnService(), VpnUiService, ConnectedUsersChangedCallback
private val packetDumper = PcapNgTcpServerPacketDumper(callback = this, isSimple = false)
private lateinit var viewModel: KAnonViewModel

private val server = Server(IcmpAndroid, protector = this)
private lateinit var server: Server
private lateinit var client: AndroidClient
private lateinit var vpnFileDescriptor: ParcelFileDescriptor

Expand All @@ -59,6 +59,10 @@ class KAnonVpnService: VpnService(), VpnUiService, ConnectedUsersChangedCallback
override fun startVPN() {
// todo: put an atomic boolean here to prevent multiple starts

val serverChannel = DatagramChannel.open()
serverChannel.configureBlocking(false)
serverChannel.bind(InetSocketAddress(DEFAULT_PORT))
server = Server(datagramChannel = serverChannel, icmp = IcmpAndroid, protector = this)
server.start()

val builder = Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.jasonernst.kanonproxy.tuntap.TunTapDevice
import com.jasonernst.packetdumper.AbstractPacketDumper
import com.jasonernst.packetdumper.DummyPacketDumper
import com.jasonernst.packetdumper.serverdumper.PcapNgTcpServerPacketDumper
import org.slf4j.LoggerFactory
import sun.misc.Signal
import java.net.InetSocketAddress
import java.nio.channels.DatagramChannel
Expand All @@ -13,28 +14,31 @@ class LinuxClient(
datagramChannel: DatagramChannel,
packetDumper: AbstractPacketDumper = DummyPacketDumper,
) : Client(datagramChannel, packetDumper) {
private val logger = LoggerFactory.getLogger(javaClass)
private val tunTapDevice = TunTapDevice()

init {
tunTapDevice.open()
}

companion object {
private val staticLogger = LoggerFactory.getLogger(LinuxClient::class.java)

@JvmStatic
fun main(args: Array<String>) {
val packetDumper = PcapNgTcpServerPacketDumper()
packetDumper.start()

val client =
if (args.isEmpty()) {
println("Using default server: 127.0.0.1 $DEFAULT_PORT")
staticLogger.debug("Using default server: 127.0.0.1 $DEFAULT_PORT")
val datagramChannel = DatagramChannel.open()
datagramChannel.configureBlocking(false)
datagramChannel.connect(InetSocketAddress("127.0.0.1", DEFAULT_PORT))
LinuxClient(datagramChannel = datagramChannel, packetDumper = packetDumper)
} else {
if (args.size != 2) {
println("Usage: Client <server> <port>")
staticLogger.warn("Usage: Client <server> <port>")
return
}
val server = args[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class TunTapDevice {
readBytes: ByteArray,
bytesToRead: Int,
): Int {
if (isRunning.get().not()) {
return -1
}
// this will block until the selector puts something here, to unblock when we're shutting
// down, just stick an empty buffer in the outgoing queue
val buffer = outgoingQueue.take()
Expand Down
15 changes: 3 additions & 12 deletions server/src/main/kotlin/com/jasonernst/kanonproxy/ProxySession.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean

class ProxySession(
private val clientAddress: InetSocketAddress,
private val kAnonProxy: KAnonProxy,
private val socket: DatagramSocket,
private val sessionManager: ProxySessionManager,
private val packetDumper: AbstractPacketDumper = DummyPacketDumper,
) {
Expand Down Expand Up @@ -50,15 +47,9 @@ class ProxySession(
break
}
// logger.debug("Received response from proxy for client: $clientAddress, sending datagram back")
val buffer = response.toByteArray()
packetDumper.dumpBuffer(ByteBuffer.wrap(buffer), etherType = EtherType.DETECT)
val datagramPacket = DatagramPacket(buffer, buffer.size, clientAddress)
try {
socket.send(datagramPacket)
} catch (e: Exception) {
logger.debug("Error trying to write to proxy server, probably shutting down: $e")
break
}
val buffer = ByteBuffer.wrap(response.toByteArray())
packetDumper.dumpBuffer(buffer, etherType = EtherType.DETECT)
sessionManager.enqueueOutgoing(clientAddress, buffer)
}
sessionManager.removeSession(clientAddress)
isRunning.set(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package com.jasonernst.kanonproxy

import java.net.InetSocketAddress
import java.nio.ByteBuffer

interface ProxySessionManager {
fun enqueueOutgoing(
clientAddress: InetSocketAddress,
buffer: ByteBuffer,
)

fun removeSession(clientAddress: InetSocketAddress)
}
187 changes: 129 additions & 58 deletions server/src/main/kotlin/com/jasonernst/kanonproxy/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.jasonernst.kanonproxy

import com.jasonernst.icmp.common.Icmp
import com.jasonernst.icmp.linux.IcmpLinux
import com.jasonernst.kanonproxy.ChangeRequest.Companion.CHANGE_OPS
import com.jasonernst.kanonproxy.ChangeRequest.Companion.REGISTER
import com.jasonernst.knet.Packet
import com.jasonernst.packetdumper.AbstractPacketDumper
import com.jasonernst.packetdumper.DummyPacketDumper
Expand All @@ -15,48 +17,62 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import sun.misc.Signal
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.nio.channels.SelectionKey.OP_READ
import java.nio.channels.SelectionKey.OP_WRITE
import java.nio.channels.Selector
import java.util.LinkedList
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.atomic.AtomicBoolean

/**
* The datagram channel should be configured as non-blocking and already bound to the port it is
* listening on before being passed to here.
*/
class Server(
icmp: Icmp,
private val port: Int = KAnonProxy.DEFAULT_PORT,
private val datagramChannel: DatagramChannel,
private val packetDumper: AbstractPacketDumper = DummyPacketDumper,
protector: VpnProtector = DummyProtector,
) : ProxySessionManager {
private val logger = LoggerFactory.getLogger(javaClass)
private lateinit var socket: DatagramSocket
private val isRunning = AtomicBoolean(false)
private val kAnonProxy = KAnonProxy(icmp, protector)
private val sessions = ConcurrentHashMap<InetSocketAddress, ProxySession>()

private lateinit var readFromClientJob: CompletableJob
private lateinit var readFromClientJobScope: CoroutineScope
private lateinit var selector: Selector
private lateinit var selectorJob: CompletableJob
private lateinit var selectorScope: CoroutineScope
private val outgoingQueue = LinkedBlockingDeque<Pair<InetSocketAddress, ByteBuffer>>() // queue of data to be sent to clients
private val changeRequests = LinkedList<ChangeRequest>()

companion object {
private const val MAX_STREAM_BUFFER_SIZE = 1048576 // max we can write into the stream without parsing
private val staticLogger = LoggerFactory.getLogger(Server::class.java)
private const val MAX_RECEIVE_BUFFER_SIZE = 1500 // max amount we can recv in one read (should be the MTU or bigger probably)

@JvmStatic
fun main(args: Array<String>) {
// listen on one port higher so we don't conflict with the client
val packetDumper = PcapNgTcpServerPacketDumper(listenPort = PcapNgTcpServerPacketDumper.DEFAULT_PORT + 1)
val server =
if (args.isEmpty()) {
println("Using default port: ${KAnonProxy.DEFAULT_PORT}")
Server(IcmpLinux)
} else {
if (args.size != 1) {
println("Usage: Server <port>")
return
}
val port = args[0].toInt()
Server(IcmpLinux, port)

val datagramChannel = DatagramChannel.open()
if (args.isEmpty()) {
staticLogger.debug("Server listening on default port: ${KAnonProxy.DEFAULT_PORT}")
datagramChannel.bind(InetSocketAddress(KAnonProxy.DEFAULT_PORT))
} else {
if (args.size != 1) {
staticLogger.warn("Usage: Server <port>")
return
}
val port = args[0].toInt()
datagramChannel.bind(InetSocketAddress(port))
}
datagramChannel.configureBlocking(false)

val server = Server(icmp = IcmpLinux, datagramChannel = datagramChannel)
packetDumper.start()
server.start()

Expand All @@ -76,58 +92,112 @@ class Server(
}
isRunning.set(true)
kAnonProxy.start()
readFromClientJob = SupervisorJob()
readFromClientJobScope = CoroutineScope(Dispatchers.IO + readFromClientJob)
readFromClientJobScope.launch {
readFromClientWriteToProxy()
selector = Selector.open()

selectorJob = SupervisorJob()
selectorScope = CoroutineScope(Dispatchers.IO + selectorJob)
selectorScope.launch {
selectorLoop()
}
}

private fun waitUntilShutdown() {
runBlocking {
readFromClientJob.join()
selectorJob.join()
}
}

private fun readFromClientWriteToProxy() {
Thread.currentThread().name = "Server proxy listener"
logger.debug("Starting server on port: $port")
socket = DatagramSocket(port)

val buffer = ByteArray(MAX_RECEIVE_BUFFER_SIZE)
val packet = DatagramPacket(buffer, buffer.size)
val stream = ByteBuffer.allocate(MAX_STREAM_BUFFER_SIZE)
private fun selectorLoop() {
datagramChannel.register(selector, OP_READ)

while (isRunning.get()) {
synchronized(changeRequests) {
for (changeRequest in changeRequests) {
when (changeRequest.type) {
REGISTER -> {
// logger.debug("Processing REGISTER: ${changeRequest.ops}")
changeRequest.channel.register(selector, changeRequest.ops)
}

CHANGE_OPS -> {
// logger.debug("Processing CHANGE_OPS: ${changeRequest.ops}")
val key = changeRequest.channel.keyFor(selector)
key.interestOps(changeRequest.ops)
}
}
}
changeRequests.clear()
}

try {
socket.receive(packet)
val numKeys = selector.select()
// we won't get any keys if we wakeup the selector before we select
// (ie, when we make changes to the keys or interest-ops)
if (numKeys > 0) {
val selectedKeys = selector.selectedKeys()
val keyStream = selectedKeys.parallelStream()
keyStream.forEach {
if (it.isReadable && it.isValid) {
logger.debug("READ")
readFromClient()
}
if (it.isWritable && it.isValid) {
if (outgoingQueue.isNotEmpty()) {
val outgoingPair = outgoingQueue.take()
val clientAddress = outgoingPair.first
val buffer = outgoingPair.second
while (buffer.hasRemaining()) {
datagramChannel.send(buffer, clientAddress)
}
} else {
it.interestOps(OP_READ)
}
}
}
selectedKeys.clear()
}
} catch (e: Exception) {
logger.warn("Error trying to receive on server socket, probably shutting down: $e")
logger.warn("Exception on select, probably shutting down: $e")
break
}
stream.put(buffer, 0, packet.length)
stream.flip()
val packets = Packet.parseStream(stream)
for (p in packets) {
packetDumper.dumpBuffer(ByteBuffer.wrap(p.toByteArray()), etherType = EtherType.DETECT)
}
val clientAddress = InetSocketAddress(packet.address, packet.port)
kAnonProxy.handlePackets(packets, clientAddress)
var newSession = false
sessions.getOrPut(clientAddress) {
newSession = true
val session = ProxySession(clientAddress, kAnonProxy, socket, this, packetDumper)
session.start()
session
}
if (newSession) {
logger.warn("New proxy session for client: $clientAddress")
} else {
// logger.debug("Continuing to use existing proxy session for client: $clientAddress")
}
}
logger.warn("Server no longer listening")
readFromClientJob.complete()
selectorJob.complete()
}

private fun readFromClient() {
// since each of these receives could be potentially from separate clients, we can't try
// to parse different subsequent reads together - it MUST all fit in a single read.
val buffer = ByteBuffer.allocate(MAX_RECEIVE_BUFFER_SIZE)
val clientAddress = datagramChannel.receive(buffer) as InetSocketAddress
buffer.flip()
val packets = Packet.parseStream(buffer)
for (p in packets) {
packetDumper.dumpBuffer(ByteBuffer.wrap(p.toByteArray()), etherType = EtherType.DETECT)
}
kAnonProxy.handlePackets(packets, clientAddress)
var newSession = false
sessions.getOrPut(clientAddress) {
newSession = true
val session = ProxySession(clientAddress, kAnonProxy, this, packetDumper)
session.start()
session
}
if (newSession) {
logger.warn("New proxy session for client: $clientAddress")
} else {
// logger.debug("Continuing to use existing proxy session for client: $clientAddress")
}
}

override fun enqueueOutgoing(
clientAddress: InetSocketAddress,
buffer: ByteBuffer,
) {
outgoingQueue.add(Pair(clientAddress, buffer))
synchronized(changeRequests) {
changeRequests.add(ChangeRequest(datagramChannel, CHANGE_OPS, OP_WRITE))
}
selector.wakeup()
}

override fun removeSession(clientAddress: InetSocketAddress) {
Expand All @@ -138,13 +208,14 @@ class Server(
fun stop() {
logger.debug("Stopping server")
isRunning.set(false)
socket.close()
datagramChannel.close()
kAnonProxy.stop()
selector.close()
logger.debug("Stopping outstanding sessions")
sessions.values.forEach { it.stop() }
logger.debug("All sessions stopped, stopping client reader job")
logger.debug("All sessions stopped, stopping selector job")
runBlocking {
readFromClientJob.join()
selectorJob.join()
}
logger.debug("Server stopped")
}
Expand Down

0 comments on commit f55a97e

Please sign in to comment.