Skip to content

Commit

Permalink
fix: make handshake abortable (#442)
Browse files Browse the repository at this point in the history
To allow doing things like having a single `AbortSignal` that can be
used as a timeout for incoming connection establishment, allow passing
it as an option to the `ConnectionEncrypter` `secureOutbound` and
`secureInbound` methods.

Previously we'd wrap the stream to be secured in an `AbortableSource`,
however this has some [serious performance implications](ChainSafe/js-libp2p-gossipsub#361)
and it's generally better to just use a signal to cancel an ongoing
operation instead of racing every chunk that comes out of the source.
  • Loading branch information
achingbrain authored Aug 29, 2024
1 parent dadcd1c commit 35ce15d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 21 deletions.
33 changes: 20 additions & 13 deletions src/noise.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { unmarshalPrivateKey } from '@libp2p/crypto/keys'
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId } from '@libp2p/interface'
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId, type AbortOptions } from '@libp2p/interface'
import { peerIdFromKeys } from '@libp2p/peer-id'
import { decode } from 'it-length-prefixed'
import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream'
Expand Down Expand Up @@ -72,10 +72,10 @@ export class Noise implements INoiseConnection {
* @param connection - streaming iterable duplex that will be encrypted
* @param remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer.
*/
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)

const wrappedConnection = lpStream(
connection,
Expand All @@ -96,7 +96,9 @@ export class Noise implements INoiseConnection {
const handshake = await this.performHandshakeInitiator(
wrappedConnection,
privateKey,
remoteIdentityKey
remoteIdentityKey, {
signal
}
)
const conn = await this.createSecureConnection(wrappedConnection, handshake)

Expand All @@ -117,10 +119,10 @@ export class Noise implements INoiseConnection {
* @param connection - streaming iterable duplex that will be encrypted.
* @param remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades.
*/
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)

const wrappedConnection = lpStream(
connection,
Expand All @@ -141,7 +143,9 @@ export class Noise implements INoiseConnection {
const handshake = await this.performHandshakeResponder(
wrappedConnection,
privateKey,
remoteIdentityKey
remoteIdentityKey, {
signal
}
)
const conn = await this.createSecureConnection(wrappedConnection, handshake)

Expand All @@ -162,7 +166,8 @@ export class Noise implements INoiseConnection {
connection: LengthPrefixedStream,
// TODO: pass private key in noise constructor via Components
privateKey: PrivateKey,
remoteIdentityKey?: Uint8Array | Uint8ArrayList
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
options?: AbortOptions
): Promise<HandshakeResult> {
let result: HandshakeResult
try {
Expand All @@ -175,7 +180,7 @@ export class Noise implements INoiseConnection {
prologue: this.prologue,
s: this.staticKey,
extensions: this.extensions
})
}, options)
this.metrics?.xxHandshakeSuccesses.increment()
} catch (e: unknown) {
this.metrics?.xxHandshakeErrors.increment()
Expand All @@ -192,7 +197,8 @@ export class Noise implements INoiseConnection {
connection: LengthPrefixedStream,
// TODO: pass private key in noise constructor via Components
privateKey: PrivateKey,
remoteIdentityKey?: Uint8Array | Uint8ArrayList
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
options?: AbortOptions
): Promise<HandshakeResult> {
let result: HandshakeResult
try {
Expand All @@ -205,7 +211,7 @@ export class Noise implements INoiseConnection {
prologue: this.prologue,
s: this.staticKey,
extensions: this.extensions
})
}, options)
this.metrics?.xxHandshakeSuccesses.increment()
} catch (e: unknown) {
this.metrics?.xxHandshakeErrors.increment()
Expand Down Expand Up @@ -241,7 +247,7 @@ export class Noise implements INoiseConnection {
* TODO: remove this after `libp2p@2.x.x` is released and only support the
* newer style
*/
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId } {
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId, signal?: AbortSignal } {
// if the first argument is a peer id, we're using the libp2p@1.x.x style
if (isPeerId(args[0])) {
return {
Expand All @@ -256,7 +262,8 @@ export class Noise implements INoiseConnection {
return {
localPeer: this.components.peerId,
connection: args[0],
remotePeer: args[1]
remotePeer: args[1]?.remotePeer,
signal: args[1]?.signal
}
}
}
Expand Down
17 changes: 9 additions & 8 deletions src/performHandshake.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
import { ZEROLEN, XXHandshakeState } from './protocol.js'
import { createHandshakePayload, decodeHandshakePayload } from './utils.js'
import type { HandshakeResult, HandshakeParams } from './types.js'
import type { AbortOptions } from '@libp2p/interface'

export async function performHandshakeInitiator (init: HandshakeParams): Promise<HandshakeResult> {
export async function performHandshakeInitiator (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init

const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
Expand All @@ -23,12 +24,12 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise

logLocalStaticKeys(xx.s, log)
log.trace('Stage 0 - Initiator starting to send first message.')
await connection.write(xx.writeMessageA(ZEROLEN))
await connection.write(xx.writeMessageA(ZEROLEN), options)
log.trace('Stage 0 - Initiator finished sending first message.')
logLocalEphemeralKeys(xx.e, log)

log.trace('Stage 1 - Initiator waiting to receive first message from responder...')
const plaintext = xx.readMessageB(await connection.read())
const plaintext = xx.readMessageB(await connection.read(options))
log.trace('Stage 1 - Initiator received the message.')
logRemoteEphemeralKey(xx.re, log)
logRemoteStaticKey(xx.rs, log)
Expand All @@ -38,7 +39,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
log.trace('All good with the signature!')

log.trace('Stage 2 - Initiator sending third handshake message.')
await connection.write(xx.writeMessageC(payload))
await connection.write(xx.writeMessageC(payload), options)
log.trace('Stage 2 - Initiator sent message with signed payload.')

const [cs1, cs2] = xx.ss.split()
Expand All @@ -51,7 +52,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
}
}

export async function performHandshakeResponder (init: HandshakeParams): Promise<HandshakeResult> {
export async function performHandshakeResponder (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init

const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
Expand All @@ -65,17 +66,17 @@ export async function performHandshakeResponder (init: HandshakeParams): Promise

logLocalStaticKeys(xx.s, log)
log.trace('Stage 0 - Responder waiting to receive first message.')
xx.readMessageA(await connection.read())
xx.readMessageA(await connection.read(options))
log.trace('Stage 0 - Responder received first message.')
logRemoteEphemeralKey(xx.re, log)

log.trace('Stage 1 - Responder sending out first message with signed payload and static key.')
await connection.write(xx.writeMessageB(payload))
await connection.write(xx.writeMessageB(payload), options)
log.trace('Stage 1 - Responder sent the second handshake message with signed payload.')
logLocalEphemeralKeys(xx.e, log)

log.trace('Stage 2 - Responder waiting for third handshake message...')
const plaintext = xx.readMessageC(await connection.read())
const plaintext = xx.readMessageC(await connection.read(options))
log.trace('Stage 2 - Responder received the message, finished handshake.')
const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey)

Expand Down
27 changes: 27 additions & 0 deletions test/noise.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,31 @@ describe('Noise', () => {
assert(false, err.message)
}
})

it('should abort noise handshake', async () => {
const abortController = new AbortController()
abortController.abort()

const noiseInit = new Noise({
peerId: localPeer,
logger: defaultLogger()
}, { staticNoiseKey: undefined, extensions: undefined })
const noiseResp = new Noise({
peerId: remotePeer,
logger: defaultLogger()
}, { staticNoiseKey: undefined, extensions: undefined })

const [inboundConnection, outboundConnection] = duplexPair<Uint8Array | Uint8ArrayList>()

await expect(Promise.all([
noiseInit.secureOutbound(outboundConnection, {
remotePeer,
signal: abortController.signal
}),
noiseResp.secureInbound(inboundConnection, {
remotePeer: localPeer
})
])).to.eventually.be.rejected
.with.property('name', 'AbortError')
})
})

0 comments on commit 35ce15d

Please sign in to comment.