Skip to content

Commit

Permalink
Merge pull request #61 from compscidr/jason/channels-instead-of-socke…
Browse files Browse the repository at this point in the history
…ts-non-blocking-tun-tap

Client channels instead of sockets, non-blocking IO
  • Loading branch information
compscidr authored Dec 13, 2024
2 parents b5ba9f0 + 5f89dac commit c86249b
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import android.os.ParcelFileDescriptor.AutoCloseOutputStream
import com.jasonernst.packetdumper.AbstractPacketDumper
import com.jasonernst.packetdumper.DummyPacketDumper
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.channels.DatagramChannel

class AndroidClient(
socketAddress: InetSocketAddress = InetSocketAddress("127.0.0.1", 8080),
datagramChannel: DatagramChannel,
packetDumper: AbstractPacketDumper = DummyPacketDumper,
vpnFileDescriptor: ParcelFileDescriptor,
onlyDestinations: List<InetAddress> = emptyList(),
onlyProtocols: List<UByte> = emptyList()
) : Client(socketAddress, packetDumper, onlyDestinations, onlyProtocols) {
) : Client(datagramChannel, packetDumper, onlyDestinations, onlyProtocols) {

private val inputStream = AutoCloseInputStream(vpnFileDescriptor)
private val outputStream = AutoCloseOutputStream(vpnFileDescriptor)
Expand All @@ -27,5 +27,4 @@ class AndroidClient(
outputStream.write(writeBytes)
outputStream.flush()
}

}
1 change: 1 addition & 0 deletions client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jacoco {
}

