diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 04c4032b04e42a..ee52bcb260c271 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -20,6 +20,7 @@ const X509 = @import("./x509.zig"); const Async = bun.Async; const uv = bun.windows.libuv; const H2FrameParser = @import("./h2_frame_parser.zig").H2FrameParser; +const NodePath = @import("../../node/path.zig"); noinline fn getSSLException(globalThis: *JSC.JSGlobalObject, defaultMessage: []const u8) JSValue { var zig_str: ZigString = ZigString.init(""); var output_buf: [4096]u8 = undefined; @@ -452,6 +453,38 @@ pub const SocketConfig = struct { } }; +fn isValidPipeName(pipe_name: []const u8) bool { + if (!Environment.isWindows) { + return false; + } + // check for valid pipe names + // at minimum we need to have \\.\pipe\ or \\?\pipe\ + 1 char that is not a separator + return pipe_name.len > 9 and + NodePath.isSepWindowsT(u8, pipe_name[0]) and + NodePath.isSepWindowsT(u8, pipe_name[1]) and + (pipe_name[2] == '.' or pipe_name[2] == '?') and + NodePath.isSepWindowsT(u8, pipe_name[3]) and + strings.eql(pipe_name[4..8], "pipe") and + NodePath.isSepWindowsT(u8, pipe_name[8]) and + !NodePath.isSepWindowsT(u8, pipe_name[9]); +} + +fn normalizePipeName(pipe_name: []const u8, buffer: []u8) ?[]const u8 { + if (Environment.isWindows) { + bun.assert(pipe_name.len < buffer.len); + if (!isValidPipeName(pipe_name)) { + return null; + } + // normalize pipe name with can have mixed slashes + // pipes are simple and this will be faster than using node:path.resolve() + // we dont wanna to normalize the pipe name it self only the pipe identifier (//./pipe/, //?/pipe/, etc) + @memcpy(buffer[0..9], "\\\\.\\pipe\\"); + @memcpy(buffer[9..pipe_name.len], pipe_name[9..]); + return buffer[0..pipe_name.len]; + } else { + return null; + } +} pub const Listener = struct { pub const log = Output.scoped(.Listener, false); @@ -606,8 +639,9 @@ pub const Listener = struct { if (Environment.isWindows) { if (port == null) { // we check if the path is a named pipe otherwise we try to connect using AF_UNIX - const pipe_name = hostname_or_unix.slice(); - if (strings.startsWith(pipe_name, "\\\\.\\pipe\\") or strings.startsWith(pipe_name, "\\\\?\\pipe\\")) { + const slice = hostname_or_unix.slice(); + var buf: bun.PathBuffer = undefined; + if (normalizePipeName(slice, buf[0..])) |pipe_name| { const connection: Listener.UnixOrHost = .{ .unix = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch bun.outOfMemory()).slice() }; if (ssl_enabled) { if (ssl.?.protos) |p| { @@ -1099,9 +1133,14 @@ pub const Listener = struct { }; if (Environment.isWindows) { + var buf: bun.PathBuffer = undefined; + var pipe_name: ?[]const u8 = null; const isNamedPipe = switch (connection) { // we check if the path is a named pipe otherwise we try to connect using AF_UNIX - .unix => |pipe_name| strings.startsWith(pipe_name, "\\\\.\\pipe\\") or strings.startsWith(pipe_name, "\\\\?\\pipe\\"), + .unix => |slice| brk: { + pipe_name = normalizePipeName(slice, buf[0..]); + break :brk (pipe_name != null); + }, .fd => |fd| brk: { const uvfd = bun.uvfdcast(fd); const fd_type = uv.uv_guess_handle(uvfd); @@ -1146,7 +1185,7 @@ pub const Listener = struct { tls.poll_ref.ref(handlers.vm); tls.ref(); if (connection == .unix) { - const named_pipe = WindowsNamedPipeContext.connect(globalObject, connection.unix, ssl, .{ .tls = tls }) catch { + const named_pipe = WindowsNamedPipeContext.connect(globalObject, pipe_name.?, ssl, .{ .tls = tls }) catch { return promise_value; }; tls.socket = TLSSocket.Socket.fromNamedPipe(named_pipe); @@ -1172,7 +1211,7 @@ pub const Listener = struct { tcp.poll_ref.ref(handlers.vm); if (connection == .unix) { - const named_pipe = WindowsNamedPipeContext.connect(globalObject, connection.unix, null, .{ .tcp = tcp }) catch { + const named_pipe = WindowsNamedPipeContext.connect(globalObject, pipe_name.?, null, .{ .tcp = tcp }) catch { return promise_value; }; tcp.socket = TCPSocket.Socket.fromNamedPipe(named_pipe); diff --git a/test/js/node/net/node-net.test.ts b/test/js/node/net/node-net.test.ts index 9e4aa7ffc25b9d..ef56deafe9c59c 100644 --- a/test/js/node/net/node-net.test.ts +++ b/test/js/node/net/node-net.test.ts @@ -563,48 +563,57 @@ it("should not hang after destroy", async () => { } }); -it.if(isWindows)("should work with named pipes", async () => { - async function test(pipe_name: string) { - const { promise: messageReceived, resolve: resolveMessageReceived } = Promise.withResolvers(); - const { promise: clientReceived, resolve: resolveClientReceived } = Promise.withResolvers(); - let client: ReturnType | null = null; - let server: ReturnType | null = null; - try { - server = createServer(socket => { - socket.on("data", data => { - const message = data.toString(); - socket.end("Goodbye World!"); - resolveMessageReceived(message); +it.if(isWindows)( + "should work with named pipes", + async () => { + async function test(pipe_name: string) { + const { promise: messageReceived, resolve: resolveMessageReceived } = Promise.withResolvers(); + const { promise: clientReceived, resolve: resolveClientReceived } = Promise.withResolvers(); + let client: ReturnType | null = null; + let server: ReturnType | null = null; + try { + server = createServer(socket => { + socket.on("data", data => { + const message = data.toString(); + socket.end("Goodbye World!"); + resolveMessageReceived(message); + }); }); - }); - server.listen(pipe_name); - client = connect(pipe_name).on("data", data => { - const message = data.toString(); - resolveClientReceived(message); - }); + server.listen(pipe_name); + client = connect(pipe_name).on("data", data => { + const message = data.toString(); + resolveClientReceived(message); + }); - client?.write("Hello World!"); - const message = await messageReceived; - expect(message).toBe("Hello World!"); - const client_message = await clientReceived; - expect(client_message).toBe("Goodbye World!"); - } finally { - client?.destroy(); - server?.close(); + client?.write("Hello World!"); + const message = await messageReceived; + expect(message).toBe("Hello World!"); + const client_message = await clientReceived; + expect(client_message).toBe("Goodbye World!"); + } finally { + client?.destroy(); + server?.close(); + } } - } - const batch = []; - const before = heapStats().objectTypeCounts.TLSSocket || 0; - for (let i = 0; i < 200; i++) { - batch.push(test(`\\\\.\\pipe\\test\\${randomUUID()}`)); - batch.push(test(`\\\\?\\pipe\\test\\${randomUUID()}`)); - if (i % 50 === 0) { - await Promise.all(batch); - batch.length = 0; + const batch = []; + const before = heapStats().objectTypeCounts.TLSSocket || 0; + for (let i = 0; i < 100; i++) { + batch.push(test(`\\\\.\\pipe\\test\\${randomUUID()}`)); + batch.push(test(`\\\\?\\pipe\\test\\${randomUUID()}`)); + batch.push(test(`//?/pipe/test/${randomUUID()}`)); + batch.push(test(`//./pipe/test/${randomUUID()}`)); + batch.push(test(`/\\./pipe/test/${randomUUID()}`)); + batch.push(test(`/\\./pipe\\test/${randomUUID()}`)); + batch.push(test(`\\/.\\pipe/test\\${randomUUID()}`)); + if (i % 50 === 0) { + await Promise.all(batch); + batch.length = 0; + } } - } - await Promise.all(batch); - expectMaxObjectTypeCount(expect, "TCPSocket", before); -}); + await Promise.all(batch); + expectMaxObjectTypeCount(expect, "TCPSocket", before); + }, + 20_000, +);