diff --git a/cli/tests/unit_node/tls_test.ts b/cli/tests/unit_node/tls_test.ts index 79c1e634ca0cb4..7a270c60b3fa4d 100644 --- a/cli/tests/unit_node/tls_test.ts +++ b/cli/tests/unit_node/tls_test.ts @@ -56,6 +56,38 @@ Connection: close await serve; }); +// https://github.com/denoland/deno/pull/20120 +Deno.test("tls.connect mid-read tcp->tls upgrade", async () => { + const ctl = new AbortController(); + const serve = serveTls(() => new Response("hello"), { + port: 8443, + key, + cert, + signal: ctl.signal, + }); + + await delay(200); + + const conn = tls.connect({ + host: "localhost", + port: 8443, + secureContext: { + ca: rootCaCert, + // deno-lint-ignore no-explicit-any + } as any, + }); + + conn.setEncoding("utf8"); + conn.write(`GET / HTTP/1.1\nHost: www.google.com\n\n`); + + conn.on("data", (_) => { + conn.destroy(); + ctl.abort(); + }); + + await serve; +}); + Deno.test("tls.createServer creates a TLS server", async () => { const p = deferred(); const server = tls.createServer( diff --git a/ext/node/polyfills/_tls_wrap.ts b/ext/node/polyfills/_tls_wrap.ts index 39df239d0d2867..416bd4136f8f31 100644 --- a/ext/node/polyfills/_tls_wrap.ts +++ b/ext/node/polyfills/_tls_wrap.ts @@ -26,6 +26,11 @@ import { import { EventEmitter } from "node:events"; import { kEmptyObject } from "ext:deno_node/internal/util.mjs"; import { nextTick } from "ext:deno_node/_next_tick.ts"; +import { kHandle } from "ext:deno_node/internal/stream_base_commons.ts"; +import { + isAnyArrayBuffer, + isArrayBufferView, +} from "ext:deno_node/internal/util/types.ts"; const kConnectOptions = Symbol("connect-options"); const kIsVerified = Symbol("verified"); @@ -71,7 +76,11 @@ export class TLSSocket extends net.Socket { [kPendingSession]: any; [kConnectOptions]: any; ssl: any; - _start: any; + + _start() { + this[kHandle].afterConnect(); + } + constructor(socket: any, opts: any = kEmptyObject) { const tlsOptions = { ...opts }; @@ -84,6 +93,9 @@ export class TLSSocket extends net.Socket { let caCerts = tlsOptions?.secureContext?.ca; if (typeof caCerts === "string") caCerts = [caCerts]; + else if (isArrayBufferView(caCerts) || isAnyArrayBuffer(caCerts)) { + caCerts = [new TextDecoder().decode(caCerts)]; + } tlsOptions.caCerts = caCerts; super({ @@ -139,9 +151,9 @@ export class TLSSocket extends net.Socket { handle.afterConnect = async (req: any, status: number) => { try { const conn = await Deno.startTls(handle[kStreamBaseField], options); + handle[kStreamBaseField] = conn; tlssock.emit("secure"); tlssock.removeListener("end", onConnectEnd); - handle[kStreamBaseField] = conn; } catch { // TODO(kt3k): Handle this } diff --git a/ext/node/polyfills/internal_binding/stream_wrap.ts b/ext/node/polyfills/internal_binding/stream_wrap.ts index 528dd7c3f265a4..66ebbe682dfadb 100644 --- a/ext/node/polyfills/internal_binding/stream_wrap.ts +++ b/ext/node/polyfills/internal_binding/stream_wrap.ts @@ -314,9 +314,16 @@ export class LibuvStreamWrap extends HandleWrap { let buf = BUF; let nread: number | null; + const ridBefore = this[kStreamBaseField]!.rid; try { nread = await this[kStreamBaseField]!.read(buf); } catch (e) { + // Try to read again if the underlying stream resource + // changed. This can happen during TLS upgrades (eg. STARTTLS) + if (ridBefore != this[kStreamBaseField]!.rid) { + return this.#read(); + } + if ( e instanceof Deno.errors.Interrupted || e instanceof Deno.errors.BadResource @@ -365,15 +372,23 @@ export class LibuvStreamWrap extends HandleWrap { async #write(req: WriteWrap, data: Uint8Array) { const { byteLength } = data; + const ridBefore = this[kStreamBaseField]!.rid; + + let nwritten = 0; try { // TODO(crowlKats): duplicate from runtime/js/13_buffer.js - let nwritten = 0; while (nwritten < data.length) { nwritten += await this[kStreamBaseField]!.write( data.subarray(nwritten), ); } } catch (e) { + // Try to read again if the underlying stream resource + // changed. This can happen during TLS upgrades (eg. STARTTLS) + if (ridBefore != this[kStreamBaseField]!.rid) { + return this.#write(req, data.subarray(nwritten)); + } + let status: number; // TODO(cmorten): map err to status codes