diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 2e6381c745b28..1e5a5e0041885 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -3794,6 +3794,8 @@ pub const Timer = struct { result.then(globalThis, this, CallbackJob__onResolve, CallbackJob__onReject); }, } + } else { + this.deinit(); } } }; diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 69d6611cb910f..329cc40e47bea 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -69,6 +69,11 @@ fn normalizeHost(input: anytype) @TypeOf(input) { const BinaryType = JSC.BinaryType; +const WrappedType = enum { + none, + tls, + tcp, +}; const Handlers = struct { onOpen: JSC.JSValue = .zero, onClose: JSC.JSValue = .zero, @@ -97,8 +102,8 @@ const Handlers = struct { handlers: *Handlers, socket_context: *uws.SocketContext, - pub fn exit(this: *Scope, ssl: bool) void { - this.handlers.markInactive(ssl, this.socket_context); + pub fn exit(this: *Scope, ssl: bool, wrapped: WrappedType) void { + this.handlers.markInactive(ssl, this.socket_context, wrapped); } }; @@ -123,19 +128,24 @@ const Handlers = struct { return true; } - pub fn markInactive(this: *Handlers, ssl: bool, ctx: *uws.SocketContext) void { + pub fn markInactive(this: *Handlers, ssl: bool, ctx: *uws.SocketContext, wrapped: WrappedType) void { Listener.log("markInactive", .{}); this.active_connections -= 1; - if (this.active_connections == 0 and this.is_server) { - var listen_socket: *Listener = @fieldParentPtr(Listener, "handlers", this); - // allow it to be GC'd once the last connection is closed and it's not listening anymore - if (listen_socket.listener == null) { - listen_socket.strong_self.clear(); + if (this.active_connections == 0) { + if (this.is_server) { + var listen_socket: *Listener = @fieldParentPtr(Listener, "handlers", this); + // allow it to be GC'd once the last connection is closed and it's not listening anymore + if (listen_socket.listener == null) { + listen_socket.strong_self.clear(); + } + } else { + this.unprotect(); + // will deinit when is not wrapped or when is the TCP wrapped connection + if (wrapped != .tls) { + ctx.deinit(ssl); + } + bun.default_allocator.destroy(this); } - } else if (this.active_connections == 0 and !this.is_server) { - this.unprotect(); - ctx.deinit(ssl); - bun.default_allocator.destroy(this); } } @@ -364,6 +374,7 @@ pub const Listener = struct { connection: UnixOrHost, socket_context: ?*uws.SocketContext = null, ssl: bool = false, + protos: ?[]const u8 = null, strong_data: JSC.Strong = .{}, strong_self: JSC.Strong = .{}, @@ -395,6 +406,19 @@ pub const Listener = struct { port: u16, }, + pub fn clone(this: UnixOrHost) UnixOrHost { + switch (this) { + .unix => |u| { + return .{ + .unix = (bun.default_allocator.dupe(u8, u) catch unreachable), + }; + }, + .host => |h| { + return .{ .host = .{ .host = (bun.default_allocator.dupe(u8, h.host) catch unreachable), .port = this.host.port } }; + }, + } + } + pub fn deinit(this: UnixOrHost) void { switch (this) { .unix => |u| { @@ -455,10 +479,12 @@ pub const Listener = struct { var socket_config = SocketConfig.fromJS(opts, globalObject, exception) orelse { return .zero; }; + var hostname_or_unix = socket_config.hostname_or_unix; var port = socket_config.port; var ssl = socket_config.ssl; var handlers = socket_config.handlers; + var protos: ?[]const u8 = null; const exclusive = socket_config.exclusive; handlers.is_server = true; @@ -496,6 +522,10 @@ pub const Listener = struct { }; if (ssl_enabled) { + if (ssl.?.protos) |p| { + protos = p[0..ssl.?.protos_len]; + } + uws.NewSocketHandler(true).configure( socket_context, true, @@ -593,6 +623,7 @@ pub const Listener = struct { .ssl = ssl_enabled, .socket_context = socket_context, .listener = listen_socket, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, }; socket.handlers.protect(); @@ -649,6 +680,8 @@ pub const Listener = struct { .handlers = &listener.handlers, .this_value = .zero, .socket = socket, + .protos = listener.protos, + .owned_protos = false, }; if (listener.strong_data.get()) |default_data| { const globalObject = listener.handlers.globalObject; @@ -715,6 +748,10 @@ pub const Listener = struct { this.handlers.unprotect(); this.connection.deinit(); + if (this.protos) |protos| { + this.protos = null; + bun.default_allocator.destroy(protos); + } bun.default_allocator.destroy(this); } @@ -775,13 +812,16 @@ pub const Listener = struct { const socket_config = SocketConfig.fromJS(opts, globalObject, exception) orelse { return .zero; }; + var hostname_or_unix = socket_config.hostname_or_unix; var port = socket_config.port; var ssl = socket_config.ssl; var handlers = socket_config.handlers; var default_data = socket_config.default_data; + var protos: ?[]const u8 = null; const ssl_enabled = ssl != null; + defer if (ssl != null) ssl.?.deinit(); handlers.protect(); @@ -797,6 +837,9 @@ pub const Listener = struct { }; if (ssl_enabled) { + if (ssl.?.protos) |p| { + protos = p[0..ssl.?.protos_len]; + } uws.NewSocketHandler(true).configure( socket_context, true, @@ -848,6 +891,7 @@ pub const Listener = struct { .this_value = .zero, .socket = undefined, .connection = connection, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, }; TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -871,6 +915,7 @@ pub const Listener = struct { .this_value = .zero, .socket = undefined, .connection = null, + .protos = null, }; TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -898,11 +943,41 @@ fn JSSocketType(comptime ssl: bool) type { } } +fn selectALPNCallback( + _: ?*BoringSSL.SSL, + out: [*c][*c]const u8, + outlen: [*c]u8, + in: [*c]const u8, + inlen: c_uint, + arg: ?*anyopaque, +) callconv(.C) c_int { + const this = bun.cast(*TLSSocket, arg); + if (this.protos) |protos| { + if (protos.len == 0) { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; + } + + const status = BoringSSL.SSL_select_next_proto(bun.cast([*c][*c]u8, out), outlen, protos.ptr, @intCast(c_uint, protos.len), in, inlen); + + // Previous versions of Node.js returned SSL_TLSEXT_ERR_NOACK if no protocol + // match was found. This would neither cause a fatal alert nor would it result + // in a useful ALPN response as part of the Server Hello message. + // We now return SSL_TLSEXT_ERR_ALERT_FATAL in that case as per Section 3.2 + // of RFC 7301, which causes a fatal no_application_protocol alert. + const expected = if (comptime BoringSSL.OPENSSL_NPN_NEGOTIATED == 1) BoringSSL.SSL_TLSEXT_ERR_OK else BoringSSL.SSL_TLSEXT_ERR_ALERT_FATAL; + + return if (status == expected) 1 else 0; + } else { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; + } +} + fn NewSocket(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); socket: Socket, detached: bool = false, + wrapped: WrappedType = .none, handlers: *Handlers, this_value: JSC.JSValue = .zero, poll_ref: JSC.PollRef = JSC.PollRef.init(), @@ -910,6 +985,8 @@ fn NewSocket(comptime ssl: bool) type { last_4: [4]u8 = .{ 0, 0, 0, 0 }, authorized: bool = false, connection: ?Listener.UnixOrHost = null, + protos: ?[]const u8, + owned_protos: bool = true, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there @@ -1079,7 +1156,7 @@ fn NewSocket(comptime ssl: bool) type { var vm = this.handlers.vm; this.reffer.unref(vm); - this.handlers.markInactive(ssl, this.socket.context()); + this.handlers.markInactive(ssl, this.socket.context(), this.wrapped); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .Release); } @@ -1091,25 +1168,35 @@ fn NewSocket(comptime ssl: bool) type { // Add SNI support for TLS (mongodb and others requires this) if (comptime ssl) { - if (this.connection) |connection| { - if (connection == .host) { - const host = normalizeHost(connection.host.host); - if (host.len > 0) { - var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); - if (!ssl_ptr.isInitFinished()) { + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); + if (!ssl_ptr.isInitFinished()) { + if (this.connection) |connection| { + if (connection == .host) { + const host = normalizeHost(connection.host.host); + if (host.len > 0) { var host__ = default_allocator.dupeZ(u8, host) catch unreachable; defer default_allocator.free(host__); ssl_ptr.setHostname(host__); } } } + if (this.protos) |protos| { + if (this.handlers.is_server) { + BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this)); + } else { + _ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @intCast(c_uint, protos.len)); + } + } } } this.poll_ref.ref(this.handlers.vm); this.detached = false; this.socket = socket; - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this); + + if (this.wrapped == .none) { + socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this); + } const handlers = this.handlers; const callback = handlers.onOpen; @@ -1174,7 +1261,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1211,7 +1298,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1255,7 +1342,6 @@ fn NewSocket(comptime ssl: bool) type { log("onClose", .{}); this.detached = true; defer this.markInactive(); - const handlers = this.handlers; this.poll_ref.unref(handlers.vm); @@ -1265,7 +1351,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); var globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1295,7 +1381,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); // const encoding = handlers.encoding; const result = callback.callWithThis(globalObject, this_value, &[_]JSValue{ @@ -1476,10 +1562,20 @@ fn NewSocket(comptime ssl: bool) type { } fn writeMaybeCorked(this: *This, buffer: []const u8, is_end: bool) i32 { - if (this.socket.isShutdown() or this.socket.isClosed()) { + if (this.detached or this.socket.isShutdown() or this.socket.isClosed()) { return -1; } // we don't cork yet but we might later + + if (comptime ssl) { + // TLS wrapped but in TCP mode + if (this.wrapped == .tcp) { + const res = this.socket.rawWrite(buffer, is_end); + log("write({d}, {any}) = {d}", .{ buffer.len, is_end, res }); + return res; + } + } + const res = this.socket.write(buffer, is_end); log("write({d}, {any}) = {d}", .{ buffer.len, is_end, res }); return res; @@ -1487,7 +1583,6 @@ fn NewSocket(comptime ssl: bool) type { fn writeOrEnd(this: *This, globalObject: *JSC.JSGlobalObject, args: []const JSC.JSValue, is_end: bool) WriteResult { if (args.len == 0) return .{ .success = .{} }; - if (args.ptr[0].asArrayBuffer(globalObject)) |array_buffer| { var slice = array_buffer.slice(); @@ -1681,9 +1776,6 @@ fn NewSocket(comptime ssl: bool) type { if (result.wrote == result.total) { this.socket.flush(); this.detached = true; - if (!this.socket.isClosed()) { - this.socket.close(0, null); - } this.markInactive(); } break :brk JSValue.jsNumber(result.wrote); @@ -1706,17 +1798,27 @@ fn NewSocket(comptime ssl: bool) type { pub fn finalize(this: *This) callconv(.C) void { log("finalize()", .{}); - if (this.detached) return; - this.detached = true; - if (!this.socket.isClosed()) { - this.socket.close(0, null); + if (!this.detached) { + this.detached = true; + if (!this.socket.isClosed()) { + this.socket.close(0, null); + } + this.markInactive(); + } + + this.poll_ref.unref(JSC.VirtualMachine.get()); + // need to deinit event without being attached + if (this.owned_protos) { + if (this.protos) |protos| { + this.protos = null; + default_allocator.free(protos); + } } + if (this.connection) |connection| { - connection.deinit(); this.connection = null; + connection.deinit(); } - this.markInactive(); - this.poll_ref.unref(JSC.VirtualMachine.get()); } pub fn reload(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { @@ -1756,8 +1858,376 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } + + pub fn getALPNProtocol( + this: *This, + globalObject: *JSC.JSGlobalObject, + ) callconv(.C) JSValue { + if (comptime ssl == false) { + return JSValue.jsBoolean(false); + } + + if (this.detached) { + return JSValue.jsBoolean(false); + } + + var alpn_proto: [*c]const u8 = null; + var alpn_proto_len: u32 = 0; + + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); + BoringSSL.SSL_get0_alpn_selected(ssl_ptr, &alpn_proto, &alpn_proto_len); + if (alpn_proto == null or alpn_proto_len == 0) { + return JSValue.jsBoolean(false); + } + + const slice = alpn_proto[0..alpn_proto_len]; + if (strings.eql(slice, "h2")) { + return ZigString.static("h2").toValue(globalObject); + } + if (strings.eql(slice, "http/1.1")) { + return ZigString.static("http/1.1").toValue(globalObject); + } + return ZigString.fromUTF8(slice).toValueGC(globalObject); + } + + pub fn setServername( + this: *This, + globalObject: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + if (comptime ssl == false) { + return JSValue.jsUndefined(); + } + if (this.detached) { + return JSValue.jsUndefined(); + } + + if (this.handlers.is_server) { + globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); + return .zero; + } + + const args = callframe.arguments(1); + if (args.len < 1) { + globalObject.throw("Expected 1 argument", .{}); + return .zero; + } + + const server_name = args.ptr[0]; + if (!server_name.isString()) { + globalObject.throw("Expected \"serverName\" to be a string", .{}); + return .zero; + } + + const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); + defer slice.deinit(); + const host = normalizeHost(slice.slice()); + if (host.len > 0) { + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); + if (ssl_ptr.isInitFinished()) { + // match node.js exceptions + globalObject.throw("Already started.", .{}); + return .zero; + } + var host__ = default_allocator.dupeZ(u8, host) catch unreachable; + defer default_allocator.free(host__); + ssl_ptr.setHostname(host__); + } + + return JSValue.jsUndefined(); + } + + pub fn open( + this: *This, + _: *JSC.JSGlobalObject, + _: *JSC.CallFrame, + ) callconv(.C) JSValue { + JSC.markBinding(@src()); + this.socket.open(!this.handlers.is_server); + return JSValue.jsUndefined(); + } + + // this invalidates the current socket returning 2 new sockets + // one for non-TLS and another for TLS + // handlers for non-TLS are preserved + pub fn wrapTLS( + this: *This, + globalObject: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + JSC.markBinding(@src()); + if (comptime ssl) { + return JSValue.jsUndefined(); + } + + if (this.detached) { + return JSValue.jsUndefined(); + } + + const args = callframe.arguments(1); + + if (args.len < 1) { + globalObject.throw("Expected 1 arguments", .{}); + return .zero; + } + + var exception: JSC.C.JSValueRef = null; + + const opts = args.ptr[0]; + if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) { + globalObject.throw("Expected options object", .{}); + return .zero; + } + + var socket_obj = opts.get(globalObject, "socket") orelse { + globalObject.throw("Expected \"socket\" option", .{}); + return .zero; + }; + + var handlers = Handlers.fromJS(globalObject, socket_obj, &exception) orelse { + globalObject.throwValue(exception.?.value()); + return .zero; + }; + + var ssl_opts: ?JSC.API.ServerConfig.SSLConfig = null; + + if (opts.getTruthy(globalObject, "tls")) |tls| { + if (tls.isBoolean()) { + if (tls.toBoolean()) { + ssl_opts = JSC.API.ServerConfig.SSLConfig.zero; + } + } else { + if (JSC.API.ServerConfig.SSLConfig.inJS(globalObject, tls, &exception)) |ssl_config| { + ssl_opts = ssl_config; + } else if (exception != null) { + return .zero; + } + } + } + + if (ssl_opts == null) { + globalObject.throw("Expected \"tls\" option", .{}); + return .zero; + } + + var default_data = JSValue.zero; + if (opts.getTruthy(globalObject, "data")) |default_data_value| { + default_data = default_data_value; + default_data.ensureStillAlive(); + } + + var socket_config = ssl_opts.?; + defer socket_config.deinit(); + const options = socket_config.asUSockets(); + + const protos = socket_config.protos; + const protos_len = socket_config.protos_len; + + const ext_size = @sizeOf(WrappedSocket); + + var tls = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); + var handlers_ptr = handlers.vm.allocator.create(Handlers) catch @panic("OOM"); + handlers_ptr.* = handlers; + handlers_ptr.is_server = this.handlers.is_server; + handlers_ptr.protect(); + + tls.* = .{ + .handlers = handlers_ptr, + .this_value = .zero, + .socket = undefined, + .connection = if (this.connection) |c| c.clone() else null, + .wrapped = .tls, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, + }; + + var tls_js_value = tls.getThisValue(globalObject); + TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); + + const TCPHandler = NewWrappedHandler(false); + + // reconfigure context to use the new wrapper handlers + Socket.unsafeConfigure(this.socket.context(), true, true, WrappedSocket, TCPHandler); + const old_context = this.socket.context(); + const TLSHandler = NewWrappedHandler(true); + const new_socket = this.socket.wrapTLS( + options, + ext_size, + true, + WrappedSocket, + TLSHandler, + ) orelse { + handlers_ptr.unprotect(); + handlers.vm.allocator.destroy(handlers_ptr); + bun.default_allocator.destroy(tls); + return JSValue.jsUndefined(); + }; + tls.socket = new_socket; + + var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); + var raw_handlers_ptr = handlers.vm.allocator.create(Handlers) catch @panic("OOM"); + this.handlers.unprotect(); + + var cloned_handlers: Handlers = .{ + .vm = globalObject.bunVM(), + .globalObject = globalObject, + .onOpen = this.handlers.onOpen, + .onClose = this.handlers.onClose, + .onData = this.handlers.onData, + .onWritable = this.handlers.onWritable, + .onTimeout = this.handlers.onTimeout, + .onConnectError = this.handlers.onConnectError, + .onEnd = this.handlers.onEnd, + .onError = this.handlers.onError, + .onHandshake = this.handlers.onHandshake, + .binary_type = this.handlers.binary_type, + }; + + raw_handlers_ptr.* = cloned_handlers; + raw_handlers_ptr.is_server = this.handlers.is_server; + raw_handlers_ptr.protect(); + raw.* = .{ + .handlers = raw_handlers_ptr, + .this_value = .zero, + .socket = new_socket, + .connection = if (this.connection) |c| c.clone() else null, + .wrapped = .tcp, + .protos = null, + }; + + var raw_js_value = raw.getThisValue(globalObject); + if (JSSocketType(ssl).dataGetCached(this.getThisValue(globalObject))) |raw_default_data| { + raw_default_data.ensureStillAlive(); + TLSSocket.dataSetCached(raw_js_value, globalObject, raw_default_data); + } + // marks both as active + raw.markActive(); + // this will keep tls alive until socket.open() is called to start TLS certificate and the handshake process + // open is not immediately called because we need to set bunSocketInternal + tls.markActive(); + + // mark both instances on socket data + new_socket.ext(WrappedSocket).?.* = .{ .tcp = raw, .tls = tls }; + + //detach and invalidate the old instance + this.detached = true; + if (this.reffer.has) { + var vm = this.handlers.vm; + this.reffer.unref(vm); + old_context.deinit(ssl); + bun.default_allocator.destroy(this.handlers); + this.poll_ref.unref(vm); + this.has_pending_activity.store(false, .Release); + } + + const array = JSC.JSValue.createEmptyArray(globalObject, 2); + array.putIndex(globalObject, 0, raw_js_value); + array.putIndex(globalObject, 1, tls_js_value); + return array; + } }; } pub const TCPSocket = NewSocket(false); pub const TLSSocket = NewSocket(true); + +pub const WrappedSocket = extern struct { + // both shares the same socket but one behaves as TLS and the other as TCP + tls: *TLSSocket, + tcp: *TLSSocket, +}; + +pub fn NewWrappedHandler(comptime tls: bool) type { + const Socket = uws.NewSocketHandler(true); + return struct { + pub fn onOpen( + this: WrappedSocket, + socket: Socket, + ) void { + // only TLS will call onOpen + if (comptime tls) { + TLSSocket.onOpen(this.tls, socket); + } + } + + pub fn onEnd( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onEnd(this.tls, socket); + } else { + TLSSocket.onEnd(this.tcp, socket); + } + } + + pub fn onHandshake( + this: WrappedSocket, + socket: Socket, + success: i32, + ssl_error: uws.us_bun_verify_error_t, + ) void { + // only TLS will call onHandshake + if (comptime tls) { + TLSSocket.onHandshake(this.tls, socket, success, ssl_error); + } + } + + pub fn onClose( + this: WrappedSocket, + socket: Socket, + err: c_int, + data: ?*anyopaque, + ) void { + if (comptime tls) { + TLSSocket.onClose(this.tls, socket, err, data); + } else { + TLSSocket.onClose(this.tcp, socket, err, data); + } + } + + pub fn onData( + this: WrappedSocket, + socket: Socket, + data: []const u8, + ) void { + if (comptime tls) { + TLSSocket.onData(this.tls, socket, data); + } else { + TLSSocket.onData(this.tcp, socket, data); + } + } + + pub fn onWritable( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onWritable(this.tls, socket); + } else { + TLSSocket.onWritable(this.tcp, socket); + } + } + pub fn onTimeout( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onTimeout(this.tls, socket); + } else { + TLSSocket.onTimeout(this.tcp, socket); + } + } + + pub fn onConnectError( + this: WrappedSocket, + socket: Socket, + errno: c_int, + ) void { + if (comptime tls) { + TLSSocket.onConnectError(this.tls, socket, errno); + } else { + TLSSocket.onConnectError(this.tcp, socket, errno); + } + } + }; +} diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 140e62ce47c6b..f52c0830170d4 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -163,6 +163,8 @@ pub const ServerConfig = struct { request_cert: i32 = 0, reject_unauthorized: i32 = 0, ssl_ciphers: [*c]const u8 = null, + protos: [*c]const u8 = null, + protos_len: usize = 0, const log = Output.scoped(.SSLConfig, false); @@ -215,6 +217,7 @@ pub const ServerConfig = struct { "dh_params_file_name", "passphrase", "ssl_ciphers", + "protos", }; inline for (fields) |field| { @@ -270,6 +273,9 @@ pub const ServerConfig = struct { pub fn inJS(global: *JSC.JSGlobalObject, obj: JSC.JSValue, exception: JSC.C.ExceptionRef) ?SSLConfig { var result = zero; + var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); + defer arena.deinit(); + if (!obj.isObject()) { JSC.throwInvalidArguments("tls option expects an object", .{}, global, exception); return null; @@ -301,7 +307,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -317,7 +322,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -325,7 +329,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("key argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all keys result.key = native_array; result.deinit(); @@ -333,8 +336,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -356,7 +357,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -369,14 +369,11 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("key argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.key = native_array; result.deinit(); return null; } - - arena.deinit(); } } @@ -394,6 +391,22 @@ pub const ServerConfig = struct { } } + if (obj.getTruthy(global, "ALPNProtocols")) |protocols| { + if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), protocols, exception)) |sb| { + const sliced = sb.slice(); + if (sliced.len > 0) { + result.protos = bun.default_allocator.dupeZ(u8, sliced) catch unreachable; + result.protos_len = sliced.len; + } + + any = true; + } else { + global.throwInvalidArguments("ALPNProtocols argument must be an string, Buffer or TypedArray", .{}); + result.deinit(); + return null; + } + } + if (obj.getTruthy(global, "cert")) |js_obj| { if (js_obj.jsType().isArray()) { const count = js_obj.getLength(global); @@ -403,7 +416,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -419,7 +431,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -427,7 +438,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("cert argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.cert = native_array; result.deinit(); @@ -435,8 +445,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -458,7 +466,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -471,14 +478,11 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("cert argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.cert = native_array; result.deinit(); return null; } - - arena.deinit(); } } @@ -518,7 +522,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -534,7 +537,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -542,7 +544,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("ca argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -550,8 +551,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -573,7 +572,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -586,13 +584,11 @@ pub const ServerConfig = struct { } } else { JSC.throwInvalidArguments("ca argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}, global, exception); - arena.deinit(); // mark and free all certs result.ca = native_array; result.deinit(); return null; } - arena.deinit(); } } diff --git a/src/bun.js/api/sockets.classes.ts b/src/bun.js/api/sockets.classes.ts index da07741a30951..0c7847e198fe5 100644 --- a/src/bun.js/api/sockets.classes.ts +++ b/src/bun.js/api/sockets.classes.ts @@ -15,10 +15,21 @@ function generate(ssl) { authorized: { getter: "getAuthorized", }, + alpnProtocol: { + getter: "getALPNProtocol", + }, write: { fn: "write", length: 3, }, + wrapTLS: { + fn: "wrapTLS", + length: 1, + }, + open: { + fn: "open", + length: 0, + }, end: { fn: "end", length: 3, @@ -82,6 +93,11 @@ function generate(ssl) { fn: "reload", length: 1, }, + + setServername: { + fn: "setServername", + length: 1, + }, }, finalize: true, construct: true, diff --git a/src/bun.js/bindings/JSSink.cpp b/src/bun.js/bindings/JSSink.cpp index 19bf055991aa7..5f99d3792773a 100644 --- a/src/bun.js/bindings/JSSink.cpp +++ b/src/bun.js/bindings/JSSink.cpp @@ -1,6 +1,6 @@ // AUTO-GENERATED FILE. DO NOT EDIT. -// Generated by 'make generate-sink' at 2023-06-25T17:34:54.187Z +// Generated by 'make generate-sink' at 2023-07-02T16:19:51.440Z // To regenerate this file, run: // // make generate-sink diff --git a/src/bun.js/bindings/JSSink.h b/src/bun.js/bindings/JSSink.h index 9bf5554c44ca5..41d7065dcf69a 100644 --- a/src/bun.js/bindings/JSSink.h +++ b/src/bun.js/bindings/JSSink.h @@ -1,6 +1,6 @@ // AUTO-GENERATED FILE. DO NOT EDIT. -// Generated by 'make generate-sink' at 2023-06-25T17:34:54.186Z +// Generated by 'make generate-sink' at 2023-07-02T16:19:51.438Z // #pragma once diff --git a/src/bun.js/bindings/JSSinkLookupTable.h b/src/bun.js/bindings/JSSinkLookupTable.h index f8518bc5e1255..e4ed8162951f6 100644 --- a/src/bun.js/bindings/JSSinkLookupTable.h +++ b/src/bun.js/bindings/JSSinkLookupTable.h @@ -1,4 +1,4 @@ -// Automatically generated from src/bun.js/bindings/JSSink.cpp using /Users/silas/Workspace/opensource/bun/src/bun.js/WebKit/Source/JavaScriptCore/create_hash_table. DO NOT EDIT! +// Automatically generated from src/bun.js/bindings/JSSink.cpp using /home/cirospaciari/Repos/bun/src/bun.js/WebKit/Source/JavaScriptCore/create_hash_table. DO NOT EDIT! diff --git a/src/bun.js/bindings/ZigGeneratedClasses.cpp b/src/bun.js/bindings/ZigGeneratedClasses.cpp index b7461b5f0e50d..866970e4d93e2 100644 --- a/src/bun.js/bindings/ZigGeneratedClasses.cpp +++ b/src/bun.js/bindings/ZigGeneratedClasses.cpp @@ -16872,6 +16872,9 @@ extern "C" void* TCPSocketClass__construct(JSC::JSGlobalObject*, JSC::CallFrame* JSC_DECLARE_CUSTOM_GETTER(jsTCPSocketConstructor); extern "C" void TCPSocketClass__finalize(void*); +extern "C" JSC::EncodedJSValue TCPSocketPrototype__getALPNProtocol(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); +JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__alpnProtocolGetterWrap); + extern "C" JSC::EncodedJSValue TCPSocketPrototype__getAuthorized(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__authorizedGetterWrap); @@ -16896,6 +16899,9 @@ JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__listenerGetterWrap); extern "C" JSC::EncodedJSValue TCPSocketPrototype__getLocalPort(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__localPortGetterWrap); +extern "C" EncodedJSValue TCPSocketPrototype__open(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__openCallback); + extern "C" JSC::EncodedJSValue TCPSocketPrototype__getReadyState(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__readyStateGetterWrap); @@ -16908,6 +16914,9 @@ JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__reloadCallback); extern "C" JSC::EncodedJSValue TCPSocketPrototype__getRemoteAddress(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__remoteAddressGetterWrap); +extern "C" EncodedJSValue TCPSocketPrototype__setServername(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__setServernameCallback); + extern "C" EncodedJSValue TCPSocketPrototype__shutdown(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__shutdownCallback); @@ -16917,12 +16926,16 @@ JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__timeoutCallback); extern "C" EncodedJSValue TCPSocketPrototype__unref(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__unrefCallback); +extern "C" EncodedJSValue TCPSocketPrototype__wrapTLS(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__wrapTLSCallback); + extern "C" EncodedJSValue TCPSocketPrototype__write(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__writeCallback); STATIC_ASSERT_ISO_SUBSPACE_SHARABLE(JSTCPSocketPrototype, JSTCPSocketPrototype::Base); static const HashTableValue JSTCPSocketPrototypeTableValues[] = { + { "alpnProtocol"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__alpnProtocolGetterWrap, 0 } }, { "authorized"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__authorizedGetterWrap, 0 } }, { "data"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__dataGetterWrap, TCPSocketPrototype__dataSetterWrap } }, { "end"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__endCallback, 3 } }, @@ -16930,13 +16943,16 @@ static const HashTableValue JSTCPSocketPrototypeTableValues[] = { { "getAuthorizationError"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__getAuthorizationErrorCallback, 0 } }, { "listener"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__listenerGetterWrap, 0 } }, { "localPort"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__localPortGetterWrap, 0 } }, + { "open"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__openCallback, 0 } }, { "readyState"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__readyStateGetterWrap, 0 } }, { "ref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__refCallback, 0 } }, { "reload"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__reloadCallback, 1 } }, { "remoteAddress"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__remoteAddressGetterWrap, 0 } }, + { "setServername"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__setServernameCallback, 1 } }, { "shutdown"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__shutdownCallback, 1 } }, { "timeout"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__timeoutCallback, 1 } }, { "unref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__unrefCallback, 0 } }, + { "wrapTLS"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__wrapTLSCallback, 1 } }, { "write"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__writeCallback, 3 } } }; @@ -16954,6 +16970,18 @@ JSC_DEFINE_CUSTOM_GETTER(jsTCPSocketConstructor, (JSGlobalObject * lexicalGlobal return JSValue::encode(globalObject->JSTCPSocketConstructor()); } +JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__alpnProtocolGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) +{ + auto& vm = lexicalGlobalObject->vm(); + Zig::GlobalObject* globalObject = reinterpret_cast(lexicalGlobalObject); + auto throwScope = DECLARE_THROW_SCOPE(vm); + JSTCPSocket* thisObject = jsCast(JSValue::decode(thisValue)); + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + JSC::EncodedJSValue result = TCPSocketPrototype__getALPNProtocol(thisObject->wrapped(), globalObject); + RETURN_IF_EXCEPTION(throwScope, {}); + RELEASE_AND_RETURN(throwScope, result); +} + JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__authorizedGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17113,6 +17141,33 @@ JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__localPortGetterWrap, (JSGlobalObjec RELEASE_AND_RETURN(throwScope, result); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__openCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__open(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__readyStateGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17210,6 +17265,33 @@ extern "C" EncodedJSValue TCPSocketPrototype__remoteAddressGetCachedValue(JSC::E return JSValue::encode(thisObject->m_remoteAddress.get()); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__setServernameCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__setServername(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__shutdownCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17291,6 +17373,33 @@ JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__unrefCallback, (JSGlobalObject * le return TCPSocketPrototype__unref(thisObject->wrapped(), lexicalGlobalObject, callFrame); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__wrapTLSCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__wrapTLS(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__writeCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17479,6 +17588,9 @@ extern "C" void* TLSSocketClass__construct(JSC::JSGlobalObject*, JSC::CallFrame* JSC_DECLARE_CUSTOM_GETTER(jsTLSSocketConstructor); extern "C" void TLSSocketClass__finalize(void*); +extern "C" JSC::EncodedJSValue TLSSocketPrototype__getALPNProtocol(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); +JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__alpnProtocolGetterWrap); + extern "C" JSC::EncodedJSValue TLSSocketPrototype__getAuthorized(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__authorizedGetterWrap); @@ -17503,6 +17615,9 @@ JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__listenerGetterWrap); extern "C" JSC::EncodedJSValue TLSSocketPrototype__getLocalPort(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__localPortGetterWrap); +extern "C" EncodedJSValue TLSSocketPrototype__open(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__openCallback); + extern "C" JSC::EncodedJSValue TLSSocketPrototype__getReadyState(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__readyStateGetterWrap); @@ -17515,6 +17630,9 @@ JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__reloadCallback); extern "C" JSC::EncodedJSValue TLSSocketPrototype__getRemoteAddress(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__remoteAddressGetterWrap); +extern "C" EncodedJSValue TLSSocketPrototype__setServername(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__setServernameCallback); + extern "C" EncodedJSValue TLSSocketPrototype__shutdown(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__shutdownCallback); @@ -17524,12 +17642,16 @@ JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__timeoutCallback); extern "C" EncodedJSValue TLSSocketPrototype__unref(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__unrefCallback); +extern "C" EncodedJSValue TLSSocketPrototype__wrapTLS(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__wrapTLSCallback); + extern "C" EncodedJSValue TLSSocketPrototype__write(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__writeCallback); STATIC_ASSERT_ISO_SUBSPACE_SHARABLE(JSTLSSocketPrototype, JSTLSSocketPrototype::Base); static const HashTableValue JSTLSSocketPrototypeTableValues[] = { + { "alpnProtocol"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__alpnProtocolGetterWrap, 0 } }, { "authorized"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__authorizedGetterWrap, 0 } }, { "data"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__dataGetterWrap, TLSSocketPrototype__dataSetterWrap } }, { "end"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__endCallback, 3 } }, @@ -17537,13 +17659,16 @@ static const HashTableValue JSTLSSocketPrototypeTableValues[] = { { "getAuthorizationError"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__getAuthorizationErrorCallback, 0 } }, { "listener"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__listenerGetterWrap, 0 } }, { "localPort"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__localPortGetterWrap, 0 } }, + { "open"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__openCallback, 0 } }, { "readyState"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__readyStateGetterWrap, 0 } }, { "ref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__refCallback, 0 } }, { "reload"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__reloadCallback, 1 } }, { "remoteAddress"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__remoteAddressGetterWrap, 0 } }, + { "setServername"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__setServernameCallback, 1 } }, { "shutdown"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__shutdownCallback, 1 } }, { "timeout"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__timeoutCallback, 1 } }, { "unref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__unrefCallback, 0 } }, + { "wrapTLS"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__wrapTLSCallback, 1 } }, { "write"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__writeCallback, 3 } } }; @@ -17561,6 +17686,18 @@ JSC_DEFINE_CUSTOM_GETTER(jsTLSSocketConstructor, (JSGlobalObject * lexicalGlobal return JSValue::encode(globalObject->JSTLSSocketConstructor()); } +JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__alpnProtocolGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) +{ + auto& vm = lexicalGlobalObject->vm(); + Zig::GlobalObject* globalObject = reinterpret_cast(lexicalGlobalObject); + auto throwScope = DECLARE_THROW_SCOPE(vm); + JSTLSSocket* thisObject = jsCast(JSValue::decode(thisValue)); + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + JSC::EncodedJSValue result = TLSSocketPrototype__getALPNProtocol(thisObject->wrapped(), globalObject); + RETURN_IF_EXCEPTION(throwScope, {}); + RELEASE_AND_RETURN(throwScope, result); +} + JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__authorizedGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17720,6 +17857,33 @@ JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__localPortGetterWrap, (JSGlobalObjec RELEASE_AND_RETURN(throwScope, result); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__openCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__open(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__readyStateGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17817,6 +17981,33 @@ extern "C" EncodedJSValue TLSSocketPrototype__remoteAddressGetCachedValue(JSC::E return JSValue::encode(thisObject->m_remoteAddress.get()); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__setServernameCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__setServername(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__shutdownCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17898,6 +18089,33 @@ JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__unrefCallback, (JSGlobalObject * le return TLSSocketPrototype__unref(thisObject->wrapped(), lexicalGlobalObject, callFrame); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__wrapTLSCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__wrapTLS(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__writeCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); diff --git a/src/bun.js/bindings/generated_classes.zig b/src/bun.js/bindings/generated_classes.zig index 04a72d7ed0a83..a220b68140b1b 100644 --- a/src/bun.js/bindings/generated_classes.zig +++ b/src/bun.js/bindings/generated_classes.zig @@ -4426,6 +4426,9 @@ pub const JSTCPSocket = struct { @compileLog("TCPSocket.finalize is not a finalizer"); } + if (@TypeOf(TCPSocket.getALPNProtocol) != GetterType) + @compileLog("Expected TCPSocket.getALPNProtocol to be a getter"); + if (@TypeOf(TCPSocket.getAuthorized) != GetterType) @compileLog("Expected TCPSocket.getAuthorized to be a getter"); @@ -4446,6 +4449,8 @@ pub const JSTCPSocket = struct { if (@TypeOf(TCPSocket.getLocalPort) != GetterType) @compileLog("Expected TCPSocket.getLocalPort to be a getter"); + if (@TypeOf(TCPSocket.open) != CallbackType) + @compileLog("Expected TCPSocket.open to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.open))); if (@TypeOf(TCPSocket.getReadyState) != GetterType) @compileLog("Expected TCPSocket.getReadyState to be a getter"); @@ -4456,18 +4461,23 @@ pub const JSTCPSocket = struct { if (@TypeOf(TCPSocket.getRemoteAddress) != GetterType) @compileLog("Expected TCPSocket.getRemoteAddress to be a getter"); + if (@TypeOf(TCPSocket.setServername) != CallbackType) + @compileLog("Expected TCPSocket.setServername to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.setServername))); if (@TypeOf(TCPSocket.shutdown) != CallbackType) @compileLog("Expected TCPSocket.shutdown to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.shutdown))); if (@TypeOf(TCPSocket.timeout) != CallbackType) @compileLog("Expected TCPSocket.timeout to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.timeout))); if (@TypeOf(TCPSocket.unref) != CallbackType) @compileLog("Expected TCPSocket.unref to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.unref))); + if (@TypeOf(TCPSocket.wrapTLS) != CallbackType) + @compileLog("Expected TCPSocket.wrapTLS to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.wrapTLS))); if (@TypeOf(TCPSocket.write) != CallbackType) @compileLog("Expected TCPSocket.write to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.write))); if (!JSC.is_bindgen) { @export(TCPSocket.end, .{ .name = "TCPSocketPrototype__end" }); @export(TCPSocket.finalize, .{ .name = "TCPSocketClass__finalize" }); @export(TCPSocket.flush, .{ .name = "TCPSocketPrototype__flush" }); + @export(TCPSocket.getALPNProtocol, .{ .name = "TCPSocketPrototype__getALPNProtocol" }); @export(TCPSocket.getAuthorizationError, .{ .name = "TCPSocketPrototype__getAuthorizationError" }); @export(TCPSocket.getAuthorized, .{ .name = "TCPSocketPrototype__getAuthorized" }); @export(TCPSocket.getData, .{ .name = "TCPSocketPrototype__getData" }); @@ -4476,12 +4486,15 @@ pub const JSTCPSocket = struct { @export(TCPSocket.getReadyState, .{ .name = "TCPSocketPrototype__getReadyState" }); @export(TCPSocket.getRemoteAddress, .{ .name = "TCPSocketPrototype__getRemoteAddress" }); @export(TCPSocket.hasPendingActivity, .{ .name = "TCPSocket__hasPendingActivity" }); + @export(TCPSocket.open, .{ .name = "TCPSocketPrototype__open" }); @export(TCPSocket.ref, .{ .name = "TCPSocketPrototype__ref" }); @export(TCPSocket.reload, .{ .name = "TCPSocketPrototype__reload" }); @export(TCPSocket.setData, .{ .name = "TCPSocketPrototype__setData" }); + @export(TCPSocket.setServername, .{ .name = "TCPSocketPrototype__setServername" }); @export(TCPSocket.shutdown, .{ .name = "TCPSocketPrototype__shutdown" }); @export(TCPSocket.timeout, .{ .name = "TCPSocketPrototype__timeout" }); @export(TCPSocket.unref, .{ .name = "TCPSocketPrototype__unref" }); + @export(TCPSocket.wrapTLS, .{ .name = "TCPSocketPrototype__wrapTLS" }); @export(TCPSocket.write, .{ .name = "TCPSocketPrototype__write" }); } } @@ -4581,6 +4594,9 @@ pub const JSTLSSocket = struct { @compileLog("TLSSocket.finalize is not a finalizer"); } + if (@TypeOf(TLSSocket.getALPNProtocol) != GetterType) + @compileLog("Expected TLSSocket.getALPNProtocol to be a getter"); + if (@TypeOf(TLSSocket.getAuthorized) != GetterType) @compileLog("Expected TLSSocket.getAuthorized to be a getter"); @@ -4601,6 +4617,8 @@ pub const JSTLSSocket = struct { if (@TypeOf(TLSSocket.getLocalPort) != GetterType) @compileLog("Expected TLSSocket.getLocalPort to be a getter"); + if (@TypeOf(TLSSocket.open) != CallbackType) + @compileLog("Expected TLSSocket.open to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.open))); if (@TypeOf(TLSSocket.getReadyState) != GetterType) @compileLog("Expected TLSSocket.getReadyState to be a getter"); @@ -4611,18 +4629,23 @@ pub const JSTLSSocket = struct { if (@TypeOf(TLSSocket.getRemoteAddress) != GetterType) @compileLog("Expected TLSSocket.getRemoteAddress to be a getter"); + if (@TypeOf(TLSSocket.setServername) != CallbackType) + @compileLog("Expected TLSSocket.setServername to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.setServername))); if (@TypeOf(TLSSocket.shutdown) != CallbackType) @compileLog("Expected TLSSocket.shutdown to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.shutdown))); if (@TypeOf(TLSSocket.timeout) != CallbackType) @compileLog("Expected TLSSocket.timeout to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.timeout))); if (@TypeOf(TLSSocket.unref) != CallbackType) @compileLog("Expected TLSSocket.unref to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.unref))); + if (@TypeOf(TLSSocket.wrapTLS) != CallbackType) + @compileLog("Expected TLSSocket.wrapTLS to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.wrapTLS))); if (@TypeOf(TLSSocket.write) != CallbackType) @compileLog("Expected TLSSocket.write to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.write))); if (!JSC.is_bindgen) { @export(TLSSocket.end, .{ .name = "TLSSocketPrototype__end" }); @export(TLSSocket.finalize, .{ .name = "TLSSocketClass__finalize" }); @export(TLSSocket.flush, .{ .name = "TLSSocketPrototype__flush" }); + @export(TLSSocket.getALPNProtocol, .{ .name = "TLSSocketPrototype__getALPNProtocol" }); @export(TLSSocket.getAuthorizationError, .{ .name = "TLSSocketPrototype__getAuthorizationError" }); @export(TLSSocket.getAuthorized, .{ .name = "TLSSocketPrototype__getAuthorized" }); @export(TLSSocket.getData, .{ .name = "TLSSocketPrototype__getData" }); @@ -4631,12 +4654,15 @@ pub const JSTLSSocket = struct { @export(TLSSocket.getReadyState, .{ .name = "TLSSocketPrototype__getReadyState" }); @export(TLSSocket.getRemoteAddress, .{ .name = "TLSSocketPrototype__getRemoteAddress" }); @export(TLSSocket.hasPendingActivity, .{ .name = "TLSSocket__hasPendingActivity" }); + @export(TLSSocket.open, .{ .name = "TLSSocketPrototype__open" }); @export(TLSSocket.ref, .{ .name = "TLSSocketPrototype__ref" }); @export(TLSSocket.reload, .{ .name = "TLSSocketPrototype__reload" }); @export(TLSSocket.setData, .{ .name = "TLSSocketPrototype__setData" }); + @export(TLSSocket.setServername, .{ .name = "TLSSocketPrototype__setServername" }); @export(TLSSocket.shutdown, .{ .name = "TLSSocketPrototype__shutdown" }); @export(TLSSocket.timeout, .{ .name = "TLSSocketPrototype__timeout" }); @export(TLSSocket.unref, .{ .name = "TLSSocketPrototype__unref" }); + @export(TLSSocket.wrapTLS, .{ .name = "TLSSocketPrototype__wrapTLS" }); @export(TLSSocket.write, .{ .name = "TLSSocketPrototype__write" }); } } diff --git a/src/bun.js/bindings/webcore/JSEventEmitter.cpp b/src/bun.js/bindings/webcore/JSEventEmitter.cpp index 1957b404b7de4..231ae0db43a8c 100644 --- a/src/bun.js/bindings/webcore/JSEventEmitter.cpp +++ b/src/bun.js/bindings/webcore/JSEventEmitter.cpp @@ -149,7 +149,7 @@ static const HashTableValue JSEventEmitterPrototypeTableValues[] = { { "on"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_addListener, 2 } }, { "once"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_addOnceListener, 2 } }, { "prependListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependListener, 2 } }, - { "prependOnce"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependOnceListener, 2 } }, + { "prependOnceListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependOnceListener, 2 } }, { "removeListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeListener, 2 } }, { "off"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeListener, 2 } }, { "removeAllListeners"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeAllListeners, 1 } }, diff --git a/src/deps/uws b/src/deps/uws index d82c4a95de3af..875948226eede 160000 --- a/src/deps/uws +++ b/src/deps/uws @@ -1 +1 @@ -Subproject commit d82c4a95de3af01614ecb12bfff821611b4cc6b7 +Subproject commit 875948226eede72861a5170212ff6b43c4b7d7f9 diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 8ebe04ac0dee9..5dbe4f5d8a621 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -40,6 +40,129 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return us_socket_timeout(comptime ssl_int, this.socket, seconds); } + pub fn open(this: ThisSocket, is_client: bool) void { + _ = us_socket_open(comptime ssl_int, this.socket, @intFromBool(is_client), null, 0); + } + + // Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context + // context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext + pub fn wrapTLS( + this: ThisSocket, + options: us_bun_socket_context_options_t, + socket_ext_size: i32, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) ?NewSocketHandler(true) { + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + const TLSSocket = NewSocketHandler(true); + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(1, socket).?; + } + + if (comptime deref_) { + return (TLSSocket{ .socket = socket }).ext(ContextType).?.*; + } + + return (TLSSocket{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + TLSSocket{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + TLSSocket{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), TLSSocket{ .socket = socket }, success, verify_error); + } + }; + + var events: us_socket_events_t = .{}; + + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) + events.on_open = SocketHandler.on_open; + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) + events.on_close = SocketHandler.on_close; + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) + events.on_data = SocketHandler.on_data; + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) + events.on_writable = SocketHandler.on_writable; + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) + events.on_timeout = SocketHandler.on_timeout; + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) + events.on_connect_error = SocketHandler.on_connect_error; + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) + events.on_end = SocketHandler.on_end; + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) + events.on_handshake = SocketHandler.on_handshake; + + const socket = us_socket_wrap_with_tls(ssl_int, this.socket, options, events, socket_ext_size) orelse return null; + return NewSocketHandler(true).from(socket); + } + pub fn getNativeHandle(this: ThisSocket) *NativeSocketHandleType(ssl) { return @ptrCast(*NativeSocketHandleType(ssl), us_socket_get_native_handle(comptime ssl_int, this.socket).?); } @@ -95,6 +218,17 @@ pub fn NewSocketHandler(comptime ssl: bool) type { @as(i32, @intFromBool(msg_more)), ); } + + pub fn rawWrite(this: ThisSocket, data: []const u8, msg_more: bool) i32 { + return us_socket_raw_write( + comptime ssl_int, + this.socket, + data.ptr, + // truncate to 31 bits since sign bit exists + @intCast(i32, @truncate(u31, data.len)), + @as(i32, @intFromBool(msg_more)), + ); + } pub fn shutdown(this: ThisSocket) void { debug("us_socket_shutdown({d})", .{@intFromPtr(this.socket)}); return us_socket_shutdown( @@ -241,13 +375,126 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return socket_; } + pub fn unsafeConfigure( + ctx: *SocketContext, + comptime ssl_type: bool, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) void { + const SocketHandlerType = NewSocketHandler(ssl_type); + const ssl_type_int: i32 = @intFromBool(ssl_type); + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(ssl_type_int, socket).?; + } + + if (comptime deref_) { + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?.*; + } + + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + SocketHandlerType{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + SocketHandlerType{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), SocketHandlerType{ .socket = socket }, success, verify_error); + } + }; + + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) + us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) + us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) + us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) + us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) + us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) + us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) + us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) + us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); + } + pub fn configure( ctx: *SocketContext, comptime deref: bool, comptime ContextType: type, comptime Fields: anytype, ) void { - const @"type" = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; const SocketHandler = struct { const alignment = if (ContextType == anyopaque) @@ -333,21 +580,21 @@ pub fn NewSocketHandler(comptime ssl: bool) type { } }; - if (comptime @hasDecl(@"type", "onOpen") and @typeInfo(@TypeOf(@"type".onOpen)) != .Null) + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); - if (comptime @hasDecl(@"type", "onClose") and @typeInfo(@TypeOf(@"type".onClose)) != .Null) + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); - if (comptime @hasDecl(@"type", "onData") and @typeInfo(@TypeOf(@"type".onData)) != .Null) + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); - if (comptime @hasDecl(@"type", "onWritable") and @typeInfo(@TypeOf(@"type".onWritable)) != .Null) + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); - if (comptime @hasDecl(@"type", "onTimeout") and @typeInfo(@TypeOf(@"type".onTimeout)) != .Null) + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); - if (comptime @hasDecl(@"type", "onConnectError") and @typeInfo(@TypeOf(@"type".onConnectError)) != .Null) + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); - if (comptime @hasDecl(@"type", "onEnd") and @typeInfo(@TypeOf(@"type".onEnd)) != .Null) + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); - if (comptime @hasDecl(@"type", "onHandshake") and @typeInfo(@TypeOf(@"type".onHandshake)) != .Null) + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); } @@ -659,6 +906,20 @@ pub const us_bun_verify_error_t = extern struct { reason: [*c]const u8 = null, }; +pub const us_socket_events_t = extern struct { + on_open: ?*const fn (*Socket, i32, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_data: ?*const fn (*Socket, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_writable: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_close: ?*const fn (*Socket, i32, ?*anyopaque) callconv(.C) ?*Socket = null, + + on_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_long_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_end: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_connect_error: ?*const fn (*Socket, i32) callconv(.C) ?*Socket = null, + on_handshake: ?*const fn (*Socket, i32, us_bun_verify_error_t, ?*anyopaque) callconv(.C) void = null, +}; + +pub extern fn us_socket_wrap_with_tls(ssl: i32, s: *Socket, options: us_bun_socket_context_options_t, events: us_socket_events_t, socket_ext_size: i32) ?*Socket; extern fn us_socket_verify_error(ssl: i32, context: *Socket) us_bun_verify_error_t; extern fn SocketContextimestamp(ssl: i32, context: ?*SocketContext) c_ushort; pub extern fn us_socket_context_add_server_name(ssl: i32, context: ?*SocketContext, hostname_pattern: [*c]const u8, options: us_socket_context_options_t, ?*anyopaque) void; @@ -777,11 +1038,16 @@ extern fn us_socket_ext(ssl: i32, s: ?*Socket) ?*anyopaque; extern fn us_socket_context(ssl: i32, s: ?*Socket) ?*SocketContext; extern fn us_socket_flush(ssl: i32, s: ?*Socket) void; extern fn us_socket_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; +extern fn us_socket_raw_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; extern fn us_socket_shutdown(ssl: i32, s: ?*Socket) void; extern fn us_socket_shutdown_read(ssl: i32, s: ?*Socket) void; extern fn us_socket_is_shut_down(ssl: i32, s: ?*Socket) i32; extern fn us_socket_is_closed(ssl: i32, s: ?*Socket) i32; extern fn us_socket_close(ssl: i32, s: ?*Socket, code: i32, reason: ?*anyopaque) ?*Socket; +// if a TLS socket calls this, it will start SSL instance and call open event will also do TLS handshake if required +// will have no effect if the socket is closed or is not TLS +extern fn us_socket_open(ssl: i32, s: ?*Socket, is_client: i32, ip: [*c]const u8, ip_length: i32) ?*Socket; + extern fn us_socket_local_port(ssl: i32, s: ?*Socket) i32; extern fn us_socket_remote_address(ssl: i32, s: ?*Socket, buf: [*c]u8, length: [*c]i32) void; pub const uws_app_s = opaque {}; diff --git a/src/js/node/net.js b/src/js/node/net.js index 430a0dfa23733..1b7742dd17a53 100644 --- a/src/js/node/net.js +++ b/src/js/node/net.js @@ -64,6 +64,7 @@ const bunTlsSymbol = Symbol.for("::buntls::"); const bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"); const bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"); const bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"); +const bunSocketInternal = Symbol.for("::bunnetsocketinternal::"); var SocketClass; const Socket = (function (InternalSocket) { @@ -117,7 +118,7 @@ const Socket = (function (InternalSocket) { const self = socket.data; socket.timeout(self.timeout); socket.ref(); - self.#socket = socket; + self[bunSocketInternal] = socket; self.connecting = false; self.emit("connect", self); Socket.#Drain(socket); @@ -164,7 +165,7 @@ const Socket = (function (InternalSocket) { if (self.#closed) return; self.#closed = true; //socket cannot be used after close - self.#socket = null; + self[bunSocketInternal] = null; const queue = self.#readQueue; if (queue.isEmpty()) { if (self.push(null)) return; @@ -289,23 +290,33 @@ const Socket = (function (InternalSocket) { localAddress = "127.0.0.1"; #readQueue = createFIFO(); remotePort; - #socket; + [bunSocketInternal] = null; timeout = 0; #writeCallback; #writeChunk; #pendingRead; isServer = false; + _handle; + _parent; + _parentWrap; + #socket; constructor(options) { - const { signal, write, read, allowHalfOpen = false, ...opts } = options || {}; + const { socket, signal, write, read, allowHalfOpen = false, ...opts } = options || {}; super({ ...opts, allowHalfOpen, readable: true, writable: true, }); + this._handle = this; + this._parent = this; + this._parentWrap = this; this.#pendingRead = undefined; + if (socket instanceof Socket) { + this.#socket = socket; + } signal?.once("abort", () => this.destroy()); this.once("connect", () => this.emit("ready")); } @@ -327,7 +338,7 @@ const Socket = (function (InternalSocket) { socket.data = this; socket.timeout(this.timeout); socket.ref(); - this.#socket = socket; + this[bunSocketInternal] = socket; this.connecting = false; this.emit("connect", this); Socket.#Drain(socket); @@ -335,6 +346,7 @@ const Socket = (function (InternalSocket) { connect(port, host, connectListener) { var path; + var connection = this.#socket; if (typeof port === "string") { path = port; port = undefined; @@ -357,6 +369,7 @@ const Socket = (function (InternalSocket) { port, host, path, + socket, // TODOs localAddress, localPort, @@ -371,7 +384,11 @@ const Socket = (function (InternalSocket) { pauseOnConnect, servername, } = port; + this.servername = servername; + if (socket) { + connection = socket; + } } if (!pauseOnConnect) { @@ -399,41 +416,117 @@ const Socket = (function (InternalSocket) { } else { tls.rejectUnauthorized = rejectUnauthorized; tls.requestCert = true; + if (!connection && tls.socket) { + connection = tls.socket; + } + } + } + if (connection) { + if ( + typeof connection !== "object" || + !(connection instanceof Socket) || + typeof connection[bunTlsSymbol] === "function" + ) { + throw new TypeError("socket must be an instance of net.Socket"); } } - this.authorized = false; this.secureConnecting = true; this._secureEstablished = false; this._securePending = true; if (connectListener) this.on("secureConnect", connectListener); } else if (connectListener) this.on("connect", connectListener); - bunConnect( - path - ? { + // start using existing connection + + if (connection) { + const socket = connection[bunSocketInternal]; + + if (socket) { + const result = socket.wrapTLS({ + data: this, + tls, + socket: Socket.#Handlers, + }); + if (result) { + const [raw, tls] = result; + // replace socket + connection[bunSocketInternal] = raw; + raw.timeout(raw.timeout); + raw.connecting = false; + // set new socket + this[bunSocketInternal] = tls; + tls.timeout(tls.timeout); + tls.connecting = true; + this[bunSocketInternal] = socket; + // start tls + tls.open(); + } else { + this[bunSocketInternal] = null; + throw new Error("Invalid socket"); + } + } else { + // wait to be connected + connection.once("connect", () => { + const socket = connection[bunSocketInternal]; + if (!socket) return; + + const result = socket.wrapTLS({ data: this, - unix: path, - socket: Socket.#Handlers, tls, - } - : { - data: this, - hostname: host || "localhost", - port: port, socket: Socket.#Handlers, - tls, - }, - ); + }); + + if (result) { + const [raw, tls] = result; + // replace socket + connection[bunSocketInternal] = raw; + raw.timeout(raw.timeout); + raw.connecting = false; + // set new socket + this[bunSocketInternal] = tls; + tls.timeout(tls.timeout); + tls.connecting = true; + this[bunSocketInternal] = socket; + // start tls + tls.open(); + } else { + this[bunSocketInternal] = null; + throw new Error("Invalid socket"); + } + }); + } + } else if (path) { + // start using unix socket + bunConnect({ + data: this, + unix: path, + socket: Socket.#Handlers, + tls, + }).catch(error => { + this.emit("error", error); + }); + } else { + // default start + bunConnect({ + data: this, + hostname: host || "localhost", + port: port, + socket: Socket.#Handlers, + tls, + }).catch(error => { + this.emit("error", error); + }); + } return this; } _destroy(err, callback) { - this.#socket?.end(); + this[bunSocketInternal]?.end(); callback(err); } _final(callback) { - this.#socket?.end(); + this[bunSocketInternal]?.end(); callback(); } @@ -446,7 +539,7 @@ const Socket = (function (InternalSocket) { } get localPort() { - return this.#socket?.localPort; + return this[bunSocketInternal]?.localPort; } get pending() { @@ -472,11 +565,11 @@ const Socket = (function (InternalSocket) { } ref() { - this.#socket?.ref(); + this[bunSocketInternal]?.ref(); } get remoteAddress() { - return this.#socket?.remoteAddress; + return this[bunSocketInternal]?.remoteAddress; } get remoteFamily() { @@ -484,7 +577,7 @@ const Socket = (function (InternalSocket) { } resetAndDestroy() { - this.#socket?.end(); + this[bunSocketInternal]?.end(); } setKeepAlive(enable = false, initialDelay = 0) { @@ -498,19 +591,19 @@ const Socket = (function (InternalSocket) { } setTimeout(timeout, callback) { - this.#socket?.timeout(timeout); + this[bunSocketInternal]?.timeout(timeout); this.timeout = timeout; if (callback) this.once("timeout", callback); return this; } unref() { - this.#socket?.unref(); + this[bunSocketInternal]?.unref(); } _write(chunk, encoding, callback) { - if (typeof chunk == "string" && encoding !== "utf8") chunk = Buffer.from(chunk, encoding); - var written = this.#socket?.write(chunk); + if (typeof chunk == "string" && encoding !== "ascii") chunk = Buffer.from(chunk, encoding); + var written = this[bunSocketInternal]?.write(chunk); if (written == chunk.length) { callback(); } else if (this.#writeCallback) { diff --git a/src/js/node/tls.js b/src/js/node/tls.js index 356c25cbd848a..310a3662005c6 100644 --- a/src/js/node/tls.js +++ b/src/js/node/tls.js @@ -1,9 +1,30 @@ // Hardcoded module "node:tls" -import { isTypedArray } from "util/types"; +import { isArrayBufferView, isTypedArray } from "util/types"; import net, { Server as NetServer } from "node:net"; const InternalTCPSocket = net[Symbol.for("::bunternal::")]; - +const bunSocketInternal = Symbol.for("::bunnetsocketinternal::"); + +const { RegExp, Array, String } = globalThis[Symbol.for("Bun.lazy")]("primordials"); +const SymbolReplace = Symbol.replace; +const RegExpPrototypeSymbolReplace = RegExp.prototype[SymbolReplace]; +const RegExpPrototypeExec = RegExp.prototype.exec; + +const StringPrototypeStartsWith = String.prototype.startsWith; +const StringPrototypeSlice = String.prototype.slice; +const StringPrototypeIncludes = String.prototype.includes; +const StringPrototypeSplit = String.prototype.split; +const StringPrototypeIndexOf = String.prototype.indexOf; +const StringPrototypeSubstring = String.prototype.substring; +const StringPrototypeEndsWith = String.prototype.endsWith; + +const ArrayPrototypeIncludes = Array.prototype.includes; +const ArrayPrototypeJoin = Array.prototype.join; +const ArrayPrototypeForEach = Array.prototype.forEach; +const ArrayPrototypePush = Array.prototype.push; +const ArrayPrototypeSome = Array.prototype.some; +const ArrayPrototypeReduce = Array.prototype.reduce; function parseCertString() { + // Removed since JAN 2022 Node v18.0.0+ https://github.com/nodejs/node/pull/41479 throwNotImplemented("Not implemented"); } @@ -18,6 +39,164 @@ function isValidTLSArray(obj) { } } +function unfqdn(host) { + return RegExpPrototypeSymbolReplace(/[.]$/, host, ""); +} + +function splitHost(host) { + return StringPrototypeSplit.call(RegExpPrototypeSymbolReplace(/[A-Z]/g, unfqdn(host), toLowerCase), "."); +} + +function check(hostParts, pattern, wildcards) { + // Empty strings, null, undefined, etc. never match. + if (!pattern) return false; + + const patternParts = splitHost(pattern); + + if (hostParts.length !== patternParts.length) return false; + + // Pattern has empty components, e.g. "bad..example.com". + if (ArrayPrototypeIncludes.call(patternParts, "")) return false; + + // RFC 6125 allows IDNA U-labels (Unicode) in names but we have no + // good way to detect their encoding or normalize them so we simply + // reject them. Control characters and blanks are rejected as well + // because nothing good can come from accepting them. + const isBad = s => RegExpPrototypeExec.call(/[^\u0021-\u007F]/u, s) !== null; + if (ArrayPrototypeSome.call(patternParts, isBad)) return false; + + // Check host parts from right to left first. + for (let i = hostParts.length - 1; i > 0; i -= 1) { + if (hostParts[i] !== patternParts[i]) return false; + } + + const hostSubdomain = hostParts[0]; + const patternSubdomain = patternParts[0]; + const patternSubdomainParts = StringPrototypeSplit.call(patternSubdomain, "*"); + + // Short-circuit when the subdomain does not contain a wildcard. + // RFC 6125 does not allow wildcard substitution for components + // containing IDNA A-labels (Punycode) so match those verbatim. + if (patternSubdomainParts.length === 1 || StringPrototypeIncludes.call(patternSubdomain, "xn--")) + return hostSubdomain === patternSubdomain; + + if (!wildcards) return false; + + // More than one wildcard is always wrong. + if (patternSubdomainParts.length > 2) return false; + + // *.tld wildcards are not allowed. + if (patternParts.length <= 2) return false; + + const { 0: prefix, 1: suffix } = patternSubdomainParts; + + if (prefix.length + suffix.length > hostSubdomain.length) return false; + + if (!StringPrototypeStartsWith.call(hostSubdomain, prefix)) return false; + + if (!StringPrototypeEndsWith.call(hostSubdomain, suffix)) return false; + + return true; +} + +// This pattern is used to determine the length of escaped sequences within +// the subject alt names string. It allows any valid JSON string literal. +// This MUST match the JSON specification (ECMA-404 / RFC8259) exactly. +const jsonStringPattern = + // eslint-disable-next-line no-control-regex + /^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/; + +function splitEscapedAltNames(altNames) { + const result = []; + let currentToken = ""; + let offset = 0; + while (offset !== altNames.length) { + const nextSep = StringPrototypeIndexOf.call(altNames, ", ", offset); + const nextQuote = StringPrototypeIndexOf.call(altNames, '"', offset); + if (nextQuote !== -1 && (nextSep === -1 || nextQuote < nextSep)) { + // There is a quote character and there is no separator before the quote. + currentToken += StringPrototypeSubstring.call(altNames, offset, nextQuote); + const match = RegExpPrototypeExec.call(jsonStringPattern, StringPrototypeSubstring.call(altNames, nextQuote)); + if (!match) { + let error = new SyntaxError("ERR_TLS_CERT_ALTNAME_FORMAT: Invalid subject alternative name string"); + error.name = ERR_TLS_CERT_ALTNAME_FORMAT; + throw error; + } + currentToken += JSON.parse(match[0]); + offset = nextQuote + match[0].length; + } else if (nextSep !== -1) { + // There is a separator and no quote before it. + currentToken += StringPrototypeSubstring.call(altNames, offset, nextSep); + ArrayPrototypePush.call(result, currentToken); + currentToken = ""; + offset = nextSep + 2; + } else { + currentToken += StringPrototypeSubstring.call(altNames, offset); + offset = altNames.length; + } + } + ArrayPrototypePush.call(result, currentToken); + return result; +} +function checkServerIdentity(hostname, cert) { + const subject = cert.subject; + const altNames = cert.subjectaltname; + const dnsNames = []; + const ips = []; + + hostname = "" + hostname; + + if (altNames) { + const splitAltNames = StringPrototypeIncludes.call(altNames, '"') + ? splitEscapedAltNames(altNames) + : StringPrototypeSplit.call(altNames, ", "); + ArrayPrototypeForEach.call(splitAltNames, name => { + if (StringPrototypeStartsWith.call(name, "DNS:")) { + ArrayPrototypePush.call(dnsNames, StringPrototypeSlice.call(name, 4)); + } else if (StringPrototypeStartsWith.call(name, "IP Address:")) { + ArrayPrototypePush.call(ips, canonicalizeIP(StringPrototypeSlice.call(name, 11))); + } + }); + } + + let valid = false; + let reason = "Unknown reason"; + + hostname = unfqdn(hostname); // Remove trailing dot for error messages. + + if (net.isIP(hostname)) { + valid = ArrayPrototypeIncludes.call(ips, canonicalizeIP(hostname)); + if (!valid) reason = `IP: ${hostname} is not in the cert's list: ` + ArrayPrototypeJoin.call(ips, ", "); + } else if (dnsNames.length > 0 || subject?.CN) { + const hostParts = splitHost(hostname); + const wildcard = pattern => check(hostParts, pattern, true); + + if (dnsNames.length > 0) { + valid = ArrayPrototypeSome.call(dnsNames, wildcard); + if (!valid) reason = `Host: ${hostname}. is not in the cert's altnames: ${altNames}`; + } else { + // Match against Common Name only if no supported identifiers exist. + const cn = subject.CN; + + if (ArrayIsArray(cn)) valid = ArrayPrototypeSome.call(cn, wildcard); + else if (cn) valid = wildcard(cn); + + if (!valid) reason = `Host: ${hostname}. is not cert's CN: ${cn}`; + } + } else { + reason = "Cert does not contain a DNS name"; + } + + if (!valid) { + let error = new Error(`ERR_TLS_CERT_ALTNAME_INVALID: Hostname/IP does not match certificate's altnames: ${reason}`); + error.name = "ERR_TLS_CERT_ALTNAME_INVALID"; + error.reason = reason; + error.host = host; + error.cert = cert; + return error; + } +} + var InternalSecureContext = class SecureContext { context; @@ -83,6 +262,36 @@ function createSecureContext(options) { return new SecureContext(options); } +// Translate some fields from the handle's C-friendly format into more idiomatic +// javascript object representations before passing them back to the user. Can +// be used on any cert object, but changing the name would be semver-major. +function translatePeerCertificate(c) { + if (!c) return null; + + if (c.issuerCertificate != null && c.issuerCertificate !== c) { + c.issuerCertificate = translatePeerCertificate(c.issuerCertificate); + } + if (c.infoAccess != null) { + const info = c.infoAccess; + c.infoAccess = { __proto__: null }; + + // XXX: More key validation? + RegExpPrototypeSymbolReplace(/([^\n:]*):([^\n]*)(?:\n|$)/g, info, (all, key, val) => { + if (val.charCodeAt(0) === 0x22) { + // The translatePeerCertificate function is only + // used on internally created legacy certificate + // objects, and any value that contains a quote + // will always be a valid JSON string literal, + // so this should never throw. + val = JSONParse(val); + } + if (key in c.infoAccess) ArrayPrototypePush.call(c.infoAccess[key], val); + else c.infoAccess[key] = [val]; + }); + } + return c; +} + const buntls = Symbol.for("::buntls::"); var SocketClass; @@ -107,8 +316,22 @@ const TLSSocket = (function (InternalTLSSocket) { })( class TLSSocket extends InternalTCPSocket { #secureContext; - constructor(options) { - super(options); + ALPNProtocols; + #socket; + + constructor(socket, options) { + super(socket instanceof InternalTCPSocket ? options : options || socket); + options = options || socket || {}; + if (typeof options === "object") { + const { ALPNProtocols } = options; + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, this); + } + if (socket instanceof InternalTCPSocket) { + this.#socket = socket; + } + } + this.#secureContext = options.secureContext || createSecureContext(options); this.authorized = false; this.secureConnecting = true; @@ -123,28 +346,52 @@ const TLSSocket = (function (InternalTLSSocket) { secureConnecting = false; _SNICallback; servername; - alpnProtocol; authorized = false; authorizationError; encrypted = true; - exportKeyingMaterial() { - throw Error("Not implented in Bun yet"); + _start() { + // some frameworks uses this _start internal implementation is suposed to start TLS handshake + // on Bun we auto start this after on_open callback and when wrapping we start it after the socket is attached to the net.Socket/tls.Socket } - setMaxSendFragment() { + + exportKeyingMaterial(length, label, context) { + //SSL_export_keying_material throw Error("Not implented in Bun yet"); } - setServername() { + setMaxSendFragment(size) { + // SSL_set_max_send_fragment throw Error("Not implented in Bun yet"); } + setServername(name) { + if (this.isServer) { + let error = new Error("ERR_TLS_SNI_FROM_SERVER: Cannot issue SNI from a TLS server-side socket"); + error.name = "ERR_TLS_SNI_FROM_SERVER"; + throw error; + } + // if the socket is detached we can't set the servername but we set this property so when open will auto set to it + this.servername = name; + this[bunSocketInternal]?.setServername(name); + } setSession() { throw Error("Not implented in Bun yet"); } getPeerCertificate() { + // need to implement peerCertificate on socket.zig + // const cert = this[bunSocketInternal]?.peerCertificate; + // if(cert) { + // return translatePeerCertificate(cert); + // } throw Error("Not implented in Bun yet"); } getCertificate() { + // need to implement certificate on socket.zig + // const cert = this[bunSocketInternal]?.certificate; + // if(cert) { + // It's not a peer cert, but the formatting is identical. + // return translatePeerCertificate(cert); + // } throw Error("Not implented in Bun yet"); } getPeerX509Certificate() { @@ -154,16 +401,17 @@ const TLSSocket = (function (InternalTLSSocket) { throw Error("Not implented in Bun yet"); } - [buntls](port, host) { - var { servername } = this; - if (servername) { - return { - serverName: typeof servername === "string" ? servername : host, - ...this.#secureContext, - }; - } + get alpnProtocol() { + return this[bunSocketInternal]?.alpnProtocol; + } - return true; + [buntls](port, host) { + return { + socket: this.#socket, + ALPNProtocols: this.ALPNProtocols, + serverName: this.servername || host || "localhost", + ...this.#secureContext, + }; } }, ); @@ -177,9 +425,12 @@ class Server extends NetServer { _rejectUnauthorized; _requestCert; servername; + ALPNProtocols; + #checkServerIdentity; constructor(options, secureConnectionListener) { super(options, secureConnectionListener); + this.#checkServerIdentity = options?.checkServerIdentity || checkServerIdentity; this.setSecureContext(options); } emit(event, args) { @@ -197,6 +448,12 @@ class Server extends NetServer { options = options.context; } if (options) { + const { ALPNProtocols } = options; + + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, this); + } + let key = options.key; if (key) { if (!isValidTLSArray(key)) { @@ -277,6 +534,8 @@ class Server extends NetServer { // Client always is NONE on set_verify rejectUnauthorized: isClient ? false : this._rejectUnauthorized, requestCert: isClient ? false : this._requestCert, + ALPNProtocols: this.ALPNProtocols, + checkServerIdentity: this.#checkServerIdentity, }, SocketClass, ]; @@ -296,6 +555,11 @@ const CLIENT_RENEG_LIMIT = 3, DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host, connectListener) => { if (typeof port === "object") { + port.checkServerIdentity || checkServerIdentity; + const { ALPNProtocols } = port; + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, port); + } // port is option pass Socket options and let connect handle connection options return new TLSSocket(port).connect(port, host, connectListener); } @@ -312,7 +576,55 @@ function getCurves() { return; } -function convertALPNProtocols(protocols, out) {} +// Convert protocols array into valid OpenSSL protocols list +// ("\x06spdy/2\x08http/1.1\x08http/1.0") +function convertProtocols(protocols) { + const lens = new Array(protocols.length); + const buff = Buffer.allocUnsafe( + ArrayPrototypeReduce.call( + protocols, + (p, c, i) => { + const len = Buffer.byteLength(c); + if (len > 255) { + throw new RangeError( + "The byte length of the protocol at index " + `${i} exceeds the maximum length.`, + "<= 255", + len, + true, + ); + } + lens[i] = len; + return p + 1 + len; + }, + 0, + ), + ); + + let offset = 0; + for (let i = 0, c = protocols.length; i < c; i++) { + buff[offset++] = lens[i]; + buff.write(protocols[i], offset); + offset += lens[i]; + } + + return buff; +} + +function convertALPNProtocols(protocols, out) { + // If protocols is Array - translate it into buffer + if (Array.isArray(protocols)) { + out.ALPNProtocols = convertProtocols(protocols); + } else if (isTypedArray(protocols)) { + // Copy new buffer not to be modified by user. + out.ALPNProtocols = Buffer.from(protocols); + } else if (isArrayBufferView(protocols)) { + out.ALPNProtocols = Buffer.from( + protocols.buffer.slice(protocols.byteOffset, protocols.byteOffset + protocols.byteLength), + ); + } else if (Buffer.isBuffer(protocols)) { + out.ALPNProtocols = protocols; + } +} var exports = { [Symbol.for("CommonJS")]: 0, @@ -351,6 +663,7 @@ export { getCurves, parseCertString, SecureContext, + checkServerIdentity, Server, TLSSocket, exports as default, diff --git a/src/js/out/modules/node/net.js b/src/js/out/modules/node/net.js index 164ec66774c0a..c34f86b04bce4 100644 --- a/src/js/out/modules/node/net.js +++ b/src/js/out/modules/node/net.js @@ -26,7 +26,7 @@ var isIPv4 = function(s) { self.emit("listening"); }, createServer = function(options, connectionListener) { return new Server(options, connectionListener); -}, v4Seg = "(?:[0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])", v4Str = `(${v4Seg}[.]){3}${v4Seg}`, IPv4Reg = new RegExp(`^${v4Str}$`), v6Seg = "(?:[0-9a-fA-F]{1,4})", IPv6Reg = new RegExp("^(" + `(?:${v6Seg}:){7}(?:${v6Seg}|:)|` + `(?:${v6Seg}:){6}(?:${v4Str}|:${v6Seg}|:)|` + `(?:${v6Seg}:){5}(?::${v4Str}|(:${v6Seg}){1,2}|:)|` + `(?:${v6Seg}:){4}(?:(:${v6Seg}){0,1}:${v4Str}|(:${v6Seg}){1,3}|:)|` + `(?:${v6Seg}:){3}(?:(:${v6Seg}){0,2}:${v4Str}|(:${v6Seg}){1,4}|:)|` + `(?:${v6Seg}:){2}(?:(:${v6Seg}){0,3}:${v4Str}|(:${v6Seg}){1,5}|:)|` + `(?:${v6Seg}:){1}(?:(:${v6Seg}){0,4}:${v4Str}|(:${v6Seg}){1,6}|:)|` + `(?::((?::${v6Seg}){0,5}:${v4Str}|(?::${v6Seg}){1,7}|:))` + ")(%[0-9a-zA-Z-.:]{1,})?$"), { Bun, createFIFO, Object } = globalThis[Symbol.for("Bun.lazy")]("primordials"), { connect: bunConnect } = Bun, { setTimeout } = globalThis, bunTlsSymbol = Symbol.for("::buntls::"), bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"), bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"), bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"), SocketClass, Socket = function(InternalSocket) { +}, v4Seg = "(?:[0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])", v4Str = `(${v4Seg}[.]){3}${v4Seg}`, IPv4Reg = new RegExp(`^${v4Str}$`), v6Seg = "(?:[0-9a-fA-F]{1,4})", IPv6Reg = new RegExp("^(" + `(?:${v6Seg}:){7}(?:${v6Seg}|:)|` + `(?:${v6Seg}:){6}(?:${v4Str}|:${v6Seg}|:)|` + `(?:${v6Seg}:){5}(?::${v4Str}|(:${v6Seg}){1,2}|:)|` + `(?:${v6Seg}:){4}(?:(:${v6Seg}){0,1}:${v4Str}|(:${v6Seg}){1,3}|:)|` + `(?:${v6Seg}:){3}(?:(:${v6Seg}){0,2}:${v4Str}|(:${v6Seg}){1,4}|:)|` + `(?:${v6Seg}:){2}(?:(:${v6Seg}){0,3}:${v4Str}|(:${v6Seg}){1,5}|:)|` + `(?:${v6Seg}:){1}(?:(:${v6Seg}){0,4}:${v4Str}|(:${v6Seg}){1,6}|:)|` + `(?::((?::${v6Seg}){0,5}:${v4Str}|(?::${v6Seg}){1,7}|:))` + ")(%[0-9a-zA-Z-.:]{1,})?$"), { Bun, createFIFO, Object } = globalThis[Symbol.for("Bun.lazy")]("primordials"), { connect: bunConnect } = Bun, { setTimeout } = globalThis, bunTlsSymbol = Symbol.for("::buntls::"), bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"), bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"), bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"), bunSocketInternal = Symbol.for("::bunnetsocketinternal::"), SocketClass, Socket = function(InternalSocket) { return SocketClass = InternalSocket, Object.defineProperty(SocketClass.prototype, Symbol.toStringTag, { value: "Socket", enumerable: !1 @@ -62,7 +62,7 @@ var isIPv4 = function(s) { }, open(socket) { const self = socket.data; - socket.timeout(self.timeout), socket.ref(), self.#socket = socket, self.connecting = !1, self.emit("connect", self), Socket2.#Drain(socket); + socket.timeout(self.timeout), socket.ref(), self[bunSocketInternal] = socket, self.connecting = !1, self.emit("connect", self), Socket2.#Drain(socket); }, handshake(socket, success, verifyError) { const { data: self } = socket; @@ -87,7 +87,7 @@ var isIPv4 = function(s) { const self = socket.data; if (self.#closed) return; - self.#closed = !0, self.#socket = null; + self.#closed = !0, self[bunSocketInternal] = null; const queue = self.#readQueue; if (queue.isEmpty()) { if (self.push(null)) @@ -163,21 +163,27 @@ var isIPv4 = function(s) { localAddress = "127.0.0.1"; #readQueue = createFIFO(); remotePort; - #socket; + [bunSocketInternal] = null; timeout = 0; #writeCallback; #writeChunk; #pendingRead; isServer = !1; + _handle; + _parent; + _parentWrap; + #socket; constructor(options) { - const { signal, write, read, allowHalfOpen = !1, ...opts } = options || {}; + const { socket, signal, write, read, allowHalfOpen = !1, ...opts } = options || {}; super({ ...opts, allowHalfOpen, readable: !0, writable: !0 }); - this.#pendingRead = void 0, signal?.once("abort", () => this.destroy()), this.once("connect", () => this.emit("ready")); + if (this._handle = this, this._parent = this, this._parentWrap = this, this.#pendingRead = void 0, socket instanceof Socket2) + this.#socket = socket; + signal?.once("abort", () => this.destroy()), this.once("connect", () => this.emit("ready")); } address() { return { @@ -190,10 +196,10 @@ var isIPv4 = function(s) { return this.writableLength; } #attach(port, socket) { - this.remotePort = port, socket.data = this, socket.timeout(this.timeout), socket.ref(), this.#socket = socket, this.connecting = !1, this.emit("connect", this), Socket2.#Drain(socket); + this.remotePort = port, socket.data = this, socket.timeout(this.timeout), socket.ref(), this[bunSocketInternal] = socket, this.connecting = !1, this.emit("connect", this), Socket2.#Drain(socket); } connect(port, host, connectListener) { - var path; + var path, connection = this.#socket; if (typeof port === "string") { if (path = port, port = void 0, typeof host === "function") connectListener = host, host = void 0; @@ -207,6 +213,7 @@ var isIPv4 = function(s) { port, host, path, + socket, localAddress, localPort, family, @@ -220,7 +227,8 @@ var isIPv4 = function(s) { pauseOnConnect, servername } = port; - this.servername = servername; + if (this.servername = servername, socket) + connection = socket; } if (!pauseOnConnect) this.resume(); @@ -228,36 +236,78 @@ var isIPv4 = function(s) { const bunTLS = this[bunTlsSymbol]; var tls = void 0; if (typeof bunTLS === "function") { - if (tls = bunTLS.call(this, port, host, !0), this._requestCert = !0, this._rejectUnauthorized = rejectUnauthorized, tls) + if (tls = bunTLS.call(this, port, host, !0), this._requestCert = !0, this._rejectUnauthorized = rejectUnauthorized, tls) { if (typeof tls !== "object") tls = { rejectUnauthorized, requestCert: !0 }; - else - tls.rejectUnauthorized = rejectUnauthorized, tls.requestCert = !0; + else if (tls.rejectUnauthorized = rejectUnauthorized, tls.requestCert = !0, !connection && tls.socket) + connection = tls.socket; + } + if (connection) { + if (typeof connection !== "object" || !(connection instanceof Socket2) || typeof connection[bunTlsSymbol] === "function") + throw new TypeError("socket must be an instance of net.Socket"); + } if (this.authorized = !1, this.secureConnecting = !0, this._secureEstablished = !1, this._securePending = !0, connectListener) this.on("secureConnect", connectListener); } else if (connectListener) this.on("connect", connectListener); - return bunConnect(path ? { - data: this, - unix: path, - socket: Socket2.#Handlers, - tls - } : { - data: this, - hostname: host || "localhost", - port, - socket: Socket2.#Handlers, - tls - }), this; + if (connection) { + const socket2 = connection[bunSocketInternal]; + if (socket2) { + const result = socket2.wrapTLS({ + data: this, + tls, + socket: Socket2.#Handlers + }); + if (result) { + const [raw, tls2] = result; + connection[bunSocketInternal] = raw, raw.timeout(raw.timeout), raw.connecting = !1, this[bunSocketInternal] = tls2, tls2.timeout(tls2.timeout), tls2.connecting = !0, this[bunSocketInternal] = socket2, tls2.open(); + } else + throw this[bunSocketInternal] = null, new Error("Invalid socket"); + } else + connection.once("connect", () => { + const socket3 = connection[bunSocketInternal]; + if (!socket3) + return; + const result = socket3.wrapTLS({ + data: this, + tls, + socket: Socket2.#Handlers + }); + if (result) { + const [raw, tls2] = result; + connection[bunSocketInternal] = raw, raw.timeout(raw.timeout), raw.connecting = !1, this[bunSocketInternal] = tls2, tls2.timeout(tls2.timeout), tls2.connecting = !0, this[bunSocketInternal] = socket3, tls2.open(); + } else + throw this[bunSocketInternal] = null, new Error("Invalid socket"); + }); + } else if (path) + bunConnect({ + data: this, + unix: path, + socket: Socket2.#Handlers, + tls + }).catch((error) => { + this.emit("error", error); + }); + else + bunConnect({ + data: this, + hostname: host || "localhost", + port, + socket: Socket2.#Handlers, + tls + }).catch((error) => { + this.emit("error", error); + }); + return this; } _destroy(err, callback) { - this.#socket?.end(), callback(err); + this[bunSocketInternal]?.end(), callback(err); } _final(callback) { - this.#socket?.end(), callback(); + this[bunSocketInternal]?.end(), callback(); } get localAddress() { return "127.0.0.1"; @@ -266,7 +316,7 @@ var isIPv4 = function(s) { return "IPv4"; } get localPort() { - return this.#socket?.localPort; + return this[bunSocketInternal]?.localPort; } get pending() { return this.connecting; @@ -289,16 +339,16 @@ var isIPv4 = function(s) { return this.writable ? "writeOnly" : "closed"; } ref() { - this.#socket?.ref(); + this[bunSocketInternal]?.ref(); } get remoteAddress() { - return this.#socket?.remoteAddress; + return this[bunSocketInternal]?.remoteAddress; } get remoteFamily() { return "IPv4"; } resetAndDestroy() { - this.#socket?.end(); + this[bunSocketInternal]?.end(); } setKeepAlive(enable = !1, initialDelay = 0) { return this; @@ -307,17 +357,17 @@ var isIPv4 = function(s) { return this; } setTimeout(timeout, callback) { - if (this.#socket?.timeout(timeout), this.timeout = timeout, callback) + if (this[bunSocketInternal]?.timeout(timeout), this.timeout = timeout, callback) this.once("timeout", callback); return this; } unref() { - this.#socket?.unref(); + this[bunSocketInternal]?.unref(); } _write(chunk, encoding, callback) { - if (typeof chunk == "string" && encoding !== "utf8") + if (typeof chunk == "string" && encoding !== "ascii") chunk = Buffer.from(chunk, encoding); - var written = this.#socket?.write(chunk); + var written = this[bunSocketInternal]?.write(chunk); if (written == chunk.length) callback(); else if (this.#writeCallback) diff --git a/src/js/out/modules/node/tls.js b/src/js/out/modules/node/tls.js index 4cceadc7fe286..ca8a13270221d 100644 --- a/src/js/out/modules/node/tls.js +++ b/src/js/out/modules/node/tls.js @@ -1,4 +1,4 @@ -import {isTypedArray} from "node:util/types"; +import {isArrayBufferView, isTypedArray} from "node:util/types"; import net, {Server as NetServer} from "node:net"; var parseCertString = function() { throwNotImplemented("Not implemented"); @@ -11,18 +11,127 @@ var parseCertString = function() { return !1; return !0; } +}, unfqdn = function(host2) { + return RegExpPrototypeSymbolReplace(/[.]$/, host2, ""); +}, splitHost = function(host2) { + return StringPrototypeSplit.call(RegExpPrototypeSymbolReplace(/[A-Z]/g, unfqdn(host2), toLowerCase), "."); +}, check = function(hostParts, pattern, wildcards) { + if (!pattern) + return !1; + const patternParts = splitHost(pattern); + if (hostParts.length !== patternParts.length) + return !1; + if (ArrayPrototypeIncludes.call(patternParts, "")) + return !1; + const isBad = (s) => RegExpPrototypeExec.call(/[^\u0021-\u007F]/u, s) !== null; + if (ArrayPrototypeSome.call(patternParts, isBad)) + return !1; + for (let i = hostParts.length - 1;i > 0; i -= 1) + if (hostParts[i] !== patternParts[i]) + return !1; + const hostSubdomain = hostParts[0], patternSubdomain = patternParts[0], patternSubdomainParts = StringPrototypeSplit.call(patternSubdomain, "*"); + if (patternSubdomainParts.length === 1 || StringPrototypeIncludes.call(patternSubdomain, "xn--")) + return hostSubdomain === patternSubdomain; + if (!wildcards) + return !1; + if (patternSubdomainParts.length > 2) + return !1; + if (patternParts.length <= 2) + return !1; + const { 0: prefix, 1: suffix } = patternSubdomainParts; + if (prefix.length + suffix.length > hostSubdomain.length) + return !1; + if (!StringPrototypeStartsWith.call(hostSubdomain, prefix)) + return !1; + if (!StringPrototypeEndsWith.call(hostSubdomain, suffix)) + return !1; + return !0; +}, splitEscapedAltNames = function(altNames) { + const result = []; + let currentToken = "", offset = 0; + while (offset !== altNames.length) { + const nextSep = StringPrototypeIndexOf.call(altNames, ", ", offset), nextQuote = StringPrototypeIndexOf.call(altNames, '"', offset); + if (nextQuote !== -1 && (nextSep === -1 || nextQuote < nextSep)) { + currentToken += StringPrototypeSubstring.call(altNames, offset, nextQuote); + const match = RegExpPrototypeExec.call(jsonStringPattern, StringPrototypeSubstring.call(altNames, nextQuote)); + if (!match) { + let error = new SyntaxError("ERR_TLS_CERT_ALTNAME_FORMAT: Invalid subject alternative name string"); + throw error.name = ERR_TLS_CERT_ALTNAME_FORMAT, error; + } + currentToken += JSON.parse(match[0]), offset = nextQuote + match[0].length; + } else if (nextSep !== -1) + currentToken += StringPrototypeSubstring.call(altNames, offset, nextSep), ArrayPrototypePush.call(result, currentToken), currentToken = "", offset = nextSep + 2; + else + currentToken += StringPrototypeSubstring.call(altNames, offset), offset = altNames.length; + } + return ArrayPrototypePush.call(result, currentToken), result; +}, checkServerIdentity = function(hostname, cert) { + const { subject, subjectaltname: altNames } = cert, dnsNames = [], ips = []; + if (hostname = "" + hostname, altNames) { + const splitAltNames = StringPrototypeIncludes.call(altNames, '"') ? splitEscapedAltNames(altNames) : StringPrototypeSplit.call(altNames, ", "); + ArrayPrototypeForEach.call(splitAltNames, (name) => { + if (StringPrototypeStartsWith.call(name, "DNS:")) + ArrayPrototypePush.call(dnsNames, StringPrototypeSlice.call(name, 4)); + else if (StringPrototypeStartsWith.call(name, "IP Address:")) + ArrayPrototypePush.call(ips, canonicalizeIP(StringPrototypeSlice.call(name, 11))); + }); + } + let valid = !1, reason = "Unknown reason"; + if (hostname = unfqdn(hostname), net.isIP(hostname)) { + if (valid = ArrayPrototypeIncludes.call(ips, canonicalizeIP(hostname)), !valid) + reason = `IP: ${hostname} is not in the cert's list: ` + ArrayPrototypeJoin.call(ips, ", "); + } else if (dnsNames.length > 0 || subject?.CN) { + const hostParts = splitHost(hostname), wildcard = (pattern) => check(hostParts, pattern, !0); + if (dnsNames.length > 0) { + if (valid = ArrayPrototypeSome.call(dnsNames, wildcard), !valid) + reason = `Host: ${hostname}. is not in the cert's altnames: ${altNames}`; + } else { + const cn = subject.CN; + if (ArrayIsArray(cn)) + valid = ArrayPrototypeSome.call(cn, wildcard); + else if (cn) + valid = wildcard(cn); + if (!valid) + reason = `Host: ${hostname}. is not cert's CN: ${cn}`; + } + } else + reason = "Cert does not contain a DNS name"; + if (!valid) { + let error = new Error(`ERR_TLS_CERT_ALTNAME_INVALID: Hostname/IP does not match certificate's altnames: ${reason}`); + return error.name = "ERR_TLS_CERT_ALTNAME_INVALID", error.reason = reason, error.host = host, error.cert = cert, error; + } }, SecureContext = function(options) { return new InternalSecureContext(options); }, createSecureContext = function(options) { return new SecureContext(options); -}, createServer = function(options, connectionListener) { +}; +var createServer = function(options, connectionListener) { return new Server(options, connectionListener); }, getCiphers = function() { return DEFAULT_CIPHERS.split(":"); }, getCurves = function() { return; +}, convertProtocols = function(protocols) { + const lens = new Array(protocols.length), buff = Buffer.allocUnsafe(ArrayPrototypeReduce.call(protocols, (p, c, i) => { + const len = Buffer.byteLength(c); + if (len > 255) + throw new RangeError("The byte length of the protocol at index " + `${i} exceeds the maximum length.`, "<= 255", len, !0); + return lens[i] = len, p + 1 + len; + }, 0)); + let offset = 0; + for (let i = 0, c = protocols.length;i < c; i++) + buff[offset++] = lens[i], buff.write(protocols[i], offset), offset += lens[i]; + return buff; }, convertALPNProtocols = function(protocols, out) { -}, InternalTCPSocket = net[Symbol.for("::bunternal::")], InternalSecureContext = class SecureContext2 { + if (Array.isArray(protocols)) + out.ALPNProtocols = convertProtocols(protocols); + else if (isTypedArray(protocols)) + out.ALPNProtocols = Buffer.from(protocols); + else if (isArrayBufferView(protocols)) + out.ALPNProtocols = Buffer.from(protocols.buffer.slice(protocols.byteOffset, protocols.byteOffset + protocols.byteLength)); + else if (Buffer.isBuffer(protocols)) + out.ALPNProtocols = protocols; +}, InternalTCPSocket = net[Symbol.for("::bunternal::")], bunSocketInternal = Symbol.for("::bunnetsocketinternal::"), { RegExp, Array, String } = globalThis[Symbol.for("Bun.lazy")]("primordials"), SymbolReplace = Symbol.replace, RegExpPrototypeSymbolReplace = RegExp.prototype[SymbolReplace], RegExpPrototypeExec = RegExp.prototype.exec, StringPrototypeStartsWith = String.prototype.startsWith, StringPrototypeSlice = String.prototype.slice, StringPrototypeIncludes = String.prototype.includes, StringPrototypeSplit = String.prototype.split, StringPrototypeIndexOf = String.prototype.indexOf, StringPrototypeSubstring = String.prototype.substring, StringPrototypeEndsWith = String.prototype.endsWith, ArrayPrototypeIncludes = Array.prototype.includes, ArrayPrototypeJoin = Array.prototype.join, ArrayPrototypeForEach = Array.prototype.forEach, ArrayPrototypePush = Array.prototype.push, ArrayPrototypeSome = Array.prototype.some, ArrayPrototypeReduce = Array.prototype.reduce, jsonStringPattern = /^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/, InternalSecureContext = class SecureContext2 { context; constructor(options) { const context = {}; @@ -73,8 +182,17 @@ var parseCertString = function() { }); }(class TLSSocket2 extends InternalTCPSocket { #secureContext; - constructor(options) { - super(options); + ALPNProtocols; + #socket; + constructor(socket, options) { + super(socket instanceof InternalTCPSocket ? options : options || socket); + if (options = options || socket || {}, typeof options === "object") { + const { ALPNProtocols } = options; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, this); + if (socket instanceof InternalTCPSocket) + this.#socket = socket; + } this.#secureContext = options.secureContext || createSecureContext(options), this.authorized = !1, this.secureConnecting = !0, this._secureEstablished = !1, this._securePending = !0; } _secureEstablished = !1; @@ -84,19 +202,24 @@ var parseCertString = function() { secureConnecting = !1; _SNICallback; servername; - alpnProtocol; authorized = !1; authorizationError; encrypted = !0; - exportKeyingMaterial() { - throw Error("Not implented in Bun yet"); + _start() { } - setMaxSendFragment() { + exportKeyingMaterial(length, label, context) { throw Error("Not implented in Bun yet"); } - setServername() { + setMaxSendFragment(size) { throw Error("Not implented in Bun yet"); } + setServername(name) { + if (this.isServer) { + let error = new Error("ERR_TLS_SNI_FROM_SERVER: Cannot issue SNI from a TLS server-side socket"); + throw error.name = "ERR_TLS_SNI_FROM_SERVER", error; + } + this.servername = name, this[bunSocketInternal]?.setServername(name); + } setSession() { throw Error("Not implented in Bun yet"); } @@ -112,14 +235,16 @@ var parseCertString = function() { getX509Certificate() { throw Error("Not implented in Bun yet"); } - [buntls](port, host) { - var { servername } = this; - if (servername) - return { - serverName: typeof servername === "string" ? servername : host, - ...this.#secureContext - }; - return !0; + get alpnProtocol() { + return this[bunSocketInternal]?.alpnProtocol; + } + [buntls](port, host2) { + return { + socket: this.#socket, + ALPNProtocols: this.ALPNProtocols, + serverName: this.servername || host2 || "localhost", + ...this.#secureContext + }; } }); @@ -132,9 +257,11 @@ class Server extends NetServer { _rejectUnauthorized; _requestCert; servername; + ALPNProtocols; + #checkServerIdentity; constructor(options, secureConnectionListener) { super(options, secureConnectionListener); - this.setSecureContext(options); + this.#checkServerIdentity = options?.checkServerIdentity || checkServerIdentity, this.setSecureContext(options); } emit(event, args) { if (super.emit(event, args), event === "connection") @@ -146,6 +273,9 @@ class Server extends NetServer { if (options instanceof InternalSecureContext) options = options.context; if (options) { + const { ALPNProtocols } = options; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, this); let key = options.key; if (key) { if (!isValidTLSArray(key)) @@ -194,26 +324,33 @@ class Server extends NetServer { setTicketKeys() { throw Error("Not implented in Bun yet"); } - [buntls](port, host, isClient) { + [buntls](port, host2, isClient) { return [ { - serverName: this.servername || host || "localhost", + serverName: this.servername || host2 || "localhost", key: this.key, cert: this.cert, ca: this.ca, passphrase: this.passphrase, secureOptions: this.secureOptions, rejectUnauthorized: isClient ? !1 : this._rejectUnauthorized, - requestCert: isClient ? !1 : this._requestCert + requestCert: isClient ? !1 : this._requestCert, + ALPNProtocols: this.ALPNProtocols, + checkServerIdentity: this.#checkServerIdentity }, SocketClass ]; } } -var CLIENT_RENEG_LIMIT = 3, CLIENT_RENEG_WINDOW = 600, DEFAULT_ECDH_CURVE = "auto", DEFAULT_CIPHERS = "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256", DEFAULT_MIN_VERSION = "TLSv1.2", DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host, connectListener) => { - if (typeof port === "object") - return new TLSSocket(port).connect(port, host, connectListener); - return new TLSSocket().connect(port, host, connectListener); +var CLIENT_RENEG_LIMIT = 3, CLIENT_RENEG_WINDOW = 600, DEFAULT_ECDH_CURVE = "auto", DEFAULT_CIPHERS = "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256", DEFAULT_MIN_VERSION = "TLSv1.2", DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host2, connectListener) => { + if (typeof port === "object") { + port.checkServerIdentity; + const { ALPNProtocols } = port; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, port); + return new TLSSocket(port).connect(port, host2, connectListener); + } + return new TLSSocket().connect(port, host2, connectListener); }, connect = createConnection, exports = { [Symbol.for("CommonJS")]: 0, CLIENT_RENEG_LIMIT, @@ -244,6 +381,7 @@ export { createConnection, convertALPNProtocols, connect, + checkServerIdentity, TLSSocket, Server, SecureContext, diff --git a/test/bun.lockb b/test/bun.lockb index f30ca197a2920..e3a2abdfa1b45 100755 Binary files a/test/bun.lockb and b/test/bun.lockb differ diff --git a/test/js/node/net/node-net-server.test.ts b/test/js/node/net/node-net-server.test.ts index 398959bd6053c..3cdaa17e13842 100644 --- a/test/js/node/net/node-net-server.test.ts +++ b/test/js/node/net/node-net-server.test.ts @@ -181,61 +181,6 @@ describe("net.createServer listen", () => { ); }); - it("should listen on the correct port", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49027, - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("::"); - expect(address.port).toStrictEqual(49027); - expect(address.family).toStrictEqual("IPv6"); - server.close(); - done(); - }), - ); - }); - - it("should listen on the correct port with IPV4", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49026, - "0.0.0.0", - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("0.0.0.0"); - expect(address.port).toStrictEqual(49026); - expect(address.family).toStrictEqual("IPv4"); - server.close(); - done(); - }), - ); - }); - it("should listen on unix domain socket", done => { const { mustCall, mustNotCall } = createCallCheckCtx(done); diff --git a/test/js/node/tls/node-tls-connect.test.ts b/test/js/node/tls/node-tls-connect.test.ts new file mode 100644 index 0000000000000..791dba88a567d --- /dev/null +++ b/test/js/node/tls/node-tls-connect.test.ts @@ -0,0 +1,32 @@ +import { TLSSocket, connect } from "tls"; + +it("should work with alpnProtocols", done => { + try { + let socket: TLSSocket | null = connect({ + ALPNProtocols: ["http/1.1"], + host: "bun.sh", + servername: "bun.sh", + port: 443, + rejectUnauthorized: false, + }); + + const timeout = setTimeout(() => { + socket?.end(); + done("timeout"); + }, 3000); + + socket.on("error", err => { + clearTimeout(timeout); + done(err); + }); + + socket.on("secureConnect", () => { + clearTimeout(timeout); + done(socket?.alpnProtocol === "http/1.1" ? undefined : "alpnProtocol is not http/1.1"); + socket?.end(); + socket = null; + }); + } catch (err) { + done(err); + } +}); diff --git a/test/js/node/tls/node-tls-server.test.ts b/test/js/node/tls/node-tls-server.test.ts index 6879d09273ecb..2a6101b9f71df 100644 --- a/test/js/node/tls/node-tls-server.test.ts +++ b/test/js/node/tls/node-tls-server.test.ts @@ -195,61 +195,6 @@ describe("tls.createServer listen", () => { ); }); - it("should listen on the correct port", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(COMMON_CERT); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49027, - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("::"); - expect(address.port).toStrictEqual(49027); - expect(address.family).toStrictEqual("IPv6"); - server.close(); - done(); - }), - ); - }); - - it("should listen on the correct port with IPV4", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(COMMON_CERT); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49026, - "0.0.0.0", - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("0.0.0.0"); - expect(address.port).toStrictEqual(49026); - expect(address.family).toStrictEqual("IPv4"); - server.close(); - done(); - }), - ); - }); - it("should listen on unix domain socket", done => { const { mustCall, mustNotCall } = createCallCheckCtx(done); diff --git a/test/js/third_party/nodemailer/nodemailer.test.ts b/test/js/third_party/nodemailer/nodemailer.test.ts new file mode 100644 index 0000000000000..2651126083e67 --- /dev/null +++ b/test/js/third_party/nodemailer/nodemailer.test.ts @@ -0,0 +1,15 @@ +import { test, expect, describe } from "bun:test"; +import { bunRun } from "harness"; +import path from "path"; + +describe("nodemailer", () => { + test("basic smtp", async () => { + try { + const info = bunRun(path.join(import.meta.dir, "process-nodemailer-fixture.js")); + expect(info.stdout).toBe("true"); + expect(info.stderr || "").toBe(""); + } catch (err: any) { + expect(err?.message || err).toBe(""); + } + }, 10000); +}); diff --git a/test/js/third_party/nodemailer/package.json b/test/js/third_party/nodemailer/package.json new file mode 100644 index 0000000000000..08e98074fdbe1 --- /dev/null +++ b/test/js/third_party/nodemailer/package.json @@ -0,0 +1,6 @@ +{ + "name": "nodemailer", + "dependencies": { + "nodemailer": "6.9.3" + } +} diff --git a/test/js/third_party/nodemailer/process-nodemailer-fixture.js b/test/js/third_party/nodemailer/process-nodemailer-fixture.js new file mode 100644 index 0000000000000..a54735f26db0d --- /dev/null +++ b/test/js/third_party/nodemailer/process-nodemailer-fixture.js @@ -0,0 +1,23 @@ +import nodemailer from "nodemailer"; +const account = await nodemailer.createTestAccount(); +const transporter = nodemailer.createTransport({ + host: account.smtp.host, + port: account.smtp.port, + secure: account.smtp.secure, + auth: { + user: account.user, // generated ethereal user + pass: account.pass, // generated ethereal password + }, +}); + +// send mail with defined transport object +let info = await transporter.sendMail({ + from: '"Fred Foo 👻" ', // sender address + to: "example@gmail.com", // list of receivers + subject: "Hello ✔", // Subject line + text: "Hello world?", // plain text body + html: "Hello world?", // html body +}); +const url = nodemailer.getTestMessageUrl(info); +console.log(typeof url === "string" && url.length > 0); +transporter.close(); diff --git a/test/js/web/timers/process-setImmediate-fixture.js b/test/js/web/timers/process-setImmediate-fixture.js new file mode 100644 index 0000000000000..6ffd91c8d6605 --- /dev/null +++ b/test/js/web/timers/process-setImmediate-fixture.js @@ -0,0 +1,9 @@ +setImmediate(() => { + console.log("setImmediate"); + return { + a: 1, + b: 2, + c: 3, + d: 4, + }; +}); diff --git a/test/js/web/timers/setImmediate.test.js b/test/js/web/timers/setImmediate.test.js index 9cd6fa1c99d01..d00224e0fc454 100644 --- a/test/js/web/timers/setImmediate.test.js +++ b/test/js/web/timers/setImmediate.test.js @@ -1,4 +1,6 @@ import { it, expect } from "bun:test"; +import { bunExe, bunEnv } from "harness"; +import path from "path"; it("setImmediate", async () => { var lastID = -1; @@ -45,3 +47,28 @@ it("clearImmediate", async () => { }); expect(called).toBe(false); }); + +it("setImmediate should not keep the process alive forever", async () => { + let process = null; + const success = async () => { + process = Bun.spawn({ + cmd: [bunExe(), "run", path.join(import.meta.dir, "process-setImmediate-fixture.js")], + stdout: "ignore", + env: { + ...bunEnv, + NODE_ENV: undefined, + }, + }); + await process.exited; + process = null; + return true; + }; + + const fail = async () => { + await Bun.sleep(500); + process?.kill(); + return false; + }; + + expect(await Promise.race([success(), fail()])).toBe(true); +}); diff --git a/test/package.json b/test/package.json index 1165718793942..db0053874273f 100644 --- a/test/package.json +++ b/test/package.json @@ -16,9 +16,10 @@ "iconv-lite": "0.6.3", "jest-extended": "4.0.0", "lodash": "4.17.21", + "nodemailer": "6.9.3", "prisma": "4.15.0", - "socket.io": "4.6.1", - "socket.io-client": "4.6.1", + "socket.io": "4.7.1", + "socket.io-client": "4.7.1", "supertest": "6.1.6", "svelte": "3.55.1", "typescript": "5.0.2",