Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dcutr): update the DCUtR initiator transport direction to Inbound #994

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libp2p/dial.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async, base.} =
dir = Direction.Out) {.async, base.} =
## connect remote peer without negotiating
## a protocol
##
Expand Down
26 changes: 15 additions & 11 deletions libp2p/dialer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ proc dialAndUpgrade(
peerId: Opt[PeerId],
hostname: string,
address: MultiAddress,
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =

for transport in self.transports: # for each transport
Expand All @@ -75,23 +75,27 @@ proc dialAndUpgrade(

let mux =
try:
dialed.transportDir = upgradeDir
await transport.upgrade(dialed, upgradeDir, peerId)
# This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have
# an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound.
# The if below is more general and might handle other use cases in the future.
if dialed.dir != dir:
dialed.dir = dir
await transport.upgrade(dialed, peerId)
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
await dialed.close()
debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
if exc isnot CancelledError:
if upgradeDir == Direction.Out:
if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()

# Try other address
return nil

doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir
doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir
debug "Dial successful", peerId = mux.connection.peerId
return mux
return nil
Expand Down Expand Up @@ -128,7 +132,7 @@ proc dialAndUpgrade(
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =

debug "Dialing peer", peerId = peerId.get(default(PeerId))
Expand All @@ -146,7 +150,7 @@ proc dialAndUpgrade(
else: await self.nameResolver.resolveMAddress(expandedAddress)

for resolvedAddress in resolvedAddresses:
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir)
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
if not isNil(result):
return result

Expand All @@ -164,7 +168,7 @@ proc internalConnect(
addrs: seq[MultiAddress],
forceDial: bool,
reuseConnection = true,
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =
if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!")
Expand All @@ -182,7 +186,7 @@ proc internalConnect(
let slot = self.connManager.getOutgoingSlot(forceDial)
let muxed =
try:
await self.dialAndUpgrade(peerId, addrs, upgradeDir)
await self.dialAndUpgrade(peerId, addrs, dir)
except CatchableError as exc:
slot.release()
raise exc
Expand All @@ -209,15 +213,15 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
dir = Direction.Out) {.async.} =
## connect remote peer without negotiating
## a protocol
##

if self.connManager.connCount(peerId) > 0 and reuseConnection:
return

discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir)
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir)

method connect*(
self: Dialer,
Expand Down
2 changes: 1 addition & 1 deletion libp2p/protocols/connectivity/dcutr/client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs:

if peerDialableAddrs.len > self.maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.In))
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.In))
try:
discard await anyCompleted(futs).wait(self.connectTimeout)
debug "Dcutr initiator has directly connected to the remote peer."
Expand Down
2 changes: 1 addition & 1 deletion libp2p/protocols/connectivity/dcutr/server.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ proc new*(T: typedesc[Dcutr], switch: Switch, connectTimeout = 15.seconds, maxDi

if peerDialableAddrs.len > maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.Out))
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.Out))
try:
discard await anyCompleted(futs).wait(connectTimeout)
debug "Dcutr receiver has directly connected to the remote peer."
Expand Down
3 changes: 1 addition & 2 deletions libp2p/protocols/secure/secure.nim
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ method init*(s: Secure) =

method secure*(s: Secure,
conn: Connection,
initiator: bool,
peerId: Opt[PeerId]):
Future[Connection] {.base.} =
s.handleConn(conn, initiator, peerId)
s.handleConn(conn, conn.dir == Direction.Out, peerId)

method readOnce*(s: SecureConn,
pbytes: pointer,
Expand Down
6 changes: 3 additions & 3 deletions libp2p/switch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.public.} =
dir = Direction.Out): Future[void] {.public.} =
## Connects to a peer without opening a stream to it

s.dialer.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, dir)

method connect*(
s: Switch,
Expand Down Expand Up @@ -213,7 +213,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil)
s.peerInfo.protocols.add(proto.codec)

proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} =
let muxed = await trans.upgrade(conn, Direction.In, Opt.none(PeerId))
let muxed = await trans.upgrade(conn, Opt.none(PeerId))
switch.connManager.storeMuxer(muxed)
await switch.peerStore.identify(muxed)
trace "Connection upgrade succeeded"
Expand Down
3 changes: 1 addition & 2 deletions libp2p/transports/transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ proc dial*(
method upgrade*(
self: Transport,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform
## transport specific upgrades
##

self.upgrader.upgrade(conn, direction, peerId)
self.upgrader.upgrade(conn, peerId)

method handles*(
self: Transport,
Expand Down
14 changes: 6 additions & 8 deletions libp2p/upgrademngrs/muxedupgrade.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =

proc mux*(
self: MuxedUpgrade,
conn: Connection,
direction: Direction): Future[Muxer] {.async, gcsafe.} =
conn: Connection): Future[Muxer] {.async, gcsafe.} =
## mux connection

trace "Muxing connection", conn
Expand All @@ -42,7 +41,7 @@ proc mux*(
return

let muxerName =
if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
if conn.dir == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))

if muxerName.len == 0 or muxerName == "na":
Expand All @@ -62,16 +61,15 @@ proc mux*(
method upgrade*(
self: MuxedUpgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
trace "Upgrading connection", conn, direction
trace "Upgrading connection", conn, direction = conn.dir

let sconn = await self.secure(conn, direction, peerId) # secure the connection
let sconn = await self.secure(conn, peerId) # secure the connection
if isNil(sconn):
raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade")

let muxer = await self.mux(sconn, direction) # mux it if possible
let muxer = await self.mux(sconn) # mux it if possible
if muxer == nil:
raise newException(UpgradeFailedError,
"a muxer is required for outgoing connections")
Expand All @@ -84,7 +82,7 @@ method upgrade*(
raise newException(UpgradeFailedError,
"Connection closed or missing peer info, stopping upgrade")

trace "Upgraded connection", conn, sconn, direction
trace "Upgraded connection", conn, sconn, direction = conn.dir
return muxer

proc new*(
Expand Down
6 changes: 2 additions & 4 deletions libp2p/upgrademngrs/upgrade.nim
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ type
method upgrade*(
self: Upgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base.} =
doAssert(false, "Not implemented!")

proc secure*(
self: Upgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!")

let codec =
if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
if conn.dir == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
if codec.len == 0:
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
Expand All @@ -65,4 +63,4 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0)

return await secureProtocol[0].secure(conn, direction == Out, peerId)
return await secureProtocol[0].secure(conn, peerId)
8 changes: 4 additions & 4 deletions tests/stubs/switchstub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ type
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.gcsafe, async.}
dir = Direction.Out): Future[void] {.gcsafe, async.}

method connect*(
self: SwitchStub,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
dir = Direction.Out) {.async.} =
if (self.connectStub != nil):
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, upgradeDir)
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, dir)
else:
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, dir)

proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T =
return SwitchStub(
Expand Down
8 changes: 4 additions & 4 deletions tests/testdcutr.nim
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis)

let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc)
Expand All @@ -115,7 +115,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error")

let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc)
Expand Down Expand Up @@ -163,7 +163,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis)

await ductrServerTest(connectProc)
Expand All @@ -175,7 +175,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error")

await ductrServerTest(connectProc)
Expand Down
2 changes: 1 addition & 1 deletion tests/testhpservice.nim
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ suite "Hole Punching":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
self.connectStub = nil # this stub should be called only once
raise newException(CatchableError, "error")

Expand Down
16 changes: 8 additions & 8 deletions tests/testnoise.nim
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ suite "Noise":

proc acceptHandler() {.async.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
try:
await sconn.write("Hello!")
finally:
Expand All @@ -115,7 +115,7 @@ suite "Noise":
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])

let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

var msg = newSeq[byte](6)
await sconn.readExactly(addr msg[0], 6)
Expand Down Expand Up @@ -144,7 +144,7 @@ suite "Noise":
var conn: Connection
try:
conn = await transport1.accept()
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
discard await serverNoise.secure(conn, Opt.none(PeerId))
except CatchableError:
discard
finally:
Expand All @@ -160,7 +160,7 @@ suite "Noise":

var sconn: Connection = nil
expect(NoiseDecryptTagError):
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
sconn = await clientNoise.secure(conn, Opt.some(conn.peerId))

await conn.close()
await handlerWait
Expand All @@ -180,7 +180,7 @@ suite "Noise":

proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer:
await sconn.close()
await conn.close()
Expand All @@ -196,7 +196,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

await sconn.write("Hello!")
await acceptFut
Expand All @@ -223,7 +223,7 @@ suite "Noise":

proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer:
await sconn.close()
let msg = await sconn.readLp(1024*1024)
Expand All @@ -237,7 +237,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

await sconn.writeLp(hugePayload)
await readTask
Expand Down
Loading