dependencies {
implementation(project(":core")) // only really for the DEFAULT_PORT
implementation(libs.jna)
implementation(libs.jnr.enxio)
implementation(libs.knet)
Expand Down
207 changes: 131 additions & 76 deletions client/src/main/kotlin/com/jasonernst/kanonproxy/Client.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.jasonernst.kanonproxy

import com.jasonernst.knet.Packet
import com.jasonernst.kanonproxy.ChangeRequest.Companion.CHANGE_OPS
import com.jasonernst.kanonproxy.ChangeRequest.Companion.REGISTER
import com.jasonernst.knet.Packet.Companion.parseStream
import com.jasonernst.knet.SentinelPacket
import com.jasonernst.packetdumper.AbstractPacketDumper
Expand All @@ -13,71 +14,74 @@ import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.nio.channels.SelectionKey.OP_READ
import java.nio.channels.SelectionKey.OP_WRITE
import java.nio.channels.Selector
import java.util.LinkedList
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.math.min

/**
* Abstract client that can support Linux, Android, etc implementations that are specific to their
* tun/tap device.
*
* @param datagramChannel - A datagram channel which has already been set into non-blocking mode
* and connected to the server (ie, just have the server destination addressed associated with
* the channel since UDP sockets can't "connect").
*/
abstract class Client(
private val socketAddress: InetSocketAddress = InetSocketAddress("127.0.0.1", 8080),
private val datagramChannel: DatagramChannel,
private val packetDumper: AbstractPacketDumper = DummyPacketDumper,
private val onlyDestinations: List<InetAddress> = emptyList(),
private val onlyProtocols: List<UByte> = emptyList(),
) {
private val logger = LoggerFactory.getLogger(javaClass)
private val socket = DatagramSocket()

private val isConnected = AtomicBoolean(false)
private lateinit var selector: Selector
private lateinit var selectorJob: CompletableJob
private lateinit var selectorScope: CoroutineScope
private val outgoingQueue = LinkedBlockingDeque<ByteBuffer>() // queue of data for the server
private val changeRequests = LinkedList<ChangeRequest>()
private val fromProxyStream: ByteBuffer = ByteBuffer.allocate(MAX_STREAM_BUFFER_SIZE)

private val isRunning = AtomicBoolean(false)
private lateinit var readFromTunJob: CompletableJob
private lateinit var readFromTunJobScope: CoroutineScope
private lateinit var readFromProxyJob: CompletableJob
private lateinit var readFromProxyJobScope: CoroutineScope

companion object {
private const val MAX_STREAM_BUFFER_SIZE = 1048576 // max we can write into the stream without parsing
private const val MAX_RECEIVE_BUFFER_SIZE = 1500 // max amount we can recv in one read (should be the MTU or bigger probably)
}

fun connect() {
if (isConnected.get()) {
logger.debug("Client is already connected")
fun start() {
if (isRunning.get()) {
logger.warn("Already running")
return
}
isRunning.set(true)
selector = Selector.open()

readFromProxyJob = SupervisorJob()
readFromProxyJobScope = CoroutineScope(Dispatchers.IO + readFromProxyJob)
readFromProxyJobScope.launch {
logger.debug("Connecting to server: {}", socketAddress)
try {
socket.connect(socketAddress)
logger.debug("Connected to server: {}", socketAddress)
isConnected.set(true)

readFromTunJob = SupervisorJob()
readFromTunJobScope = CoroutineScope(Dispatchers.IO + readFromTunJob)
readFromTunJobScope.launch {
readFromTunWriteToProxy()
}
selectorJob = SupervisorJob()
selectorScope = CoroutineScope(Dispatchers.IO + selectorJob)
selectorScope.launch {
selectorLoop()
}

readFromProxyWriteToTun()
} catch (e: Exception) {
logger.error("Failed to connect to server")
}
readFromTunJob = SupervisorJob()
readFromTunJobScope = CoroutineScope(Dispatchers.IO + readFromTunJob)
readFromTunJobScope.launch {
readFromTunWriteToProxy()
}
readFromProxyJob.complete()
}

fun waitUntilShutdown() {
// block until the read jobs are finished
runBlocking {
readFromProxyJob.join()
selectorJob.join()
readFromTunJob.join()
}
}
Expand All @@ -89,47 +93,76 @@ abstract class Client(

abstract fun tunWrite(writeBytes: ByteArray)

private fun readFromProxyWriteToTun() {
val buffer = ByteArray(MAX_RECEIVE_BUFFER_SIZE)
val datagram = DatagramPacket(buffer, buffer.size)
val stream = ByteBuffer.allocate(MAX_STREAM_BUFFER_SIZE)
private fun selectorLoop() {
datagramChannel.register(selector, OP_READ)

while (isRunning.get()) {
synchronized(changeRequests) {
for (changeRequest in changeRequests) {
when (changeRequest.type) {
REGISTER -> {
logger.debug("Processing REGISTER: ${changeRequest.ops}")
changeRequest.channel.register(selector, changeRequest.ops)
}
CHANGE_OPS -> {
logger.debug("Processing CHANGE_OPS: ${changeRequest.ops}")
val key = changeRequest.channel.keyFor(selector)
key.interestOps(changeRequest.ops)
}
}
}
changeRequests.clear()
}

while (isConnected.get()) {
// logger.debug("Waiting for response from server")
try {
socket.receive(datagram)
val numKeys = selector.select()
// we won't get any keys if we wakeup the selector before we select
// (ie, when we make changes to the keys or interest-ops)
if (numKeys > 0) {
val selectedKeys = selector.selectedKeys()
val keyStream = selectedKeys.parallelStream()
keyStream.forEach {
if (it.isReadable && it.isValid) {
readFromProxy()
}
if (it.isWritable && it.isValid) {
if (outgoingQueue.isNotEmpty()) {
val buffer = outgoingQueue.take()
while (buffer.hasRemaining()) {
datagramChannel.write(buffer)
}
} else {
it.interestOps(OP_READ)
}
}
}
selectedKeys.clear()
}
} catch (e: Exception) {
logger.error("Error receiving from server: $e")
logger.warn("Exception on select, probably shutting down: $e")
break
}
stream.put(buffer, 0, datagram.length)
stream.flip()
val packets = parseStream(stream)
for (packet in packets) {
if (packet is SentinelPacket) {
logger.debug("Sentinel packet, skip")
continue
}
logger.debug("From proxy: $packet")
packetDumper.dumpBuffer(ByteBuffer.wrap(packet.toByteArray()), etherType = EtherType.DETECT)
tunWrite(packet.toByteArray())
}
}
logger.warn("No longer reading from server")
selectorJob.complete()
}

private fun writePackets(packets: List<Packet>) {
packets.forEach { packet ->
val buffer = packet.toByteArray()
val datagramPacket = DatagramPacket(buffer, buffer.size, socketAddress)
packetDumper.dumpBuffer(ByteBuffer.wrap(buffer), etherType = EtherType.DETECT)
try {
socket.send(datagramPacket)
} catch (e: Exception) {
logger.warn("IO error writing to proxy, probably shutting down")
return@forEach
private fun readFromProxy() {
val recvBuffer = ByteBuffer.allocate(MAX_RECEIVE_BUFFER_SIZE)
datagramChannel.read(recvBuffer)
recvBuffer.flip()

fromProxyStream.put(recvBuffer)
fromProxyStream.flip()

val packets = parseStream(fromProxyStream)
for (packet in packets) {
if (packet is SentinelPacket) {
logger.debug("Sentinel packet, skip")
continue
}
// logger.debug("From OS: $packet")
logger.debug("From proxy: $packet")
packetDumper.dumpBuffer(ByteBuffer.wrap(packet.toByteArray()), etherType = EtherType.DETECT)
tunWrite(packet.toByteArray())
}
}

Expand All @@ -142,7 +175,7 @@ abstract class Client(
logger.warn("Filters enabled, not sending all packets to proxy")
}

while (isConnected.get()) {
while (isRunning.get()) {
val bytesToRead = min(MAX_RECEIVE_BUFFER_SIZE, stream.remaining())
val bytesRead =
try {
Expand All @@ -162,49 +195,71 @@ abstract class Client(
// logger.debug("After flip: position: {} remaining {}", stream.position(), stream.remaining())
val packets = parseStream(stream)

var numPackets = 0
if (filters) {
val packetsToForward: MutableList<Packet> = mutableListOf()
for (packet in packets) {
if (onlyDestinations.isNotEmpty()) {
if (packet.ipHeader?.destinationAddress in onlyDestinations) {
if (onlyProtocols.isNotEmpty()) {
if (packet.ipHeader?.protocol in onlyProtocols) {
packetsToForward.add(packet)
outgoingQueue.add(ByteBuffer.wrap(packet.toByteArray()))
numPackets++
// logger.debug("To proxy: $packet")
}
} else {
packetsToForward.add(packet)
outgoingQueue.add(ByteBuffer.wrap(packet.toByteArray()))
numPackets++
// logger.debug("To proxy: $packet")
}
}
} else {
if (onlyProtocols.isNotEmpty()) {
if (packet.ipHeader?.protocol in onlyProtocols) {
packetsToForward.add(packet)
outgoingQueue.add(ByteBuffer.wrap(packet.toByteArray()))
numPackets++
// logger.debug("To proxy: $packet")
}
}
}
}
writePackets(packetsToForward)
} else {
writePackets(packets)
for (packet in packets) {
outgoingQueue.add(ByteBuffer.wrap(packet.toByteArray()))
numPackets++
}
}
if (numPackets > 0) {
logger.debug("Added packets, switching to WRITE mode")
synchronized(changeRequests) {
changeRequests.add(ChangeRequest(datagramChannel, CHANGE_OPS, OP_WRITE))
}
selector.wakeup()
}
}
}

logger.warn("No longer reading from TUN adapter")
readFromTunJob.complete()
}

open fun close() {
open fun stop() {
if (isRunning.get().not()) {
logger.warn("Trying to stop when we're not running")
return
}
logger.debug("Stopping client")
isConnected.set(false)
socket.close()
isRunning.set(false)
selector.close()
try {
datagramChannel.close()
} catch (e: Exception) {
logger.warn("Error closing datagram channel: $e")
}
runBlocking {
logger.debug("Waiting for tun reader to stop")
readFromTunJob.join()
logger.debug("Stopped, waiting for proxy reader to stop")
readFromProxyJob.join()
logger.debug("Stopped, waiting for selector job to stop")
selectorJob.join()
logger.debug("Stopped")
}
logger.debug("Client stopped")
Expand Down
Loading

0 comments on commit c86249b

Please sign in to comment.