Skip to content

Commit

Permalink
Merge pull request #71 from compscidr/jason/add-traffic-accounting
Browse files Browse the repository at this point in the history
Traffic Accounting
  • Loading branch information
compscidr authored Dec 19, 2024
2 parents cd0e818 + 7d77fb3 commit 0e9a830
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.jasonernst.kanonproxy

object DummyTrafficAccount : TrafficAccounting {
override fun recordToInternet(bytes: Long) {
// do nothing
}

override fun recordFromInternet(bytes: Long) {
// do nothing
}
}
2 changes: 2 additions & 0 deletions core/src/main/kotlin/com/jasonernst/kanonproxy/KAnonProxy.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import java.util.concurrent.atomic.AtomicBoolean
class KAnonProxy(
val icmp: Icmp,
val protector: VpnProtector = DummyProtector,
val trafficAccounting: TrafficAccounting = DummyTrafficAccount,
) : SessionManager {
private val logger = LoggerFactory.getLogger(javaClass)

Expand Down Expand Up @@ -212,6 +213,7 @@ class KAnonProxy(
protector,
this,
clientAddress,
trafficAccounting,
)
}
if (isNewSession) {
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/kotlin/com/jasonernst/kanonproxy/Session.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ abstract class Session(
val protector: VpnProtector,
val sessionManager: SessionManager,
val clientAddress: InetSocketAddress,
val trafficAccounting: TrafficAccounting = DummyTrafficAccount,
) {
private val logger = LoggerFactory.getLogger(javaClass)
abstract val channel: ByteChannel
Expand Down Expand Up @@ -90,6 +91,7 @@ abstract class Session(
protector: VpnProtector,
sessionManager: SessionManager,
clientAddress: InetSocketAddress,
trafficAccounting: TrafficAccounting,
): Session =
when (initialIPHeader.protocol) {
IpType.UDP.value -> {
Expand All @@ -101,6 +103,7 @@ abstract class Session(
protector,
sessionManager,
clientAddress,
trafficAccounting,
)
}
IpType.TCP.value -> {
Expand All @@ -116,6 +119,7 @@ abstract class Session(
protector,
sessionManager,
clientAddress,
trafficAccounting,
)
}
else -> {
Expand Down Expand Up @@ -245,7 +249,8 @@ abstract class Session(
val queue = outgoingQueue.take()
logger.debug("Writing ${queue.limit()} bytes to remote channel")
while (queue.hasRemaining()) {
channel.write(queue)
val bytesWritten = channel.write(queue)
trafficAccounting.recordToInternet(bytesWritten.toLong())
}
}
if (outgoingQueue.isNotEmpty()) {
Expand Down Expand Up @@ -320,6 +325,7 @@ abstract class Session(
val payload = ByteArray(len)
readBuffer.get(payload, 0, len)
logger.debug("Read {} bytes from {}", len, channel)
trafficAccounting.recordToInternet(len.toLong())
handlePayloadFromInternet(payload)
readBuffer.clear()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.jasonernst.kanonproxy

interface TrafficAccounting {
fun recordToInternet(bytes: Long)

fun recordFromInternet(bytes: Long)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.jasonernst.kanonproxy.tcp

import com.jasonernst.kanonproxy.ChangeRequest
import com.jasonernst.kanonproxy.SessionManager
import com.jasonernst.kanonproxy.TrafficAccounting
import com.jasonernst.kanonproxy.VpnProtector
import com.jasonernst.knet.Packet
import com.jasonernst.knet.network.ip.IpHeader
Expand All @@ -25,6 +26,7 @@ class AnonymousTcpSession(
protector: VpnProtector,
sessionManager: SessionManager,
clientAddress: InetSocketAddress,
trafficAccounting: TrafficAccounting,
) : TcpSession(
initialIpHeader = initialIpHeader,
initialTransportHeader = initialTransportHeader,
Expand All @@ -33,6 +35,7 @@ class AnonymousTcpSession(
protector = protector,
sessionManager = sessionManager,
clientAddress = clientAddress,
trafficAccounting = trafficAccounting,
) {
companion object {
const val CONNECTION_POLL_MS: Long = 500
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.jasonernst.kanonproxy.tcp

import com.jasonernst.kanonproxy.Session
import com.jasonernst.kanonproxy.SessionManager
import com.jasonernst.kanonproxy.TrafficAccounting
import com.jasonernst.kanonproxy.VpnProtector
import com.jasonernst.knet.Packet
import com.jasonernst.knet.network.ip.IpHeader
Expand All @@ -24,6 +25,7 @@ abstract class TcpSession(
protector: VpnProtector,
sessionManager: SessionManager,
clientAddress: InetSocketAddress,
trafficAccounting: TrafficAccounting,
) : Session(
initialIpHeader = initialIpHeader,
initialTransportHeader = initialTransportHeader,
Expand All @@ -32,6 +34,7 @@ abstract class TcpSession(
protector = protector,
sessionManager = sessionManager,
clientAddress = clientAddress,
trafficAccounting = trafficAccounting,
) {
private val logger = LoggerFactory.getLogger(javaClass)
val isPsh = AtomicBoolean(false) // set when we have accepted a PSH packet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.jasonernst.kanonproxy.udp
import com.jasonernst.kanonproxy.ChangeRequest
import com.jasonernst.kanonproxy.Session
import com.jasonernst.kanonproxy.SessionManager
import com.jasonernst.kanonproxy.TrafficAccounting
import com.jasonernst.kanonproxy.VpnProtector
import com.jasonernst.knet.Packet
import com.jasonernst.knet.network.ip.IpHeader
Expand Down Expand Up @@ -30,6 +31,7 @@ class UdpSession(
protector: VpnProtector,
sessionManager: SessionManager,
clientAddress: InetSocketAddress,
trafficAccounting: TrafficAccounting,
) : Session(
initialIpHeader = initialIpHeader,
initialTransportHeader = initialTransportHeader,
Expand All @@ -38,6 +40,7 @@ class UdpSession(
protector = protector,
sessionManager = sessionManager,
clientAddress = clientAddress,
trafficAccounting = trafficAccounting,
) {
private val logger = LoggerFactory.getLogger(javaClass)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.jasonernst.kanonproxy.tcp

import com.jasonernst.kanonproxy.DummyProtector
import com.jasonernst.kanonproxy.DummyTrafficAccount
import com.jasonernst.kanonproxy.SessionManager
import com.jasonernst.knet.Packet
import com.jasonernst.knet.network.ip.IpType
Expand Down Expand Up @@ -32,7 +33,17 @@ class AnonymousTcpSessionTest {
val sessionManager = mockk<SessionManager>()
every { sessionManager.isRunning() } returns true

val session = AnonymousTcpSession(ipHeader, tcpHeader, ByteArray(0), returnQueue, DummyProtector, sessionManager, clientAddress)
val session =
AnonymousTcpSession(
ipHeader,
tcpHeader,
ByteArray(0),
returnQueue,
DummyProtector,
sessionManager,
clientAddress,
DummyTrafficAccount,
)

// wait until its connecting
while (session.isConnecting.get().not()) {
Expand All @@ -53,7 +64,17 @@ class AnonymousTcpSessionTest {
)
val tcpHeader2 = TcpHeader(syn = true, destinationPort = 80u)
val returnQueue2 = LinkedBlockingDeque<Packet>()
val session2 = AnonymousTcpSession(ipHeader2, tcpHeader2, ByteArray(0), returnQueue2, DummyProtector, sessionManager, clientAddress)
val session2 =
AnonymousTcpSession(
ipHeader2,
tcpHeader2,
ByteArray(0),
returnQueue2,
DummyProtector,
sessionManager,
clientAddress,
DummyTrafficAccount,
)

// wait until its connecting
while (session2.isConnecting.get().not()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.jasonernst.kanonproxy.tcp
import com.jasonernst.icmp.common.v4.IcmpV4DestinationUnreachablePacket
import com.jasonernst.icmp.common.v6.IcmpV6DestinationUnreachablePacket
import com.jasonernst.kanonproxy.BidirectionalByteChannel
import com.jasonernst.kanonproxy.DummyTrafficAccount
import com.jasonernst.kanonproxy.KAnonProxy
import com.jasonernst.knet.Packet
import com.jasonernst.knet.SentinelPacket
Expand Down Expand Up @@ -54,6 +55,7 @@ class TcpClient(
mockk(relaxed = true),
mockk(relaxed = true),
clientAddress,
DummyTrafficAccount,
) {
private val clientId = UUID.randomUUID()
private val logger = LoggerFactory.getLogger(javaClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ class ProxyServer(
val keyStream = selectedKeys.parallelStream()
keyStream.forEach {
if (it.isReadable && it.isValid) {
logger.debug("READ")
readFromClient()
}
if (it.isWritable && it.isValid) {
Expand Down

0 comments on commit 0e9a830

Please sign in to comment.