Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch server to a non-blocking datagram channel #63

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class KAnonVpnService: VpnService(), VpnUiService, ConnectedUsersChangedCallback
private val packetDumper = PcapNgTcpServerPacketDumper(callback = this, isSimple = false)
private lateinit var viewModel: KAnonViewModel

private val server = Server(IcmpAndroid, protector = this)
private lateinit var server: Server
private lateinit var client: AndroidClient
private lateinit var vpnFileDescriptor: ParcelFileDescriptor

Expand All @@ -59,6 +59,10 @@ class KAnonVpnService: VpnService(), VpnUiService, ConnectedUsersChangedCallback
override fun startVPN() {
// todo: put an atomic boolean here to prevent multiple starts

val serverChannel = DatagramChannel.open()
serverChannel.configureBlocking(false)
serverChannel.bind(InetSocketAddress(DEFAULT_PORT))
server = Server(datagramChannel = serverChannel, icmp = IcmpAndroid, protector = this)
server.start()

val builder = Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.jasonernst.kanonproxy.tuntap.TunTapDevice
import com.jasonernst.packetdumper.AbstractPacketDumper
import com.jasonernst.packetdumper.DummyPacketDumper
import com.jasonernst.packetdumper.serverdumper.PcapNgTcpServerPacketDumper
import org.slf4j.LoggerFactory
import sun.misc.Signal
import java.net.InetSocketAddress
import java.nio.channels.DatagramChannel
Expand All @@ -13,28 +14,31 @@ class LinuxClient(
datagramChannel: DatagramChannel,
packetDumper: AbstractPacketDumper = DummyPacketDumper,
) : Client(datagramChannel, packetDumper) {
private val logger = LoggerFactory.getLogger(javaClass)
private val tunTapDevice = TunTapDevice()

init {
tunTapDevice.open()
}

companion object {
private val staticLogger = LoggerFactory.getLogger(LinuxClient::class.java)

@JvmStatic
fun main(args: Array<String>) {
val packetDumper = PcapNgTcpServerPacketDumper()
packetDumper.start()

val client =
if (args.isEmpty()) {
println("Using default server: 127.0.0.1 $DEFAULT_PORT")
staticLogger.debug("Using default server: 127.0.0.1 $DEFAULT_PORT")
val datagramChannel = DatagramChannel.open()
datagramChannel.configureBlocking(false)
datagramChannel.connect(InetSocketAddress("127.0.0.1", DEFAULT_PORT))
LinuxClient(datagramChannel = datagramChannel, packetDumper = packetDumper)
} else {
if (args.size != 2) {
println("Usage: Client <server> <port>")
staticLogger.warn("Usage: Client <server> <port>")
return
}
val server = args[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class TunTapDevice {
readBytes: ByteArray,
bytesToRead: Int,
): Int {
if (isRunning.get().not()) {
return -1
}
// this will block until the selector puts something here, to unblock when we're shutting
// down, just stick an empty buffer in the outgoing queue
val buffer = outgoingQueue.take()
Expand Down
15 changes: 3 additions & 12 deletions server/src/main/kotlin/com/jasonernst/kanonproxy/ProxySession.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean

class ProxySession(
private val clientAddress: InetSocketAddress,
private val kAnonProxy: KAnonProxy,
private val socket: DatagramSocket,
private val sessionManager: ProxySessionManager,
private val packetDumper: AbstractPacketDumper = DummyPacketDumper,
) {
Expand Down Expand Up @@ -50,15 +47,9 @@ class ProxySession(
break
}
// logger.debug("Received response from proxy for client: $clientAddress, sending datagram back")
val buffer = response.toByteArray()
packetDumper.dumpBuffer(ByteBuffer.wrap(buffer), etherType = EtherType.DETECT)
val datagramPacket = DatagramPacket(buffer, buffer.size, clientAddress)
try {
socket.send(datagramPacket)
} catch (e: Exception) {
logger.debug("Error trying to write to proxy server, probably shutting down: $e")
break
}
val buffer = ByteBuffer.wrap(response.toByteArray())
packetDumper.dumpBuffer(buffer, etherType = EtherType.DETECT)
sessionManager.enqueueOutgoing(clientAddress, buffer)
}
sessionManager.removeSession(clientAddress)
isRunning.set(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package com.jasonernst.kanonproxy

import java.net.InetSocketAddress
import java.nio.ByteBuffer

interface ProxySessionManager {
fun enqueueOutgoing(
clientAddress: InetSocketAddress,
buffer: ByteBuffer,
)

fun removeSession(clientAddress: InetSocketAddress)
}
187 changes: 129 additions & 58 deletions server/src/main/kotlin/com/jasonernst/kanonproxy/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.jasonernst.kanonproxy

import com.jasonernst.icmp.common.Icmp
import com.jasonernst.icmp.linux.IcmpLinux
import com.jasonernst.kanonproxy.ChangeRequest.Companion.CHANGE_OPS
import com.jasonernst.kanonproxy.ChangeRequest.Companion.REGISTER
import com.jasonernst.knet.Packet
import com.jasonernst.packetdumper.AbstractPacketDumper
import com.jasonernst.packetdumper.DummyPacketDumper
Expand All @@ -15,48 +17,62 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import sun.misc.Signal
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.nio.channels.SelectionKey.OP_READ
import java.nio.channels.SelectionKey.OP_WRITE
import java.nio.channels.Selector
import java.util.LinkedList
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.atomic.AtomicBoolean

/**
* The datagram channel should be configured as non-blocking and already bound to the port it is
* listening on before being passed to here.
*/
class Server(
icmp: Icmp,
private val port: Int = KAnonProxy.DEFAULT_PORT,
private val datagramChannel: DatagramChannel,
private val packetDumper: AbstractPacketDumper = DummyPacketDumper,
protector: VpnProtector = DummyProtector,
) : ProxySessionManager {
private val logger = LoggerFactory.getLogger(javaClass)
private lateinit var socket: DatagramSocket
private val isRunning = AtomicBoolean(false)
private val kAnonProxy = KAnonProxy(icmp, protector)
private val sessions = ConcurrentHashMap<InetSocketAddress, ProxySession>()

private lateinit var readFromClientJob: CompletableJob
private lateinit var readFromClientJobScope: CoroutineScope
private lateinit var selector: Selector
private lateinit var selectorJob: CompletableJob
private lateinit var selectorScope: CoroutineScope
private val outgoingQueue = LinkedBlockingDeque<Pair<InetSocketAddress, ByteBuffer>>() // queue of data to be sent to clients
private val changeRequests = LinkedList<ChangeRequest>()

companion object {
private const val MAX_STREAM_BUFFER_SIZE = 1048576 // max we can write into the stream without parsing
private val staticLogger = LoggerFactory.getLogger(Server::class.java)
private const val MAX_RECEIVE_BUFFER_SIZE = 1500 // max amount we can recv in one read (should be the MTU or bigger probably)

@JvmStatic
fun main(args: Array<String>) {
// listen on one port higher so we don't conflict with the client
val packetDumper = PcapNgTcpServerPacketDumper(listenPort = PcapNgTcpServerPacketDumper.DEFAULT_PORT + 1)
val server =
if (args.isEmpty()) {
println("Using default port: ${KAnonProxy.DEFAULT_PORT}")
Server(IcmpLinux)
} else {
if (args.size != 1) {
println("Usage: Server <port>")
return
}
val port = args[0].toInt()
Server(IcmpLinux, port)

val datagramChannel = DatagramChannel.open()
if (args.isEmpty()) {
staticLogger.debug("Server listening on default port: ${KAnonProxy.DEFAULT_PORT}")
datagramChannel.bind(InetSocketAddress(KAnonProxy.DEFAULT_PORT))
} else {
if (args.size != 1) {
staticLogger.warn("Usage: Server <port>")
return
}
val port = args[0].toInt()
datagramChannel.bind(InetSocketAddress(port))
}
datagramChannel.configureBlocking(false)

val server = Server(icmp = IcmpLinux, datagramChannel = datagramChannel)
packetDumper.start()
server.start()

Expand All @@ -76,58 +92,112 @@ class Server(
}
isRunning.set(true)
kAnonProxy.start()
readFromClientJob = SupervisorJob()
readFromClientJobScope = CoroutineScope(Dispatchers.IO + readFromClientJob)
readFromClientJobScope.launch {
readFromClientWriteToProxy()
selector = Selector.open()

selectorJob = SupervisorJob()
selectorScope = CoroutineScope(Dispatchers.IO + selectorJob)
selectorScope.launch {
selectorLoop()
}
}

private fun waitUntilShutdown() {
runBlocking {
readFromClientJob.join()
selectorJob.join()
}
}

private fun readFromClientWriteToProxy() {
Thread.currentThread().name = "Server proxy listener"
logger.debug("Starting server on port: $port")
socket = DatagramSocket(port)

val buffer = ByteArray(MAX_RECEIVE_BUFFER_SIZE)
val packet = DatagramPacket(buffer, buffer.size)
val stream = ByteBuffer.allocate(MAX_STREAM_BUFFER_SIZE)
private fun selectorLoop() {
datagramChannel.register(selector, OP_READ)

while (isRunning.get()) {
synchronized(changeRequests) {
for (changeRequest in changeRequests) {
when (changeRequest.type) {
REGISTER -> {
// logger.debug("Processing REGISTER: ${changeRequest.ops}")
changeRequest.channel.register(selector, changeRequest.ops)
}

CHANGE_OPS -> {
// logger.debug("Processing CHANGE_OPS: ${changeRequest.ops}")
val key = changeRequest.channel.keyFor(selector)
key.interestOps(changeRequest.ops)
}
}
}
changeRequests.clear()
}

try {
socket.receive(packet)
val numKeys = selector.select()
// we won't get any keys if we wakeup the selector before we select
// (ie, when we make changes to the keys or interest-ops)
if (numKeys > 0) {
val selectedKeys = selector.selectedKeys()
val keyStream = selectedKeys.parallelStream()
keyStream.forEach {
if (it.isReadable && it.isValid) {
logger.debug("READ")
readFromClient()
}
if (it.isWritable && it.isValid) {
if (outgoingQueue.isNotEmpty()) {
val outgoingPair = outgoingQueue.take()
val clientAddress = outgoingPair.first
val buffer = outgoingPair.second
while (buffer.hasRemaining()) {
datagramChannel.send(buffer, clientAddress)
}
} else {
it.interestOps(OP_READ)
}
}
}
selectedKeys.clear()
}
} catch (e: Exception) {
logger.warn("Error trying to receive on server socket, probably shutting down: $e")
logger.warn("Exception on select, probably shutting down: $e")
break
}
stream.put(buffer, 0, packet.length)
stream.flip()
val packets = Packet.parseStream(stream)
for (p in packets) {
packetDumper.dumpBuffer(ByteBuffer.wrap(p.toByteArray()), etherType = EtherType.DETECT)
}
val clientAddress = InetSocketAddress(packet.address, packet.port)
kAnonProxy.handlePackets(packets, clientAddress)
var newSession = false
sessions.getOrPut(clientAddress) {
newSession = true
val session = ProxySession(clientAddress, kAnonProxy, socket, this, packetDumper)
session.start()
session
}
if (newSession) {
logger.warn("New proxy session for client: $clientAddress")
} else {
// logger.debug("Continuing to use existing proxy session for client: $clientAddress")
}
}
logger.warn("Server no longer listening")
readFromClientJob.complete()
selectorJob.complete()
}

private fun readFromClient() {
// since each of these receives could be potentially from separate clients, we can't try
// to parse different subsequent reads together - it MUST all fit in a single read.
val buffer = ByteBuffer.allocate(MAX_RECEIVE_BUFFER_SIZE)
val clientAddress = datagramChannel.receive(buffer) as InetSocketAddress
buffer.flip()
val packets = Packet.parseStream(buffer)
for (p in packets) {
packetDumper.dumpBuffer(ByteBuffer.wrap(p.toByteArray()), etherType = EtherType.DETECT)
}
kAnonProxy.handlePackets(packets, clientAddress)
var newSession = false
sessions.getOrPut(clientAddress) {
newSession = true
val session = ProxySession(clientAddress, kAnonProxy, this, packetDumper)
session.start()
session
}
if (newSession) {
logger.warn("New proxy session for client: $clientAddress")
} else {
// logger.debug("Continuing to use existing proxy session for client: $clientAddress")
}
}

override fun enqueueOutgoing(
clientAddress: InetSocketAddress,
buffer: ByteBuffer,
) {
outgoingQueue.add(Pair(clientAddress, buffer))
synchronized(changeRequests) {
changeRequests.add(ChangeRequest(datagramChannel, CHANGE_OPS, OP_WRITE))
}
selector.wakeup()
}

override fun removeSession(clientAddress: InetSocketAddress) {
Expand All @@ -138,13 +208,14 @@ class Server(
fun stop() {
logger.debug("Stopping server")
isRunning.set(false)
socket.close()
datagramChannel.close()
kAnonProxy.stop()
selector.close()
logger.debug("Stopping outstanding sessions")
sessions.values.forEach { it.stop() }
logger.debug("All sessions stopped, stopping client reader job")
logger.debug("All sessions stopped, stopping selector job")
runBlocking {
readFromClientJob.join()
selectorJob.join()
}
logger.debug("Server stopped")
}
Expand Down
Loading