diff --git a/lib/internal/streams/operators.js b/lib/internal/streams/operators.js index 80a0f9f731e89a..fe863bf3ed013d 100644 --- a/lib/internal/streams/operators.js +++ b/lib/internal/streams/operators.js @@ -10,7 +10,10 @@ const { PromisePrototypeThen, PromiseReject, PromiseResolve, + SafeSet, Symbol, + SymbolAsyncIterator, + SymbolIterator, } = primordials; const { AbortController, AbortSignal } = require('internal/abort_controller'); @@ -39,6 +42,7 @@ const { isWritable, isNodeStream } = require('internal/streams/utils'); const kEmpty = Symbol('kEmpty'); const kEof = Symbol('kEof'); +const kFlatMap = Symbol('kFlatMap'); function compose(stream, options) { if (options != null) { @@ -92,11 +96,20 @@ function map(fn, options) { highWaterMark += concurrency; + const flatMap = options?.[kFlatMap] != null; + return async function* map() { const signal = AbortSignal.any([options?.signal].filter(Boolean)); const stream = this; const queue = []; const signalOpt = { signal }; + const baseIterator = (async function* baseIterator() { + for await (const value of stream) { + // wrap in an object to avoid awaitng if result is a promise + yield { result: fn(value, signalOpt) }; + } + })(); + const iterators = new SafeSet([baseIterator]); let next; let resume; @@ -125,45 +138,54 @@ function map(fn, options) { } } + function addIterator(result) { + if (result && (result[SymbolAsyncIterator] || result[SymbolIterator])) { + const iterator = result[SymbolAsyncIterator] ? result[SymbolAsyncIterator]() : result[SymbolIterator](); + iterators.add(iterator); + return kEmpty; + } + return result; + } + async function pump() { try { - for await (let val of stream) { - if (done) { - return; - } - - if (signal.aborted) { - throw new AbortError(); - } - - try { - val = fn(val, signalOpt); - - if (val === kEmpty) { - continue; + while (iterators.size > 0) { + for (const iterator of iterators) { + if (done) { + return; } - val = PromiseResolve(val); - } catch (err) { - val = PromiseReject(err); - } - - cnt += 1; + if (signal.aborted) { + throw new AbortError(); + } + let val = PromisePrototypeThen(PromiseResolve(iterator.next()), ({ value, done }) => { + if (done) { + iterators.delete(iterator); + return kEmpty; + } + return iterator === baseIterator ? value.result : value; + }); - PromisePrototypeThen(val, afterItemProcessed, onCatch); + if (flatMap && baseIterator === iterator) { + val = PromisePrototypeThen(val, addIterator); + } + PromisePrototypeThen(val, afterItemProcessed, onCatch); + cnt += 1; + queue.push(val); - queue.push(val); - if (next) { - next(); - next = null; - } + if (next) { + next(); + next = null; + } - if (!done && (queue.length >= highWaterMark || cnt >= concurrency)) { - await new Promise((resolve) => { - resume = resolve; - }); + if (!done && (queue.length >= highWaterMark || cnt >= concurrency)) { + await new Promise((resolve) => { + resume = resolve; + }); + } } } + queue.push(kEof); } catch (err) { const val = PromiseReject(err); @@ -343,12 +365,10 @@ async function toArray(options) { } function flatMap(fn, options) { - const values = map.call(this, fn, options); - return async function* flatMap() { - for await (const val of values) { - yield* val; - } - }.call(this); + if (options != null) { + validateObject(options, 'options'); + } + return map.call(this, fn, { ...options, [kFlatMap]: true }); } function toIntegerOrInfinity(number) { diff --git a/test/parallel/test-stream-flatMap.js b/test/parallel/test-stream-flatMap.js index 0e55119f7a767d..869065c6f65e98 100644 --- a/test/parallel/test-stream-flatMap.js +++ b/test/parallel/test-stream-flatMap.js @@ -73,9 +73,9 @@ function oneTo5() { { // Concurrency + AbortSignal const ac = new AbortController(); - const stream = oneTo5().flatMap(common.mustNotCall(async (_, { signal }) => { + const stream = oneTo5().flatMap(common.mustCall(async (_, { signal }) => { await setTimeout(100, { signal }); - }), { signal: ac.signal, concurrency: 2 }); + }, 2), { signal: ac.signal, concurrency: 2 }); // pump assert.rejects(async () => { for await (const item of stream) {