Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More datagram tests #136

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions FlyingSocks/Sources/AsyncSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
// SOFTWARE.
//

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

public protocol AsyncSocketPool: Sendable {

Expand Down Expand Up @@ -63,19 +67,35 @@ public extension AsyncSocketPool where Self == SocketPool<Poll> {
public struct AsyncSocket: Sendable {

public struct Message: Sendable {
public let peerAddress: any SocketAddress
public let bytes: [UInt8]
public let interfaceIndex: UInt32?
public let localAddress: (any SocketAddress)?
public var peerAddress: any SocketAddress
public var payload: Data
public var interfaceIndex: UInt32?
public var localAddress: (any SocketAddress)?

public init(
peerAddress: any SocketAddress,
payload: Data,
interfaceIndex: UInt32? = nil,
localAddress: (any SocketAddress)? = nil
) {
self.peerAddress = peerAddress
self.payload = payload
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}

@available(*, deprecated, renamed: "payload")
public var bytes: [UInt8] { Array(payload) }

@available(*, deprecated, renamed: "init(peerAddress:payload:)")
public init(
peerAddress: any SocketAddress,
bytes: [UInt8],
interfaceIndex: UInt32? = nil,
localAddress: (any SocketAddress)? = nil
) {
self.peerAddress = peerAddress
self.bytes = bytes
self.payload = Data(bytes)
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}
Expand Down Expand Up @@ -169,7 +189,7 @@ public struct AsyncSocket: Sendable {
repeat {
do {
let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length)
return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress)
return Message(peerAddress: peerAddress, payload: Data(bytes), interfaceIndex: interfaceIndex, localAddress: localAddress)
} catch SocketError.blocked {
try await pool.suspendSocket(socket, untilReadyFor: .read)
} catch {
Expand Down Expand Up @@ -228,11 +248,12 @@ public struct AsyncSocket: Sendable {

#if !canImport(WinSDK)
public func send(
message: [UInt8],
message: some Sequence<UInt8>,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (any SocketAddress)? = nil
) async throws {
let message = Array(message)
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}
Expand All @@ -241,18 +262,9 @@ public struct AsyncSocket: Sendable {
}
}

public func send(
message: Data,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}

public func send(message: Message) async throws {
try await send(
message: message.bytes,
message: message.payload,
to: AnySocketAddress(message.peerAddress),
interfaceIndex: message.interfaceIndex,
from: message.localAddress
Expand Down
108 changes: 72 additions & 36 deletions FlyingSocks/Tests/AsyncSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,52 +209,43 @@ struct AsyncSocketTests {
}

#if !canImport(WinSDK)
#if canImport(Darwin)
@Test
func datagramSocketReceivesMessageTupleAPI_WhenAvailable() async throws {
let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair()
func messageSequence_sendsMessage_receivesTuple() async throws {
let (socket, port) = try await AsyncSocket.makeLoopbackDatagram()

async let d2: AsyncSocket.Message = s2.receive(atMost: 100)
#if canImport(Darwin)
try await s1.write("Swift".data(using: .utf8)!)
#else
try await s1.send(message: "Swift".data(using: .utf8)!, to: addr, from: addr)
#endif
let v2 = try await d2
#expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift")
async let received: (any SocketAddress, [UInt8]) = socket.receive(atMost: 100)

try s1.close()
try s2.close()
try? Socket.unlink(addr)
}
#endif
let client = try await AsyncSocket.makeLoopbackDatagram().0
let message = AsyncSocket.Message(peerAddress: .loopback(port: port), payload: "Chips 🍟")
try await client.send(message: message)

#if !canImport(WinSDK)
#expect(
try await received.1 == Array("Chips 🍟".data(using: .utf8)!)
)
}
#else
@Test
func datagramSocketReceivesMessageStructAPI_WhenAvailable() async throws {
func sendMessage_receivesTuple() async throws {
let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair()
let messageToSend = AsyncSocket.Message(
peerAddress: addr,
bytes: Array("Swift".data(using: .utf8)!),
localAddress: addr
)
defer {
try? s1.close()
try? s2.close()
try? Socket.unlink(addr)
}
async let received: (any SocketAddress, [UInt8]) = s2.receive(atMost: 100)

async let d2: AsyncSocket.Message = s2.receive(atMost: 100)
#if canImport(Darwin)
try await s1.write("Swift".data(using: .utf8)!)
#else
try await s1.send(message: messageToSend)
#endif
let v2 = try await d2
#expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift")
let message = AsyncSocket.Message(peerAddress: addr, payload: "Shrimp 🦐")
try await s1.send(message: message)

try s1.close()
try s2.close()
try? Socket.unlink(addr)
#expect(
try await received.1 == Array("Shrimp 🦐".data(using: .utf8)!)
)
}
#endif
#endif

@Test
func messageSequence_receives_messages() async throws {
func messageSequence_sendsData_receivesMessage() async throws {
let (socket, port) = try await AsyncSocket.makeLoopbackDatagram()
var messages = socket.messages

Expand All @@ -267,6 +258,44 @@ struct AsyncSocketTests {
try await received?.payloadString == "Fish 🐡"
)
}

#if canImport(Darwin)
@Test
func messageSequence_sendsMessage_receivesMessage() async throws {
let (socket, port) = try await AsyncSocket.makeLoopbackDatagram()
var messages = socket.messages

async let received = messages.next()

let client = try await AsyncSocket.makeLoopbackDatagram().0
let message = AsyncSocket.Message(peerAddress: .loopback(port: port), payload: "Chips 🍟")
try await client.send(message: message)

#expect(
try await received?.payloadString == "Chips 🍟"
)
}
#else
@Test
func sendMessage_receivesMessage() async throws {
let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair()
defer {
try? s1.close()
try? s2.close()
try? Socket.unlink(addr)
}

async let received: AsyncSocket.Message = s2.receive(atMost: 100)

let message = AsyncSocket.Message(peerAddress: addr, payload: "Shrimp 🦐")
try await s1.send(message: message)

#expect(
try await received.payloadString == "Shrimp 🦐"
)
}
#endif
#endif
}

extension AsyncSocket {
Expand Down Expand Up @@ -341,12 +370,19 @@ private extension AsyncSocket.Message {

var payloadString: String {
get throws {
guard let text = String(bytes: bytes, encoding: .utf8) else {
guard let text = String(data: payload, encoding: .utf8) else {
throw SocketError.disconnected
}
return text
}
}

init(peerAddress: some SocketAddress, payload: String) {
self.init(
peerAddress: peerAddress,
payload: payload.data(using: .utf8)!
)
}
}

struct DisconnectedPool: AsyncSocketPool {
Expand Down