From ecc2fd06807d4229f589ed3cad407123415ba04f Mon Sep 17 00:00:00 2001 From: Alex Potsides Date: Thu, 23 Nov 2023 16:30:04 +0000 Subject: [PATCH] fix: use optimistic protocol negotation (#2253) When negotiating connection encrypters, multiplexers or stream protocols, if we only have one protocol to negotiate there's no point sending it, then waiting for a response, then sending some data, we can just send it, and the data, then assuming the remote doesn't immediately close the negotiation channel, read the response at our leisure. This saves a round trip for the first chunk of stream data and reduces our connection latency in the [libp2p perf tests](https://observablehq.com/@libp2p-workspace/performance-dashboard?branch=c8022cb77397759bd7e71a73e93f9074854989fe) from 0.45 to 0.3ms. It changes stream behaviour a little, since we now don't start the protocol negotiation until we interact with the stream (e.g. try to read or write data) and most of our tests assume that negotiation has succeeded when the stream is returned so it's not been a straightforward fix. --- .../test/circuit-relay.node.ts | 2 +- .../src/pubsub/two-nodes.ts | 6 +- packages/libp2p/.aegir.js | 9 +- .../src/connection-manager/dial-queue.ts | 2 +- packages/libp2p/src/upgrader.ts | 79 +++++-- .../test/connection-manager/direct.node.ts | 23 +- .../test/connection-manager/index.node.ts | 21 +- .../libp2p/test/content-routing/dht/utils.ts | 3 +- packages/libp2p/test/fixtures/echo-service.ts | 42 ++++ .../libp2p/test/upgrading/upgrader.spec.ts | 32 ++- packages/multistream-select/package.json | 1 + packages/multistream-select/src/index.ts | 2 +- .../multistream-select/src/multistream.ts | 4 +- packages/multistream-select/src/select.ts | 203 +++++++++++++++--- .../multistream-select/test/dialer.spec.ts | 39 +++- .../test/integration.spec.ts | 12 +- packages/transport-webrtc/test/basics.spec.ts | 11 +- packages/utils/src/abstract-stream.ts | 9 + packages/utils/test/abstract-stream.spec.ts | 5 +- 19 files changed, 408 insertions(+), 97 deletions(-) create mode 100644 packages/libp2p/test/fixtures/echo-service.ts diff --git a/packages/integration-tests/test/circuit-relay.node.ts b/packages/integration-tests/test/circuit-relay.node.ts index 4d43bbd83e..3752774ba4 100644 --- a/packages/integration-tests/test/circuit-relay.node.ts +++ b/packages/integration-tests/test/circuit-relay.node.ts @@ -522,7 +522,7 @@ describe('circuit-relay', () => { expect(conns).to.have.lengthOf(1) // this should fail as the local peer has HOP disabled - await expect(conns[0].newStream(RELAY_V2_HOP_CODEC)) + await expect(conns[0].newStream([RELAY_V2_HOP_CODEC, '/other/1.0.0'])) .to.be.rejected() // we should still be connected to the relay diff --git a/packages/interface-compliance-tests/src/pubsub/two-nodes.ts b/packages/interface-compliance-tests/src/pubsub/two-nodes.ts index fe86ec1050..c68d81dedd 100644 --- a/packages/interface-compliance-tests/src/pubsub/two-nodes.ts +++ b/packages/interface-compliance-tests/src/pubsub/two-nodes.ts @@ -73,7 +73,7 @@ export default (common: TestSetup): void => { expect(psA.getTopics()).to.deep.equal([topic]) expect(psB.getPeers()).to.have.lengthOf(1) expect(psB.getSubscribers(topic).map(p => p.toString())).to.deep.equal([componentsA.peerId.toString()]) - expect(changedPeerId).to.deep.equal(psB.getPeers()[0]) + expect(changedPeerId.toString()).to.equal(psB.getPeers()[0].toString()) expect(changedSubs).to.have.lengthOf(1) expect(changedSubs[0].topic).to.equal(topic) expect(changedSubs[0].subscribe).to.equal(true) @@ -243,7 +243,7 @@ export default (common: TestSetup): void => { const { peerId: changedPeerId, subscriptions: changedSubs } = evt.detail expect(psB.getPeers()).to.have.lengthOf(1) expect(psB.getTopics()).to.be.empty() - expect(changedPeerId).to.deep.equal(psB.getPeers()[0]) + expect(changedPeerId.toString()).to.equal(psB.getPeers()[0].toString()) expect(changedSubs).to.have.lengthOf(1) expect(changedSubs[0].topic).to.equal(topic) expect(changedSubs[0].subscribe).to.equal(true) @@ -252,7 +252,7 @@ export default (common: TestSetup): void => { const { peerId: changedPeerId, subscriptions: changedSubs } = evt.detail expect(psB.getPeers()).to.have.lengthOf(1) expect(psB.getTopics()).to.be.empty() - expect(changedPeerId).to.deep.equal(psB.getPeers()[0]) + expect(changedPeerId.toString()).to.equal(psB.getPeers()[0].toString()) expect(changedSubs).to.have.lengthOf(1) expect(changedSubs[0].topic).to.equal(topic) expect(changedSubs[0].subscribe).to.equal(false) diff --git a/packages/libp2p/.aegir.js b/packages/libp2p/.aegir.js index b0f7674cd0..8372a3dfc2 100644 --- a/packages/libp2p/.aegir.js +++ b/packages/libp2p/.aegir.js @@ -18,6 +18,7 @@ export default { const { plaintext } = await import('@libp2p/plaintext') const { circuitRelayServer, circuitRelayTransport } = await import('@libp2p/circuit-relay-v2') const { identify } = await import('@libp2p/identify') + const { echo, ECHO_PROTOCOL } = await import('./dist/test/fixtures/echo-service.js') const peerId = await createEd25519PeerId() const libp2p = await createLibp2p({ @@ -49,14 +50,10 @@ export default { reservations: { maxReservations: Infinity } - }) + }), + echo: echo() } }) - // Add the echo protocol - await libp2p.handle('/echo/1.0.0', ({ stream }) => { - pipe(stream, stream) - .catch() // sometimes connections are closed before multistream-select finishes which causes an error - }) return { libp2p, diff --git a/packages/libp2p/src/connection-manager/dial-queue.ts b/packages/libp2p/src/connection-manager/dial-queue.ts index ddc90b6870..9550c95dfe 100644 --- a/packages/libp2p/src/connection-manager/dial-queue.ts +++ b/packages/libp2p/src/connection-manager/dial-queue.ts @@ -461,7 +461,7 @@ export class DialQueue { // internal peer dial queue - only one dial per peer at a time const peerDialQueue = new PQueue({ concurrency: 1 }) peerDialQueue.on('error', (err) => { - this.log.error('error dialing [%s] %o', pendingDial.multiaddrs, err) + this.log.error('error dialing %s %o', pendingDial.multiaddrs, err) }) const conn = await Promise.any(pendingDial.multiaddrs.map(async (addr, i) => { diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index b014702b50..94fcd04d4d 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -117,14 +117,12 @@ export class DefaultUpgrader implements Upgrader { private readonly muxers: Map private readonly inboundUpgradeTimeout: number private readonly events: TypedEventTarget - private readonly logger: ComponentLogger private readonly log: Logger constructor (components: DefaultUpgraderComponents, init: UpgraderInit) { this.components = components this.connectionEncryption = new Map() this.log = components.logger.forComponent('libp2p:upgrader') - this.logger = components.logger init.connectionEncryption.forEach(encrypter => { this.connectionEncryption.set(encrypter.protocol, encrypter) @@ -415,6 +413,21 @@ export class DefaultUpgrader implements Upgrader { muxedStream.sink = stream.sink muxedStream.protocol = protocol + // allow closing the write end of a not-yet-negotiated stream + if (stream.closeWrite != null) { + muxedStream.closeWrite = stream.closeWrite + } + + // allow closing the read end of a not-yet-negotiated stream + if (stream.closeRead != null) { + muxedStream.closeRead = stream.closeRead + } + + // make sure we don't try to negotiate a stream we are closing + if (stream.close != null) { + muxedStream.close = stream.close + } + // If a protocol stream has been successfully negotiated and is to be passed to the application, // the peerstore should ensure that the peer is registered with that protocol await this.components.peerStore.merge(remotePeer, { @@ -426,7 +439,7 @@ export class DefaultUpgrader implements Upgrader { this._onStream({ connection, stream: muxedStream, protocol }) }) .catch(async err => { - this.log.error('error handling incoming stream id %d', muxedStream.id, err.message, err.code, err.stack) + this.log.error('error handling incoming stream id %s', muxedStream.id, err.message, err.code, err.stack) if (muxedStream.timeline.close == null) { await muxedStream.close() @@ -440,13 +453,13 @@ export class DefaultUpgrader implements Upgrader { throw new CodeError('Stream is not multiplexed', codes.ERR_MUXER_UNAVAILABLE) } - connection.log('starting new stream for protocols [%s]', protocols) + connection.log('starting new stream for protocols %s', protocols) const muxedStream = await muxer.newStream() - connection.log.trace('starting new stream %s for protocols [%s]', muxedStream.id, protocols) + connection.log.trace('started new stream %s for protocols %s', muxedStream.id, protocols) try { if (options.signal == null) { - this.log('No abort signal was passed while trying to negotiate protocols [%s] falling back to default timeout', protocols) + this.log('No abort signal was passed while trying to negotiate protocols %s falling back to default timeout', protocols) const signal = AbortSignal.timeout(DEFAULT_PROTOCOL_SELECT_TIMEOUT) setMaxListeners(Infinity, signal) @@ -457,13 +470,18 @@ export class DefaultUpgrader implements Upgrader { } } - const { stream, protocol } = await mss.select(muxedStream, protocols, { + muxedStream.log.trace('selecting protocol from protocols %s', protocols) + + const { + stream, + protocol + } = await mss.select(muxedStream, protocols, { ...options, log: muxedStream.log, - yieldBytes: false + yieldBytes: true }) - connection.log('negotiated protocol stream %s with id %s', protocol, muxedStream.id) + muxedStream.log('selected protocol %s', protocol) const outgoingLimit = findOutgoingStreamLimit(protocol, this.components.registrar, options) const streamCount = countStreams(protocol, 'outbound', connection) @@ -487,6 +505,21 @@ export class DefaultUpgrader implements Upgrader { muxedStream.sink = stream.sink muxedStream.protocol = protocol + // allow closing the write end of a not-yet-negotiated stream + if (stream.closeWrite != null) { + muxedStream.closeWrite = stream.closeWrite + } + + // allow closing the read end of a not-yet-negotiated stream + if (stream.closeRead != null) { + muxedStream.closeRead = stream.closeRead + } + + // make sure we don't try to negotiate a stream we are closing + if (stream.close != null) { + muxedStream.close = stream.close + } + this.components.metrics?.trackProtocolStream(muxedStream, connection) return muxedStream @@ -637,16 +670,23 @@ export class DefaultUpgrader implements Upgrader { this.log('selecting outbound crypto protocol', protocols) try { - const { stream, protocol } = await mss.select(connection, protocols, { - log: this.logger.forComponent('libp2p:mss:select') + connection.log.trace('selecting encrypter from %s', protocols) + + const { + stream, + protocol + } = await mss.select(connection, protocols, { + log: connection.log, + yieldBytes: true }) + const encrypter = this.connectionEncryption.get(protocol) if (encrypter == null) { throw new Error(`no crypto module found for ${protocol}`) } - this.log('encrypting outbound connection to %p', remotePeerId) + connection.log('encrypting outbound connection to %p using %p', remotePeerId) return { ...await encrypter.secureOutbound(this.components.peerId, stream, remotePeerId), @@ -665,15 +705,22 @@ export class DefaultUpgrader implements Upgrader { const protocols = Array.from(muxers.keys()) this.log('outbound selecting muxer %s', protocols) try { - const { stream, protocol } = await mss.select(connection, protocols, { - log: this.logger.forComponent('libp2p:mss:select') + connection.log.trace('selecting stream muxer from %s', protocols) + + const { + stream, + protocol + } = await mss.select(connection, protocols, { + log: connection.log, + yieldBytes: true }) - this.log('%s selected as muxer protocol', protocol) + + connection.log('selected %s as muxer protocol', protocol) const muxerFactory = muxers.get(protocol) return { stream, muxerFactory } } catch (err: any) { - this.log.error('error multiplexing outbound stream', err) + connection.log.error('error multiplexing outbound stream', err) throw new CodeError(String(err), codes.ERR_MUXER_UNAVAILABLE) } } diff --git a/packages/libp2p/test/connection-manager/direct.node.ts b/packages/libp2p/test/connection-manager/direct.node.ts index 6893d81c3b..10d4f9da09 100644 --- a/packages/libp2p/test/connection-manager/direct.node.ts +++ b/packages/libp2p/test/connection-manager/direct.node.ts @@ -33,6 +33,7 @@ import { DefaultConnectionManager } from '../../src/connection-manager/index.js' import { codes as ErrorCodes } from '../../src/errors.js' import { createLibp2pNode, type Libp2pNode } from '../../src/libp2p.js' import { DefaultTransportManager } from '../../src/transport-manager.js' +import { ECHO_PROTOCOL, echo } from '../fixtures/echo-service.js' import type { PeerId } from '@libp2p/interface/peer-id' import type { TransportManager } from '@libp2p/interface-internal/transport-manager' import type { Multiaddr } from '@multiformats/multiaddr' @@ -303,10 +304,10 @@ describe('libp2p.dialer (direct, TCP)', () => { ], connectionEncryption: [ plaintext() - ] - }) - await remoteLibp2p.handle('/echo/1.0.0', ({ stream }) => { - void pipe(stream, stream) + ], + services: { + echo: echo() + } }) await remoteLibp2p.start() @@ -348,9 +349,9 @@ describe('libp2p.dialer (direct, TCP)', () => { const connection = await libp2p.dial(remotePeerId) expect(connection).to.exist() - const stream = await connection.newStream('/echo/1.0.0') + const stream = await connection.newStream(ECHO_PROTOCOL) expect(stream).to.exist() - expect(stream).to.have.property('protocol', '/echo/1.0.0') + expect(stream).to.have.property('protocol', ECHO_PROTOCOL) await connection.close() }) @@ -388,7 +389,7 @@ describe('libp2p.dialer (direct, TCP)', () => { const connection = await libp2p.dial(remoteLibp2p.getMultiaddrs()) // Create local to remote streams - const stream = await connection.newStream('/echo/1.0.0') + const stream = await connection.newStream([ECHO_PROTOCOL, '/other/1.0.0']) await connection.newStream('/stream-count/3') await libp2p.dialProtocol(remoteLibp2p.peerId, '/stream-count/4') @@ -398,8 +399,8 @@ describe('libp2p.dialer (direct, TCP)', () => { source.push(uint8ArrayFromString('hello')) // Create remote to local streams - await remoteLibp2p.dialProtocol(libp2p.peerId, '/stream-count/1') - await remoteLibp2p.dialProtocol(libp2p.peerId, '/stream-count/2') + await remoteLibp2p.dialProtocol(libp2p.peerId, ['/stream-count/1', '/other/1.0.0']) + await remoteLibp2p.dialProtocol(libp2p.peerId, ['/stream-count/2', '/other/1.0.0']) // Verify stream count const remoteConn = remoteLibp2p.getConnections(libp2p.peerId) @@ -497,9 +498,9 @@ describe('libp2p.dialer (direct, TCP)', () => { const connection = await libp2p.dial(remoteAddr) expect(connection).to.exist() - const stream = await connection.newStream('/echo/1.0.0') + const stream = await connection.newStream(ECHO_PROTOCOL) expect(stream).to.exist() - expect(stream).to.have.property('protocol', '/echo/1.0.0') + expect(stream).to.have.property('protocol', ECHO_PROTOCOL) await connection.close() expect(protectorProtectSpy.callCount).to.equal(1) }) diff --git a/packages/libp2p/test/connection-manager/index.node.ts b/packages/libp2p/test/connection-manager/index.node.ts index cac8b990d3..d15d695bb3 100644 --- a/packages/libp2p/test/connection-manager/index.node.ts +++ b/packages/libp2p/test/connection-manager/index.node.ts @@ -5,6 +5,8 @@ import { start } from '@libp2p/interface/startable' import { mockConnection, mockDuplex, mockMultiaddrConnection } from '@libp2p/interface-compliance-tests/mocks' import { expect } from 'aegir/chai' import delay from 'delay' +import all from 'it-all' +import { pipe } from 'it-pipe' import pWaitFor from 'p-wait-for' import sinon from 'sinon' import { stubInterface } from 'sinon-ts' @@ -13,6 +15,7 @@ import { DefaultConnectionManager } from '../../src/connection-manager/index.js' import { codes } from '../../src/errors.js' import { createBaseOptions } from '../fixtures/base-options.browser.js' import { createNode, createPeerId } from '../fixtures/creators/peer.js' +import { ECHO_PROTOCOL, echo } from '../fixtures/echo-service.js' import type { Libp2p } from '../../src/index.js' import type { Libp2pNode } from '../../src/libp2p.js' import type { ConnectionGater } from '@libp2p/interface/connection-gater' @@ -401,6 +404,9 @@ describe('libp2p.connections', () => { peerId: peerIds[1], addresses: { listen: ['/ip4/127.0.0.1/tcp/0/ws'] + }, + services: { + echo: echo() } }) }) @@ -591,16 +597,23 @@ describe('libp2p.connections', () => { }, connectionGater: { denyInboundUpgradedConnection + }, + services: { + echo: echo() } }) }) await remoteLibp2p.peerStore.patch(libp2p.peerId, { multiaddrs: libp2p.getMultiaddrs() }) - await remoteLibp2p.dial(libp2p.peerId) + const connection = await remoteLibp2p.dial(libp2p.peerId) + const stream = await connection.newStream(ECHO_PROTOCOL) + const input = [Uint8Array.from([0])] + const output = await pipe(input, stream, async (source) => all(source)) expect(denyInboundUpgradedConnection.called).to.be.true() expect(denyInboundUpgradedConnection.getCall(0)).to.have.nested.property('args[0].multihash.digest').that.equalBytes(remoteLibp2p.peerId.multihash.digest) + expect(output.map(b => b.subarray())).to.deep.equal(input) }) it('intercept outbound upgraded', async () => { @@ -620,10 +633,14 @@ describe('libp2p.connections', () => { await libp2p.peerStore.patch(remoteLibp2p.peerId, { multiaddrs: remoteLibp2p.getMultiaddrs() }) - await libp2p.dial(remoteLibp2p.peerId) + const connection = await libp2p.dial(remoteLibp2p.peerId) + const stream = await connection.newStream(ECHO_PROTOCOL) + const input = [Uint8Array.from([0])] + const output = await pipe(input, stream, async (source) => all(source)) expect(denyOutboundUpgradedConnection.called).to.be.true() expect(denyOutboundUpgradedConnection.getCall(0)).to.have.nested.property('args[0].multihash.digest').that.equalBytes(remoteLibp2p.peerId.multihash.digest) + expect(output.map(b => b.subarray())).to.deep.equal(input) }) }) }) diff --git a/packages/libp2p/test/content-routing/dht/utils.ts b/packages/libp2p/test/content-routing/dht/utils.ts index 323637deba..d370b64478 100644 --- a/packages/libp2p/test/content-routing/dht/utils.ts +++ b/packages/libp2p/test/content-routing/dht/utils.ts @@ -1,3 +1,4 @@ export const subsystemMulticodecs = [ - '/ipfs/lan/kad/1.0.0' + '/ipfs/lan/kad/1.0.0', + '/other/1.0.0' ] diff --git a/packages/libp2p/test/fixtures/echo-service.ts b/packages/libp2p/test/fixtures/echo-service.ts new file mode 100644 index 0000000000..85ef053a53 --- /dev/null +++ b/packages/libp2p/test/fixtures/echo-service.ts @@ -0,0 +1,42 @@ +import { pipe } from 'it-pipe' +import type { Startable } from '@libp2p/interface/startable' +import type { Registrar } from '@libp2p/interface-internal/registrar' + +export const ECHO_PROTOCOL = '/echo/1.0.0' + +export interface EchoInit { + protocol?: string +} + +export interface EchoComponents { + registrar: Registrar +} + +class EchoService implements Startable { + private readonly protocol: string + private readonly registrar: Registrar + + constructor (components: EchoComponents, init: EchoInit = {}) { + this.protocol = init.protocol ?? ECHO_PROTOCOL + this.registrar = components.registrar + } + + async start (): Promise { + await this.registrar.handle(this.protocol, ({ stream }) => { + void pipe(stream, stream) + // sometimes connections are closed before multistream-select finishes + // which causes an error + .catch() + }) + } + + async stop (): Promise { + await this.registrar.unhandle(this.protocol) + } +} + +export function echo (init: EchoInit = {}): (components: EchoComponents) => unknown { + return (components) => { + return new EchoService(components, init) + } +} diff --git a/packages/libp2p/test/upgrading/upgrader.spec.ts b/packages/libp2p/test/upgrading/upgrader.spec.ts index 6863c01e21..81103edfae 100644 --- a/packages/libp2p/test/upgrading/upgrader.spec.ts +++ b/packages/libp2p/test/upgrading/upgrader.spec.ts @@ -345,12 +345,21 @@ describe('Upgrader', () => { } } + class OtherOtherMuxerFactory implements StreamMuxerFactory { + protocol = '/muxer-local-other' + + createStreamMuxer (init?: StreamMuxerInit): StreamMuxer { + return new OtherMuxer() + } + } + localUpgrader = new DefaultUpgrader(localComponents, { connectionEncryption: [ plaintext()(localComponents) ], muxers: [ - new OtherMuxerFactory() + new OtherMuxerFactory(), + new OtherOtherMuxerFactory() ], inboundUpgradeTimeout: 1000 }) @@ -390,9 +399,10 @@ describe('Upgrader', () => { expect(connections).to.have.length(2) // Create a few streams, at least 1 in each direction - await connections[0].newStream('/echo/1.0.0') - await connections[1].newStream('/echo/1.0.0') - await connections[0].newStream('/echo/1.0.0') + // use multiple protocols to trigger regular multistream select + await connections[0].newStream(['/echo/1.0.0', '/echo/1.0.1']) + await connections[1].newStream(['/echo/1.0.0', '/echo/1.0.1']) + await connections[0].newStream(['/echo/1.0.0', '/echo/1.0.1']) connections.forEach(conn => { expect(conn.streams).to.have.length(3) }) @@ -495,7 +505,7 @@ describe('Upgrader', () => { }) } - await expect(connections[0].newStream('/echo/1.0.0', { + await expect(connections[0].newStream(['/echo/1.0.0', '/echo/1.0.1'], { signal })) .to.eventually.be.rejected.with.property('code', 'ABORT_ERR') @@ -514,7 +524,7 @@ describe('Upgrader', () => { expect(connections[0].streams).to.have.lengthOf(0) expect(connections[1].streams).to.have.lengthOf(0) - await expect(connections[0].newStream('/echo/1.0.0')) + await expect(connections[0].newStream(['/echo/1.0.0', '/echo/1.0.1'])) .to.eventually.be.rejected.with.property('code', 'ERR_UNSUPPORTED_PROTOCOL') // wait for remote to close @@ -558,7 +568,7 @@ describe('Upgrader', () => { expect(connections).to.have.length(2) - const stream = await connections[0].newStream('/echo/1.0.0') + const stream = await connections[0].newStream(['/echo/1.0.0', '/echo/1.0.1']) expect(stream).to.have.property('protocol', '/echo/1.0.0') const hello = uint8ArrayFromString('hello there!') @@ -704,7 +714,7 @@ describe('libp2p.upgrader', () => { ]) const remoteLibp2pUpgraderOnStreamSpy = sinon.spy(remoteComponents.upgrader as DefaultUpgrader, '_onStream') - const stream = await localConnection.newStream(['/echo/1.0.0']) + const stream = await localConnection.newStream(['/echo/1.0.0', '/echo/1.0.1']) expect(stream).to.include.keys(['id', 'sink', 'source']) const [arg0] = remoteLibp2pUpgraderOnStreamSpy.getCall(0).args @@ -852,7 +862,7 @@ describe('libp2p.upgrader', () => { expect(streamCount).to.equal(0) - await localToRemote.newStream(protocol) + await localToRemote.newStream([protocol, '/other/1.0.0']) expect(streamCount).to.equal(1) @@ -930,7 +940,7 @@ describe('libp2p.upgrader', () => { expect(streamCount).to.equal(0) - await localToRemote.newStream(protocol) + await localToRemote.newStream([protocol, '/other/1.0.0']) expect(streamCount).to.equal(1) @@ -1004,7 +1014,7 @@ describe('libp2p.upgrader', () => { expect(streamCount).to.equal(0) for (let i = 0; i < limit; i++) { - await localToRemote.newStream(protocol, { + await localToRemote.newStream([protocol, '/other/1.0.0'], { maxOutboundStreams: limit }) } diff --git a/packages/multistream-select/package.json b/packages/multistream-select/package.json index 1a1a20171b..f0b3195d9b 100644 --- a/packages/multistream-select/package.json +++ b/packages/multistream-select/package.json @@ -58,6 +58,7 @@ "it-length-prefixed-stream": "^1.1.1", "it-pipe": "^3.0.1", "it-stream-types": "^2.0.1", + "p-defer": "^4.0.0", "uint8-varint": "^2.0.2", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.6" diff --git a/packages/multistream-select/src/index.ts b/packages/multistream-select/src/index.ts index 677dee48ca..acd5f60e99 100644 --- a/packages/multistream-select/src/index.ts +++ b/packages/multistream-select/src/index.ts @@ -35,5 +35,5 @@ export interface MultistreamSelectInit extends AbortOptions, LoggerOptions, Part } -export { select, lazySelect } from './select.js' +export { select } from './select.js' export { handle } from './handle.js' diff --git a/packages/multistream-select/src/multistream.ts b/packages/multistream-select/src/multistream.ts index 75d976c3e1..1f76f3dbbd 100644 --- a/packages/multistream-select/src/multistream.ts +++ b/packages/multistream-select/src/multistream.ts @@ -12,14 +12,14 @@ const NewLine = uint8ArrayFromString('\n') /** * `write` encodes and writes a single buffer */ -export async function write (writer: LengthPrefixedStream, Source>>, buffer: Uint8Array | Uint8ArrayList, options?: MultistreamSelectInit): Promise { +export async function write (writer: LengthPrefixedStream, Source>>, buffer: Uint8Array | Uint8ArrayList, options?: MultistreamSelectInit): Promise { await writer.write(buffer, options) } /** * `writeAll` behaves like `write`, except it encodes an array of items as a single write */ -export async function writeAll (writer: LengthPrefixedStream, Source>>, buffers: Uint8Array[], options?: MultistreamSelectInit): Promise { +export async function writeAll (writer: LengthPrefixedStream, Source>>, buffers: Uint8Array[], options?: MultistreamSelectInit): Promise { await writer.writeV(buffers, options) } diff --git a/packages/multistream-select/src/select.ts b/packages/multistream-select/src/select.ts index 97d4c4aca3..526c6a4d2b 100644 --- a/packages/multistream-select/src/select.ts +++ b/packages/multistream-select/src/select.ts @@ -1,5 +1,6 @@ import { CodeError } from '@libp2p/interface/errors' import { lpStream } from 'it-length-prefixed-stream' +import pDefer from 'p-defer' import * as varint from 'uint8-varint' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' @@ -7,8 +8,16 @@ import { MAX_PROTOCOL_LENGTH } from './constants.js' import * as multistream from './multistream.js' import { PROTOCOL_ID } from './index.js' import type { MultistreamSelectInit, ProtocolStream } from './index.js' +import type { AbortOptions } from '@libp2p/interface' import type { Duplex } from 'it-stream-types' +export interface SelectStream extends Duplex { + readStatus?: string + closeWrite?(options?: AbortOptions): Promise + closeRead?(options?: AbortOptions): Promise + close?(options?: AbortOptions): Promise +} + /** * Negotiate a protocol to use from a list of protocols. * @@ -52,8 +61,13 @@ import type { Duplex } from 'it-stream-types' * // } * ``` */ -export async function select > (stream: Stream, protocols: string | string[], options: MultistreamSelectInit): Promise> { +export async function select (stream: Stream, protocols: string | string[], options: MultistreamSelectInit): Promise> { protocols = Array.isArray(protocols) ? [...protocols] : [protocols] + + if (protocols.length === 1) { + return optimisticSelect(stream, protocols[0], options) + } + const lp = lpStream(stream, { ...options, maxDataLength: MAX_PROTOCOL_LENGTH @@ -102,17 +116,28 @@ export async function select > (stream: Str } /** - * Lazily negotiates a protocol. + * Optimistically negotiates a protocol. * * It *does not* block writes waiting for the other end to respond. Instead, it * simply assumes the negotiation went successfully and starts writing data. * * Use when it is known that the receiver supports the desired protocol. */ -export function lazySelect > (stream: Stream, protocol: string, options: MultistreamSelectInit): ProtocolStream { +function optimisticSelect (stream: Stream, protocol: string, options: MultistreamSelectInit): ProtocolStream { const originalSink = stream.sink.bind(stream) const originalSource = stream.source - let selected = false + + let negotiated = false + let negotiating = false + const doneNegotiating = pDefer() + + let sentProtocol = false + let sendingProtocol = false + const doneSendingProtocol = pDefer() + + let readProtocol = false + let readingProtocol = false + const doneReadingProtocol = pDefer() const lp = lpStream({ sink: originalSink, @@ -126,11 +151,19 @@ export function lazySelect > (stream: Strea const { sink } = lp.unwrap() await sink(async function * () { + let sentData = false + for await (const buf of source) { - // if writing before selecting, send selection with first data chunk - if (!selected) { - selected = true - options?.log.trace('lazy: write ["%s", "%s", data] in sink', PROTOCOL_ID, protocol) + // started reading before the source yielded, wait for protocol send + if (sendingProtocol) { + await doneSendingProtocol.promise + } + + // writing before reading, send the protocol and the first chunk of data + if (!sentProtocol) { + sendingProtocol = true + + options?.log.trace('optimistic: write ["%s", "%s", data(%d)] in sink', PROTOCOL_ID, protocol, buf.byteLength) const protocolString = `${protocol}\n` @@ -143,44 +176,164 @@ export function lazySelect > (stream: Strea buf ).subarray() - options?.log.trace('lazy: wrote ["%s", "%s", data] in sink', PROTOCOL_ID, protocol) + options?.log.trace('optimistic: wrote ["%s", "%s", data(%d)] in sink', PROTOCOL_ID, protocol, buf.byteLength) + + sentProtocol = true + sendingProtocol = false + doneSendingProtocol.resolve() } else { yield buf } + + sentData = true + } + + // special case - the source passed to the sink has ended but we didn't + // negotiated the protocol yet so do it now + if (!sentData) { + await negotiate() } }()) } - stream.source = (async function * () { - // if reading before selecting, send selection before first data chunk - if (!selected) { - selected = true - options?.log.trace('lazy: write ["%s", "%s", data] in source', PROTOCOL_ID, protocol) + async function negotiate (): Promise { + if (negotiating) { + options?.log.trace('optimistic: already negotiating %s stream', protocol) + await doneNegotiating.promise + return + } + + negotiating = true + + try { + // we haven't sent the protocol yet, send it now + if (!sentProtocol) { + options?.log.trace('optimistic: doing send protocol for %s stream', protocol) + await doSendProtocol() + } + + // if we haven't read the protocol response yet, do it now + if (!readProtocol) { + options?.log.trace('optimistic: doing read protocol for %s stream', protocol) + await doReadProtocol() + } + } finally { + negotiating = false + negotiated = true + doneNegotiating.resolve() + } + } + + async function doSendProtocol (): Promise { + if (sendingProtocol) { + await doneSendingProtocol.promise + return + } + + sendingProtocol = true + + try { + options?.log.trace('optimistic: write ["%s", "%s", data] in source', PROTOCOL_ID, protocol) await lp.writeV([ uint8ArrayFromString(`${PROTOCOL_ID}\n`), uint8ArrayFromString(`${protocol}\n`) ]) - options?.log.trace('lazy: wrote ["%s", "%s", data] in source', PROTOCOL_ID, protocol) + options?.log.trace('optimistic: wrote ["%s", "%s", data] in source', PROTOCOL_ID, protocol) + } finally { + sentProtocol = true + sendingProtocol = false + doneSendingProtocol.resolve() } + } - options?.log.trace('lazy: reading multistream select header') - let response = await multistream.readString(lp, options) - options?.log.trace('lazy: read multistream select header "%s"', response) - - if (response === PROTOCOL_ID) { - response = await multistream.readString(lp, options) + async function doReadProtocol (): Promise { + if (readingProtocol) { + await doneReadingProtocol.promise + return } - options?.log.trace('lazy: read protocol "%s", expecting "%s"', response, protocol) + readingProtocol = true - if (response !== protocol) { - throw new CodeError('protocol selection failed', 'ERR_UNSUPPORTED_PROTOCOL') + try { + options?.log.trace('optimistic: reading multistream select header') + let response = await multistream.readString(lp, options) + options?.log.trace('optimistic: read multistream select header "%s"', response) + + if (response === PROTOCOL_ID) { + response = await multistream.readString(lp, options) + } + + options?.log.trace('optimistic: read protocol "%s", expecting "%s"', response, protocol) + + if (response !== protocol) { + throw new CodeError('protocol selection failed', 'ERR_UNSUPPORTED_PROTOCOL') + } + } finally { + readProtocol = true + readingProtocol = false + doneReadingProtocol.resolve() } + } + + stream.source = (async function * () { + // make sure we've done protocol negotiation before we read stream data + await negotiate() - options?.log.trace('lazy: reading rest of "%s" stream', protocol) + options?.log.trace('optimistic: reading data from "%s" stream', protocol) yield * lp.unwrap().source })() + if (stream.closeRead != null) { + const originalCloseRead = stream.closeRead.bind(stream) + + stream.closeRead = async (opts) => { + // we need to read & write to negotiate the protocol so ensure we've done + // this before closing the readable end of the stream + if (!negotiated) { + await negotiate().catch(err => { + options?.log.error('could not negotiate protocol before close read', err) + }) + } + + // protocol has been negotiated, ok to close the readable end + await originalCloseRead(opts) + } + } + + if (stream.closeWrite != null) { + const originalCloseWrite = stream.closeWrite.bind(stream) + + stream.closeWrite = async (opts) => { + // we need to read & write to negotiate the protocol so ensure we've done + // this before closing the writable end of the stream + if (!negotiated) { + await negotiate().catch(err => { + options?.log.error('could not negotiate protocol before close write', err) + }) + } + + // protocol has been negotiated, ok to close the writable end + await originalCloseWrite(opts) + } + } + + if (stream.close != null) { + const originalClose = stream.close.bind(stream) + + stream.close = async (opts) => { + // the stream is being closed, don't try to negotiate a protocol if we + // haven't already + if (!negotiated) { + negotiated = true + negotiating = false + doneNegotiating.resolve() + } + + // protocol has been negotiated, ok to close the writable end + await originalClose(opts) + } + } + return { stream, protocol diff --git a/packages/multistream-select/test/dialer.spec.ts b/packages/multistream-select/test/dialer.spec.ts index 5ff4aeb4e2..6a75fc391e 100644 --- a/packages/multistream-select/test/dialer.spec.ts +++ b/packages/multistream-select/test/dialer.spec.ts @@ -14,11 +14,11 @@ import * as mss from '../src/index.js' describe('Dialer', () => { describe('dialer.select', () => { - it('should select from single protocol', async () => { + it('should select from single protocol on outgoing stream', async () => { const protocol = '/echo/1.0.0' const [outgoingStream, incomingStream] = duplexPair() - void mss.handle(incomingStream, protocol, { + const handled = mss.handle(incomingStream, protocol, { log: logger('mss:test-incoming') }) @@ -27,13 +27,42 @@ describe('Dialer', () => { }) expect(selection.protocol).to.equal(protocol) - // Ensure stream is usable after selection + // Ensure stream is usable after selection - send data outgoing -> incoming const input = [randomBytes(10), randomBytes(64), randomBytes(3)] void pipe(input, selection.stream) + + // wait for incoming end to have completed negotiation + await handled + const output = await all(incomingStream.source) expect(new Uint8ArrayList(...output).slice()).to.eql(new Uint8ArrayList(...input).slice()) }) + it('should select from single protocol on incoming stream', async () => { + const protocol = '/echo/1.0.0' + const [outgoingStream, incomingStream] = duplexPair() + const input = [randomBytes(10), randomBytes(64), randomBytes(3)] + + void mss.select(outgoingStream, protocol, { + log: logger('mss:test-outgoing') + }) + + // have to interact with the stream to start protocol negotiation + const outgoingSourceData = all(outgoingStream.source) + + const selection = await mss.handle(incomingStream, protocol, { + log: logger('mss:test-incoming') + }) + + expect(selection.protocol).to.equal(protocol) + + // Ensure stream is usable after selection - send data incoming -> outgoing + void pipe(input, selection.stream) + + const output = await outgoingSourceData + expect(new Uint8ArrayList(...output).slice()).to.eql(new Uint8ArrayList(...input).slice()) + }) + it('should fail to select twice', async () => { const protocol = '/echo/1.0.0' const protocol2 = '/echo/2.0.0' @@ -49,7 +78,7 @@ describe('Dialer', () => { expect(selection.protocol).to.equal(protocol) // A second select will timeout - await pTimeout(mss.select(outgoingStream, protocol2, { + await pTimeout(mss.select(outgoingStream, [protocol, protocol2], { log: logger('mss:test-outgoing') }), { milliseconds: 1e3 @@ -101,7 +130,7 @@ describe('Dialer', () => { const protocol = '/echo/1.0.0' const [outgoingStream, incomingStream] = duplexPair() - const selection = mss.lazySelect(outgoingStream, protocol, { + const selection = await mss.select(outgoingStream, [protocol], { log: logger('mss:test-lazy') }) expect(selection.protocol).to.equal(protocol) diff --git a/packages/multistream-select/test/integration.spec.ts b/packages/multistream-select/test/integration.spec.ts index 0ff0389e24..479f798ad3 100644 --- a/packages/multistream-select/test/integration.spec.ts +++ b/packages/multistream-select/test/integration.spec.ts @@ -45,7 +45,7 @@ describe('Dialer and Listener integration', () => { mss.handle(pair[1], selectedProtocol, { log: logger('mss:test') }), - (async () => mss.select(pair[0], selectedProtocol, { + (async () => mss.select(pair[0], protocols, { log: logger('mss:test') }))() ]) @@ -92,7 +92,7 @@ describe('Dialer and Listener integration', () => { const protocol = '/echo/1.0.0' const pair = duplexPair() - const dialerSelection = mss.lazySelect(pair[0], protocol, { + const dialerSelection = await mss.select(pair[0], [protocol], { log: logger('mss:test') }) expect(dialerSelection.protocol).to.equal(protocol) @@ -118,7 +118,7 @@ describe('Dialer and Listener integration', () => { const otherProtocol = '/echo/2.0.0' const pair = duplexPair() - const dialerSelection = mss.lazySelect(pair[0], protocol, { + const dialerSelection = await mss.select(pair[0], [protocol], { log: logger('mss:test') }) expect(dialerSelection.protocol).to.equal(protocol) @@ -138,7 +138,7 @@ describe('Dialer and Listener integration', () => { const protocol = '/echo/1.0.0' const pair = duplexPair() - const dialerSelection = mss.lazySelect(pair[0], protocol, { + const dialerSelection = await mss.select(pair[0], [protocol], { log: logger('mss:dialer') }) expect(dialerSelection.protocol).to.equal(protocol) @@ -168,7 +168,7 @@ describe('Dialer and Listener integration', () => { const pair = duplexPair() // lazy succeeds - const dialerSelection = mss.lazySelect(pair[0], protocol, { + const dialerSelection = await mss.select(pair[0], [protocol], { log: logger('mss:dialer') }) expect(dialerSelection.protocol).to.equal(protocol) @@ -187,7 +187,7 @@ describe('Dialer and Listener integration', () => { const protocol = '/echo/1.0.0' const pair = duplexPair() - const dialerSelection = mss.lazySelect(pair[0], protocol, { + const dialerSelection = await mss.select(pair[0], [protocol], { log: logger('mss:test') }) expect(dialerSelection.protocol).to.equal(protocol) diff --git a/packages/transport-webrtc/test/basics.spec.ts b/packages/transport-webrtc/test/basics.spec.ts index 171a78cfd4..a74384a2a1 100644 --- a/packages/transport-webrtc/test/basics.spec.ts +++ b/packages/transport-webrtc/test/basics.spec.ts @@ -242,6 +242,9 @@ describe('basics', () => { runOnTransientConnection: true }) + // close the write end immediately + const p = stream.closeWrite() + const remoteStream = await getRemoteStream.promise // close the readable end of the remote stream await remoteStream.closeRead() @@ -250,8 +253,6 @@ describe('basics', () => { const remoteInputStream = pushable() void remoteStream.sink(remoteInputStream) - const p = stream.closeWrite() - // wait for remote to receive local close-write await pRetry(() => { if (remoteStream.readStatus !== 'closed') { @@ -302,15 +303,15 @@ describe('basics', () => { runOnTransientConnection: true }) + // keep the remote write end open, this should delay the FIN_ACK reply to the local stream + const p = stream.sink([]) + const remoteStream = await getRemoteStream.promise // close the readable end of the remote stream await remoteStream.closeRead() // readable end should finish await drain(remoteStream.source) - // keep the remote write end open, this should delay the FIN_ACK reply to the local stream - const p = stream.sink([]) - // wait for remote to receive local close-write await pRetry(() => { if (remoteStream.readStatus !== 'closed') { diff --git a/packages/utils/src/abstract-stream.ts b/packages/utils/src/abstract-stream.ts index 9ad31c5f0f..a1baf9cf37 100644 --- a/packages/utils/src/abstract-stream.ts +++ b/packages/utils/src/abstract-stream.ts @@ -96,6 +96,7 @@ export abstract class AbstractStream implements Stream { private readonly sinkController: AbortController private readonly sinkEnd: DeferredPromise + private readonly closed: DeferredPromise private endErr: Error | undefined private readonly streamSource: Pushable private readonly onEnd?: (err?: Error | undefined) => void @@ -108,6 +109,7 @@ export abstract class AbstractStream implements Stream { constructor (init: AbstractStreamInit) { this.sinkController = new AbortController() this.sinkEnd = defer() + this.closed = defer() this.log = init.log // stream status @@ -237,6 +239,8 @@ export abstract class AbstractStream implements Stream { if (this.onEnd != null) { this.onEnd(this.endErr) } + + this.closed.resolve() } else { this.log.trace('source ended, waiting for sink to end') } @@ -267,6 +271,8 @@ export abstract class AbstractStream implements Stream { if (this.onEnd != null) { this.onEnd(this.endErr) } + + this.closed.resolve() } else { this.log.trace('sink ended, waiting for source to end') } @@ -283,6 +289,9 @@ export abstract class AbstractStream implements Stream { this.closeWrite(options) ]) + // wait for read and write ends to close + await raceSignal(this.closed.promise, options?.signal) + this.status = 'closed' this.log.trace('closed gracefully') diff --git a/packages/utils/test/abstract-stream.spec.ts b/packages/utils/test/abstract-stream.spec.ts index e88f17c929..581c24e1dd 100644 --- a/packages/utils/test/abstract-stream.spec.ts +++ b/packages/utils/test/abstract-stream.spec.ts @@ -37,7 +37,8 @@ describe('abstract stream', () => { stream = new TestStream({ id: 'test', direction: 'outbound', - log: logger('test') + log: logger('test'), + onEnd: () => {} }) }) @@ -68,11 +69,13 @@ describe('abstract stream', () => { it('closes', async () => { const sendCloseReadSpy = Sinon.spy(stream, 'sendCloseRead') const sendCloseWriteSpy = Sinon.spy(stream, 'sendCloseWrite') + const onEndSpy = Sinon.spy(stream as any, 'onEnd') await stream.close() expect(sendCloseReadSpy.calledOnce).to.be.true() expect(sendCloseWriteSpy.calledOnce).to.be.true() + expect(onEndSpy.calledOnce).to.be.true() expect(stream).to.have.property('status', 'closed') expect(stream).to.have.property('writeStatus', 'closed')