diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index 4a4aed7aa3e4b7..e12ea20afa4918 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -1003,22 +1003,21 @@ declare module "bun" { * * **Experimental API** * - * Prefetch a hostname and port. + * Prefetch a hostname. * * This will be used by fetch() and Bun.connect() to avoid DNS lookups. * * @param hostname The hostname to prefetch - * @param port The port to prefetch * * @example * ```js * import { dns } from 'bun'; - * dns.prefetch('example.com', 443); + * dns.prefetch('example.com'); * // ... something expensive * await fetch('https://example.com'); * ``` */ - prefetch(hostname: string, port: number): void; + prefetch(hostname: string): void; /** * **Experimental API** diff --git a/packages/bun-usockets/src/bsd.c b/packages/bun-usockets/src/bsd.c index 5a64df6e8ee34c..50ea790cb5bac1 100644 --- a/packages/bun-usockets/src/bsd.c +++ b/packages/bun-usockets/src/bsd.c @@ -888,11 +888,12 @@ int bsd_disconnect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd) { // return 0; // no ecn defaults to 0 // } -static int bsd_do_connect_raw(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR fd) +static int bsd_do_connect_raw(struct sockaddr_storage *addr, LIBUS_SOCKET_DESCRIPTOR fd) { + int namelen = addr->ss_family == AF_INET ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); #ifdef _WIN32 do { - if (connect(fd, rp->ai_addr, rp->ai_addrlen) == 0 || WSAGetLastError() == WSAEINPROGRESS) { + if (connect(fd, (struct sockaddr *)addr, namelen) == 0 || WSAGetLastError() == WSAEINPROGRESS) { return 0; } } while (WSAGetLastError() == WSAEINTR); @@ -900,7 +901,7 @@ static int bsd_do_connect_raw(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR fd) return WSAGetLastError(); #else do { - if (connect(fd, rp->ai_addr, rp->ai_addrlen) == 0 || errno == EINPROGRESS) { + if (connect(fd, (struct sockaddr *)addr, namelen) == 0 || errno == EINPROGRESS) { return 0; } } while (errno == EINTR); @@ -909,77 +910,34 @@ static int bsd_do_connect_raw(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR fd) #endif } -static int bsd_do_connect(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR *fd) -{ - int lastErr = 0; - while (rp != NULL) { - lastErr = bsd_do_connect_raw(rp, *fd); - if (lastErr == 0) { - return 0; - } - - rp = rp->ai_next; - bsd_close_socket(*fd); - - if (rp == NULL) { - if (lastErr != 0) { - errno = lastErr; - } - return LIBUS_SOCKET_ERROR; - } - - LIBUS_SOCKET_DESCRIPTOR resultFd = bsd_create_socket(rp->ai_family, SOCK_STREAM, 0); - if (resultFd < 0) { - return LIBUS_SOCKET_ERROR; - } - *fd = resultFd; - } - - if (lastErr != 0) { - errno = lastErr; - } - - return LIBUS_SOCKET_ERROR; -} - #ifdef _WIN32 -static int convert_null_addr(struct addrinfo *addrinfo, struct addrinfo* result, struct sockaddr_storage *inaddr) { +static int convert_null_addr(const struct sockaddr_storage *addr, struct sockaddr_storage* result) { // 1. check that all addrinfo results are 0.0.0.0 or :: - if (addrinfo->ai_family == AF_INET) { - struct sockaddr_in *addr = (struct sockaddr_in *) addrinfo->ai_addr; - if (addr->sin_addr.s_addr == htonl(INADDR_ANY)) { - memcpy(inaddr, addr, sizeof(struct sockaddr_in)); - ((struct sockaddr_in *) inaddr)->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - - memcpy(result, addrinfo, sizeof(struct addrinfo)); - result->ai_addr = (struct sockaddr *) inaddr; - result->ai_next = NULL; - + if (addr->ss_family == AF_INET) { + struct sockaddr_in *addr4 = (struct sockaddr_in *) addr; + if (addr4->sin_addr.s_addr == htonl(INADDR_ANY)) { + memcpy(result, addr, sizeof(struct sockaddr_in)); + ((struct sockaddr_in *) result)->sin_addr.s_addr = htonl(INADDR_LOOPBACK); return 1; } - } else if (addrinfo->ai_family == AF_INET6) { - struct sockaddr_in6 *addr = (struct sockaddr_in6 *) addrinfo->ai_addr; - if (memcmp(&addr->sin6_addr, &in6addr_any, sizeof(struct in6_addr)) == 0) { - memcpy(inaddr, addr, sizeof(struct sockaddr_in6)); - memcpy(&((struct sockaddr_in6 *) inaddr)->sin6_addr, &in6addr_loopback, sizeof(struct in6_addr)); - - memcpy(result, addrinfo, sizeof(struct addrinfo)); - result->ai_addr = (struct sockaddr *) inaddr; - result->ai_next = NULL; - + } else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *) addr; + if (memcmp(&addr6->sin6_addr, &in6addr_any, sizeof(struct in6_addr)) == 0) { + memcpy(result, addr, sizeof(struct sockaddr_in6)); + memcpy(&((struct sockaddr_in6 *) result)->sin6_addr, &in6addr_loopback, sizeof(struct in6_addr)); return 1; } } return 0; } -static int is_loopback(struct addrinfo *addrinfo) { - if (addrinfo->ai_family == AF_INET) { - struct sockaddr_in *addr = (struct sockaddr_in *) addrinfo->ai_addr; +static int is_loopback(struct sockaddr_storage *sockaddr) { + if (sockaddr->ss_family == AF_INET) { + struct sockaddr_in *addr = (struct sockaddr_in *) sockaddr; return addr->sin_addr.s_addr == htonl(INADDR_LOOPBACK); - } else if (addrinfo->ai_family == AF_INET6) { - struct sockaddr_in6 *addr = (struct sockaddr_in6 *) addrinfo->ai_addr; + } else if (sockaddr->ss_family == AF_INET6) { + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) sockaddr; return memcmp(&addr->sin6_addr, &in6addr_loopback, sizeof(struct in6_addr)) == 0; } else { return 0; @@ -987,8 +945,8 @@ static int is_loopback(struct addrinfo *addrinfo) { } #endif -LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int options) { - LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(addrinfo->ai_family, SOCK_STREAM, 0); +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct sockaddr_storage *addr, int options) { + LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(addr->ss_family, SOCK_STREAM, 0); if (fd == LIBUS_SOCKET_ERROR) { return LIBUS_SOCKET_ERROR; } @@ -997,11 +955,9 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int // On windows we can't connect to the null address directly. // To match POSIX behavior, we need to connect to localhost instead. - struct addrinfo alt_result; - struct sockaddr_storage storage; - - if (convert_null_addr(addrinfo, &alt_result, &storage)) { - addrinfo = &alt_result; + struct sockaddr_storage converted; + if (convert_null_addr(addr, &converted)) { + addr = &converted; } // This sets the socket to fail quickly if no connection can be established to localhost, @@ -1010,7 +966,7 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int // see https://github.com/libuv/libuv/blob/bf61390769068de603e6deec8e16623efcbe761a/src/win/tcp.c#L806 TCP_INITIAL_RTO_PARAMETERS retransmit_ioctl; DWORD bytes; - if (is_loopback(addrinfo)) { + if (is_loopback(addr)) { memset(&retransmit_ioctl, 0, sizeof(retransmit_ioctl)); retransmit_ioctl.Rtt = TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS; retransmit_ioctl.MaxSynRetransmissions = TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS; @@ -1027,10 +983,10 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int #endif - if (bsd_do_connect(addrinfo, &fd) != 0) { + if (bsd_do_connect_raw(addr, fd) != 0) { + bsd_close_socket(fd); return LIBUS_SOCKET_ERROR; } - return fd; } diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index d747285402457f..ddc60b13330c87 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -345,22 +345,12 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock } -struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_context_t *context, void* request, int options, int socket_ext_size) { - struct addrinfo_result *result = Bun__addrinfo_getRequestResult(request); - if (result->error) { - errno = result->error; - Bun__addrinfo_freeRequest(request, 1); - return NULL; - } - - LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(result->info, options); +struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_context_t *context, struct sockaddr_storage* addr, int options, int socket_ext_size) { + LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(addr, options); if (connect_socket_fd == LIBUS_SOCKET_ERROR) { - int err = errno; - Bun__addrinfo_freeRequest(request, err); return NULL; } - Bun__addrinfo_freeRequest(request, 0); bsd_socket_nodelay(connect_socket_fd, 1); /* Connect sockets are semi-sockets just like listen sockets */ @@ -381,7 +371,19 @@ struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_cont return socket; } -struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size, int* is_connecting) { +static void init_addr_with_port(struct addrinfo* info, int port, struct sockaddr_storage *addr) { + if (info->ai_family == AF_INET) { + struct sockaddr_in *addr_in = (struct sockaddr_in *) addr; + memcpy(addr_in, info->ai_addr, info->ai_addrlen); + addr_in->sin_port = htons(port); + } else { + struct sockaddr_in6 *addr_in6 = (struct sockaddr_in6 *) addr; + memcpy(addr_in6, info->ai_addr, info->ai_addrlen); + addr_in6->sin6_port = htons(port); + } +} + +void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size, int* is_connecting) { #ifndef LIBUS_NO_SSL if (ssl == 1) { return us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, options, socket_ext_size, is_connecting); @@ -390,12 +392,25 @@ struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_sock struct us_loop_t* loop = us_socket_context_loop(ssl, context); - void* ptr; - if (Bun__addrinfo_get(loop, host, port, &ptr) == 0) { - // Fast-path: it's already cached. - // Avoid the connection logic. - *is_connecting = 1; - return (struct us_connecting_socket_t *) us_socket_context_connect_resolved_dns(context, ptr, options, socket_ext_size); + struct addrinfo_request* ai_req; + if (Bun__addrinfo_get(loop, host, &ai_req) == 0) { + struct addrinfo_result *result = Bun__addrinfo_getRequestResult(ai_req); + // fast failure path + if (result->error) { + errno = result->error; + Bun__addrinfo_freeRequest(ai_req, 1); + return NULL; + } + + // if there is only one result we can immediately connect + if (result->info && result->info->ai_next == NULL) { + struct sockaddr_storage addr; + init_addr_with_port(result->info, port, &addr); + *is_connecting = 1; + struct us_socket_t *s = us_socket_context_connect_resolved_dns(context, &addr, options, socket_ext_size); + Bun__addrinfo_freeRequest(ai_req, s == NULL); + return s; + } } struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size); @@ -406,6 +421,7 @@ struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_sock c->timeout = 255; c->long_timeout = 255; c->pending_resolve_callback = 1; + c->port = port; #ifdef _WIN32 loop->uv_loop->active_handles++; @@ -413,7 +429,7 @@ struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_sock loop->num_polls++; #endif - Bun__addrinfo_set(ptr, c); + Bun__addrinfo_set(ai_req, c); return c; } @@ -440,36 +456,110 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) { us_connecting_socket_close(0, c); return; } - LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(result->info, c->options); - if (connect_socket_fd == LIBUS_SOCKET_ERROR) { - c->error = errno; - c->context->on_connect_error(c, errno); + + int error = 0; + for (struct addrinfo *info = result->info; info; info = info->ai_next) { + struct sockaddr_storage addr; + init_addr_with_port(info, c->port, &addr); + LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(&addr, c->options); + if (connect_socket_fd == LIBUS_SOCKET_ERROR) { + continue; + } + bsd_socket_nodelay(connect_socket_fd, 1); + + struct us_socket_t *s = (struct us_socket_t *)us_create_poll(c->context->loop, 0, sizeof(struct us_socket_t) + c->socket_ext_size); + s->context = c->context; + s->timeout = c->timeout; + s->long_timeout = c->long_timeout; + s->low_prio_state = 0; + /* Link it into context so that timeout fires properly */ + us_internal_socket_context_link_socket(s->context, s); + + // TODO check this, specifically how it interacts with the SSL code + // does this work when we create multiple sockets at once? will we need multiple SSL contexts? + // no, we won't need multiple contexts - the context is only initialized on_open + memcpy(us_socket_ext(0, s), us_connecting_socket_ext(0, c), c->socket_ext_size); + + // store the socket so we can close it if we need to + s->connect_next = c->connecting_head; + c->connecting_head = s; + + s->connect_state = c; + + /* Connect sockets are semi-sockets just like listen sockets */ + us_poll_init(&s->p, connect_socket_fd, POLL_TYPE_SEMI_SOCKET); + us_poll_start(&s->p, s->context->loop, LIBUS_SOCKET_WRITABLE); + } + + if (!c->connecting_head) { + c->error = error; + c->context->on_connect_error(c, error); Bun__addrinfo_freeRequest(c->addrinfo_req, 1); us_connecting_socket_close(0, c); return; } Bun__addrinfo_freeRequest(c->addrinfo_req, 0); - bsd_socket_nodelay(connect_socket_fd, 1); +} - struct us_socket_t *s = (struct us_socket_t *)us_create_poll(c->context->loop, 0, sizeof(struct us_socket_t) + c->socket_ext_size); - s->context = c->context; - s->timeout = c->timeout; - s->long_timeout = c->long_timeout; - s->low_prio_state = 0; - /* Link it into context so that timeout fires properly */ - us_internal_socket_context_link_socket(s->context, s); - // TODO check this, specifically how it interacts with the SSL code - memcpy(us_socket_ext(0, s), us_connecting_socket_ext(0, c), c->socket_ext_size); +void us_internal_socket_after_open(struct us_socket_t *s, int error) { + struct us_connecting_socket_t *c = s->connect_state; + /* It is perfectly possible to come here with an error */ + if (error) { - // store the socket so we can close it if we need to - c->socket = s; - s->connect_state = c; + /* Emit error, close without emitting on_close */ - /* Connect sockets are semi-sockets just like listen sockets */ - us_poll_init(&s->p, connect_socket_fd, POLL_TYPE_SEMI_SOCKET); - us_poll_start(&s->p, s->context->loop, LIBUS_SOCKET_WRITABLE); + /* There are two possible states here: + 1. It's a us_connecting_socket_t*. DNS resolution failed, or a connection failed. + 2. It's a us_socket_t* + + We differentiate between these two cases by checking if the connect_state is null. + */ + if (c) { + // remove this connecting socket from the list of connecting sockets + // if it was the last one, signal the error to the user + for (struct us_socket_t **next = &c->connecting_head; *next; next = &(*next)->connect_next) { + if (*next == s) { + *next = s->connect_next; + break; + } + } + us_socket_close(0, s, 0, 0); + if (!c->connecting_head) { + c->context->on_connect_error(c, error); + us_connecting_socket_close(0, c); + } + } else { + s->context->on_socket_connect_error(s, error); + // It's expected that close is called by the caller + } + } else { + /* All sockets poll for readable */ + us_poll_change(&s->p, s->context->loop, LIBUS_SOCKET_READABLE); + + /* We always use nodelay */ + bsd_socket_nodelay(us_poll_fd(&s->p), 1); + + /* We are now a proper socket */ + us_internal_poll_set_type(&s->p, POLL_TYPE_SOCKET); + + /* If we used a connection timeout we have to reset it here */ + us_socket_timeout(0, s, 0); + // if there is a connect_state, we need to close all other connection attempts that are currently in progress + if (c) { + for (struct us_socket_t *next = c->connecting_head; next; next = next->connect_next) { + if (next != s) { + us_socket_close(0, next, 0, 0); + } + } + // now that the socket is open, we can release the associated us_connecting_socket_t if it exists + us_connecting_socket_free(c); + s->connect_state = NULL; + } + + s->context->on_open(s, 1, 0, 0); + } } struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_context_t *context, const char *server_path, size_t pathlen, int options, int socket_ext_size) { @@ -539,7 +629,7 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con if (ext_size != -1) { new_s = (struct us_socket_t *) us_poll_resize(&s->p, s->context->loop, sizeof(struct us_socket_t) + ext_size); if (c) { - c->socket = new_s; + c->connecting_head = new_s; c->context = context; } } diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index ca2f000126430a..4ba1aa88fc115f 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -75,15 +75,16 @@ enum { void Bun__lock(uint32_t *lock); void Bun__unlock(uint32_t *lock); +struct addrinfo_request; struct addrinfo_result { struct addrinfo *info; int error; }; -extern int Bun__addrinfo_get(struct us_loop_t* loop, const char* host, int port, void** ptr); -extern int Bun__addrinfo_set(void* ptr, struct us_connecting_socket_t* socket); -extern void Bun__addrinfo_freeRequest(void* addrinfo_req, int error); -extern struct addrinfo_result *Bun__addrinfo_getRequestResult(void* addrinfo_req); +extern int Bun__addrinfo_get(struct us_loop_t* loop, const char* host, struct addrinfo_request** ptr); +extern int Bun__addrinfo_set(struct addrinfo_request* ptr, struct us_connecting_socket_t* socket); +extern void Bun__addrinfo_freeRequest(struct addrinfo_request* addrinfo_req, int error); +extern struct addrinfo_result *Bun__addrinfo_getRequestResult(struct addrinfo_request* addrinfo_req); /* Loop related */ @@ -128,6 +129,7 @@ void us_internal_socket_context_unlink_socket( struct us_socket_context_t *context, struct us_socket_t *s); void us_internal_socket_after_resolve(struct us_connecting_socket_t *s); +void us_internal_socket_after_open(struct us_socket_t *s, int error); int us_internal_handle_dns_results(struct us_loop_t *loop); /* Sockets are polls */ @@ -140,19 +142,21 @@ struct us_socket_t { = was in low-prio queue in this iteration */ struct us_socket_context_t *context; struct us_socket_t *prev, *next; + struct us_socket_t *connect_next; struct us_connecting_socket_t *connect_state; }; struct us_connecting_socket_t { - alignas(LIBUS_EXT_ALIGNMENT) void *addrinfo_req; + alignas(LIBUS_EXT_ALIGNMENT) struct addrinfo_request *addrinfo_req; struct us_socket_context_t *context; struct us_connecting_socket_t *next; - struct us_socket_t *socket; + struct us_socket_t *connecting_head; int options; int socket_ext_size; unsigned int closed : 1, shutdown : 1, ssl : 1, shutdown_read : 1, pending_resolve_callback : 1; unsigned char timeout; unsigned char long_timeout; + uint16_t port; int error; }; diff --git a/packages/bun-usockets/src/internal/networking/bsd.h b/packages/bun-usockets/src/internal/networking/bsd.h index ba968380d49dbb..d3bc65e16d42fb 100644 --- a/packages/bun-usockets/src/internal/networking/bsd.h +++ b/packages/bun-usockets/src/internal/networking/bsd.h @@ -150,7 +150,7 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_udp_socket(const char *host, int port); int bsd_connect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd, const char *host, int port); int bsd_disconnect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd); -LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int options); +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct sockaddr_storage *addr, int options); LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket_unix(const char *server_path, size_t pathlen, int options); diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index f3e50165129c83..e33ba6e80801e3 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -270,8 +270,16 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock /* listen_socket.c/.h */ void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls); -/* Land in on_open or on_connection_error or return null or return socket */ -struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, +/* + Returns one of + - struct us_socket_t * - indicated by the value at on_connecting being set to 1 + This is the fast path where the DNS result is available immediately and only a single remote + address is available + - struct us_connecting_socket_t * - indicated by the value at on_connecting being set to 0 + This is the slow path where we must either go through DNS resolution or create multiple sockets + per the happy eyeballs algorithm +*/ +void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size, int *is_connecting); struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_context_t *context, diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index 08eb7f744a00dd..ae9abf241f20fe 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -280,50 +280,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* Both connect and listen sockets are semi-sockets * but they poll for different events */ if (us_poll_events(p) == LIBUS_SOCKET_WRITABLE) { - struct us_socket_t *s = (struct us_socket_t *) p; - - /* It is perfectly possible to come here with an error */ - if (error) { - struct us_connecting_socket_t *c = s->connect_state; - - /* Emit error, close without emitting on_close */ - - /* There are two possible states here: - 1. It's a us_connecting_socket_t*. DNS resolution failed, or a connection failed. - 2. It's a us_socket_t* - - We differentiate between these two cases by checking if the connect_state is null. - */ - if (c) { - s->context->on_connect_error(s->connect_state, error); - us_connecting_socket_close(c->ssl, c); - } else { - s->context->on_socket_connect_error(s, error); - // It's expected that close is called by the caller - } - - s = NULL; - } else { - /* All sockets poll for readable */ - us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE); - - /* We always use nodelay */ - bsd_socket_nodelay(us_poll_fd(p), 1); - - /* We are now a proper socket */ - us_internal_poll_set_type(p, POLL_TYPE_SOCKET); - - /* If we used a connection timeout we have to reset it here */ - us_socket_timeout(0, s, 0); - - s->context->on_open(s, 1, 0, 0); - - if (s->connect_state) { - // now that the socket is open, we can release the associated us_connecting_socket_t if it exists - us_connecting_socket_free(s->connect_state); - s->connect_state = NULL; - } - } + us_internal_socket_after_open((struct us_socket_t *) p, error); } else { struct us_listen_socket_t *listen_socket = (struct us_listen_socket_t *) p; struct bsd_addr_t addr; diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index c3ad62611b3413..688dfd98a6828c 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -136,10 +136,7 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { if (c->closed) return; c->closed = 1; - struct us_socket_t *s = c->socket; - if (s) { - c->socket = NULL; - + for (struct us_socket_t *s = c->connecting_head; s; s = s->connect_next) { us_internal_socket_context_unlink_socket(s->context, s); us_poll_stop((struct us_poll_t *) s, s->context->loop); bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); diff --git a/src/bun.js/api/bun/dns_resolver.zig b/src/bun.js/api/bun/dns_resolver.zig index bd09be53d4bbef..984a4407445bb6 100644 --- a/src/bun.js/api/bun/dns_resolver.zig +++ b/src/bun.js/api/bun/dns_resolver.zig @@ -1200,16 +1200,14 @@ pub const InternalDNS = struct { pub usingnamespace bun.New(@This()); const Key = struct { host: ?[:0]const u8, - port: u16, hash: u64, - pub fn init(name: ?[:0]const u8, port: u16) @This() { + pub fn init(name: ?[:0]const u8) @This() { const hash = if (name) |n| brk: { break :brk bun.hash(n); } else 0; return .{ .host = name, - .port = port, .hash = hash, }; } @@ -1219,7 +1217,6 @@ pub const InternalDNS = struct { const host_copy = bun.default_allocator.dupeZ(u8, host) catch bun.outOfMemory(); return .{ .host = host_copy, - .port = this.port, .hash = this.hash, }; } else { @@ -1258,8 +1255,6 @@ pub const InternalDNS = struct { /// Not a precise timestamp. created_at: u32 = std.math.maxInt(u32), - lock: bun.Lock = bun.Lock.init(), - valid: bool = true, libinfo: if (Environment.isMac) MacAsyncDNS else void = if (Environment.isMac) .{} else {}, @@ -1319,7 +1314,7 @@ pub const InternalDNS = struct { var i: usize = 0; while (i < len) { var entry = this.cache[i]; - if (entry.key.hash == key.hash and entry.key.port == key.port and entry.valid) { + if (entry.key.hash == key.hash and entry.valid) { if (entry.isExpired(timestamp_to_store)) { log("get: expired entry", .{}); _ = this.deleteEntryAt(len, i); @@ -1438,8 +1433,7 @@ pub const InternalDNS = struct { }; fn afterResult(req: *Request, info: ?*std.c.addrinfo, err: c_int) void { - // Only lock while - req.lock.lock(); + global_cache.lock.lock(); req.result = .{ .info = info, @@ -1449,7 +1443,9 @@ pub const InternalDNS = struct { defer notify.deinit(bun.default_allocator); req.notify = .{}; req.refcount -= 1; - req.lock.unlock(); + + // is this correct, or should it go after the loop? + global_cache.lock.unlock(); for (notify.items) |query| { query.notifyThreadsafe(req); @@ -1457,11 +1453,6 @@ pub const InternalDNS = struct { } fn workPoolCallback(req: *Request) void { - var port_buf: [128]u8 = undefined; - const port = std.fmt.bufPrintIntToSlice(&port_buf, req.key.port, 10, .lower, .{}); - port_buf[port.len] = 0; - const portZ = port_buf[0..port.len :0]; - if (Environment.isWindows) { const wsa = std.os.windows.ws2_32; const wsa_hints = wsa.addrinfo{ @@ -1478,7 +1469,7 @@ pub const InternalDNS = struct { var addrinfo: ?*wsa.addrinfo = null; const err = wsa.getaddrinfo( if (req.key.host) |host| host.ptr else null, - if (port.len > 0) portZ.ptr else null, + null, &wsa_hints, &addrinfo, ); @@ -1487,7 +1478,7 @@ pub const InternalDNS = struct { var addrinfo: ?*std.c.addrinfo = null; const err = std.c.getaddrinfo( if (req.key.host) |host| host.ptr else null, - if (port.len > 0) portZ.ptr else null, + null, &hints, &addrinfo, ); @@ -1498,16 +1489,11 @@ pub const InternalDNS = struct { pub fn lookupLibinfo(req: *Request, loop: JSC.EventLoopHandle) bool { const getaddrinfo_async_start_ = LibInfo.getaddrinfo_async_start() orelse return false; - var port_buf: [128]u8 = undefined; - const port = std.fmt.bufPrintIntToSlice(&port_buf, req.key.port, 10, .lower, .{}); - port_buf[port.len] = 0; - const portZ = port_buf[0..port.len :0]; - var machport: ?*anyopaque = null; const errno = getaddrinfo_async_start_( &machport, if (req.key.host) |host| host.ptr else null, - if (port.len > 0) portZ.ptr else null, + null, &hints, libinfoCallback, req, @@ -1560,9 +1546,9 @@ pub const InternalDNS = struct { return object; } - pub fn getaddrinfo(loop: *bun.uws.Loop, host: ?[:0]const u8, port: u16, is_cache_hit: ?*bool) ?*Request { + pub fn getaddrinfo(loop: *bun.uws.Loop, host: ?[:0]const u8, is_cache_hit: ?*bool) ?*Request { const preload = is_cache_hit == null; - const key = Request.Key.init(host, port); + const key = Request.Key.init(host); global_cache.lock.lock(); getaddrinfo_calls += 1; var timestamp_to_store: u32 = 0; @@ -1574,19 +1560,17 @@ pub const InternalDNS = struct { return null; } - entry.lock.lock(); entry.refcount += 1; if (entry.result != null) { is_cache_hit.?.* = true; - log("getaddrinfo({s}:{d}) = cache hit", .{ host orelse "", port }); + log("getaddrinfo({s}) = cache hit", .{host orelse ""}); dns_cache_hits_completed += 1; } else { - log("getaddrinfo({s}:{d}) = cache hit (inflight)", .{ host orelse "", port }); + log("getaddrinfo({s}) = cache hit (inflight)", .{host orelse ""}); dns_cache_hits_inflight += 1; } - entry.lock.unlock(); global_cache.lock.unlock(); return entry; @@ -1607,17 +1591,16 @@ pub const InternalDNS = struct { dns_cache_size = global_cache.len; global_cache.lock.unlock(); - // doesn't work yet if (comptime Environment.isMac) { if (!bun.getRuntimeFeatureFlag("BUN_FEATURE_FLAG_DISABLE_DNS_CACHE_LIBINFO")) { const res = lookupLibinfo(req, loop.internal_loop_data.getParent()); - log("getaddrinfo({s}:{d}) = cache miss (libinfo)", .{ host orelse "", port }); + log("getaddrinfo({s}) = cache miss (libinfo)", .{host orelse ""}); if (res) return req; // if we were not able to use libinfo, we fall back to the work pool } } - log("getaddrinfo({s}:{d}) = cache miss (libc)", .{ host orelse "", port }); + log("getaddrinfo({s}) = cache miss (libc)", .{host orelse ""}); // schedule the request to be executed on the work pool bun.JSC.WorkPool.go(bun.default_allocator, *Request, req, workPoolCallback) catch bun.outOfMemory(); return req; @@ -1635,22 +1618,9 @@ pub const InternalDNS = struct { var hostname_slice = JSC.ZigString.Slice.empty; defer hostname_slice.deinit(); - var port: u16 = 0; if (hostname_or_url.isString()) { hostname_slice = hostname_or_url.toSlice(globalThis, bun.default_allocator); - - if (arguments.len > 1 and arguments[1].isAnyInt()) { - const portI = arguments[1].coerce(i32, globalThis); - if (portI < 0 or portI > 65535) { - globalThis.throwInvalidArguments("port must be between 0 and 65535", .{}); - return .zero; - } - port = @intCast(portI); - } else { - globalThis.throwInvalidArguments("port must be an integer", .{}); - return .zero; - } } else { globalThis.throwInvalidArguments("hostname must be a string", .{}); return .zero; @@ -1662,18 +1632,18 @@ pub const InternalDNS = struct { }; defer bun.default_allocator.free(hostname_z); - prefetch(JSC.VirtualMachine.get().uwsLoop(), hostname_z, port); + prefetch(JSC.VirtualMachine.get().uwsLoop(), hostname_z); return .undefined; } - pub fn prefetch(loop: *bun.uws.Loop, hostname: ?[:0]const u8, port: u16) void { - _ = getaddrinfo(loop, hostname, port, null); + pub fn prefetch(loop: *bun.uws.Loop, hostname: ?[:0]const u8) void { + _ = getaddrinfo(loop, hostname, null); } - fn us_getaddrinfo(loop: *bun.uws.Loop, _host: ?[*:0]const u8, port: u16, socket: *?*anyopaque) callconv(.C) c_int { + fn us_getaddrinfo(loop: *bun.uws.Loop, _host: ?[*:0]const u8, socket: *?*anyopaque) callconv(.C) c_int { const host: ?[:0]const u8 = std.mem.span(_host); var is_cache_hit: bool = false; - const req = getaddrinfo(loop, host, port, &is_cache_hit).?; + const req = getaddrinfo(loop, host, &is_cache_hit).?; socket.* = req; return if (is_cache_hit) 0 else 1; } @@ -1682,33 +1652,30 @@ pub const InternalDNS = struct { request: *Request, socket: *bun.uws.ConnectingSocket, ) callconv(.C) void { - request.lock.lock(); + global_cache.lock.lock(); + defer global_cache.lock.unlock(); const query = DNSRequestOwner{ .socket = socket, }; if (request.result != null) { - request.lock.unlock(); query.notify(request); return; } request.notify.append(bun.default_allocator, .{ .socket = socket }) catch bun.outOfMemory(); - request.lock.unlock(); } fn freeaddrinfo(req: *Request, err: c_int) callconv(.C) void { - req.lock.lock(); - defer req.lock.unlock(); + global_cache.lock.lock(); + defer global_cache.lock.unlock(); req.valid = err == 0; dns_cache_errors += @as(usize, @intFromBool(err != 0)); + bun.assert(req.refcount > 0); req.refcount -= 1; if (req.refcount == 0 and (global_cache.isNearlyFull() or !req.valid)) { - global_cache.lock.lock(); log("cache --", .{}); - - defer global_cache.lock.unlock(); global_cache.remove(req); req.deinit(); } diff --git a/src/deps/uws.zig b/src/deps/uws.zig index c5235888589b65..1374585f14f578 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -589,8 +589,16 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { var stack_fallback = std.heap.stackFallback(1024, bun.default_allocator); var allocator = stack_fallback.get(); - const host_ = allocator.dupeZ(u8, host) catch return null; - defer allocator.free(host_); + + // remove brackets from IPv6 addresses, as getaddrinfo doesn't understand them + const clean_host = if (host.len > 1 and host[0] == '[' and host[host.len - 1] == ']') + host[1 .. host.len - 1] + else + host; + + const host_ = allocator.dupeZ(u8, clean_host) catch bun.outOfMemory(); + defer allocator.free(host); + var did_dns_resolve: i32 = 0; const socket = us_socket_context_connect(comptime ssl_int, socket_ctx, host_, port, 0, @sizeOf(Context), &did_dns_resolve) orelse return null; const socket_ = if (did_dns_resolve == 1) @@ -682,22 +690,20 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { var stack_fallback = std.heap.stackFallback(1024, bun.default_allocator); var allocator = stack_fallback.get(); - const host: ?[*:0]u8 = brk: { - // getaddrinfo expects `node` to be null if localhost - if (raw_host.len < 6 and (bun.strings.eqlComptime(raw_host, "[::1]") or bun.strings.eqlComptime(raw_host, "[::]"))) { - break :brk null; - } - - break :brk allocator.dupeZ(u8, raw_host) catch bun.outOfMemory(); - }; + // remove brackets from IPv6 addresses, as getaddrinfo doesn't understand them + const clean_host = if (raw_host.len > 1 and raw_host[0] == '[' and raw_host[raw_host.len - 1] == ']') + raw_host[1 .. raw_host.len - 1] + else + raw_host; - defer if (host) |allocated_host| allocator.free(allocated_host[0..raw_host.len]); + const host = allocator.dupeZ(u8, clean_host) catch bun.outOfMemory(); + defer allocator.free(host); var did_dns_resolve: i32 = 0; const socket_ptr = us_socket_context_connect( comptime ssl_int, socket_ctx, - host, + host.ptr, port, 0, @sizeOf(*anyopaque), @@ -1425,7 +1431,7 @@ extern fn us_socket_context_ext(ssl: i32, context: ?*SocketContext) ?*anyopaque; pub extern fn us_socket_context_listen(ssl: i32, context: ?*SocketContext, host: ?[*:0]const u8, port: i32, options: i32, socket_ext_size: i32) ?*ListenSocket; pub extern fn us_socket_context_listen_unix(ssl: i32, context: ?*SocketContext, path: [*:0]const u8, pathlen: usize, options: i32, socket_ext_size: i32) ?*ListenSocket; -pub extern fn us_socket_context_connect(ssl: i32, context: ?*SocketContext, host: ?[*:0]const u8, port: i32, options: i32, socket_ext_size: i32, has_dns_resolved: *i32) ?*anyopaque; +pub extern fn us_socket_context_connect(ssl: i32, context: ?*SocketContext, host: [*:0]const u8, port: i32, options: i32, socket_ext_size: i32, has_dns_resolved: *i32) ?*anyopaque; pub extern fn us_socket_context_connect_unix(ssl: i32, context: ?*SocketContext, path: [*c]const u8, pathlen: usize, options: i32, socket_ext_size: i32) ?*Socket; pub extern fn us_socket_is_established(ssl: i32, s: ?*Socket) i32; pub extern fn us_socket_context_loop(ssl: i32, context: ?*SocketContext) ?*Loop; diff --git a/src/install/install.zig b/src/install/install.zig index b557df8befd3b0..2bd67c900c865a 100644 --- a/src/install/install.zig +++ b/src/install/install.zig @@ -10391,7 +10391,7 @@ pub const PackageManager = struct { const allocator = hostname_stack.get(); const hostname = try allocator.dupeZ(u8, manager.options.scope.url.hostname); defer allocator.free(hostname); - bun.dns.internal.prefetch(manager.event_loop.loop(), hostname, manager.options.scope.url.getPortAuto()); + bun.dns.internal.prefetch(manager.event_loop.loop(), hostname); } var load_lockfile_result: Lockfile.LoadFromDiskResult = if (manager.options.do.load_lockfile) diff --git a/test/js/bun/dns/dns-prefetch.test.ts b/test/js/bun/dns/dns-prefetch.test.ts index 1b40f48b976d01..b8455e71643ab7 100644 --- a/test/js/bun/dns/dns-prefetch.test.ts +++ b/test/js/bun/dns/dns-prefetch.test.ts @@ -4,7 +4,7 @@ import { describe, expect, it } from "bun:test"; describe("dns.prefetch", () => { it("should prefetch", async () => { const currentStats = dns.getCacheStats(); - dns.prefetch("example.com", 80); + dns.prefetch("example.com"); await Bun.sleep(32); // Must set keepalive: false to ensure it doesn't reuse the socket. @@ -25,11 +25,5 @@ describe("dns.prefetch", () => { const newStats2 = dns.getCacheStats(); // Ensure it's cached. expect(newStats2.cacheHitsCompleted).toBeGreaterThan(currentStats.cacheHitsCompleted); - - dns.prefetch("example.com", 443); - await Bun.sleep(32); - - // Verify the cache key includes the port number. - expect(dns.getCacheStats().cacheMisses).toBeGreaterThan(currentStats.cacheMisses); }); }); diff --git a/test/js/bun/net/echo.js b/test/js/bun/net/echo.js index 97c61d011a8b38..d9a9167d2a8839 100644 --- a/test/js/bun/net/echo.js +++ b/test/js/bun/net/echo.js @@ -35,7 +35,9 @@ function createOptions(type, message, closeOnDone) { } return { - hostname: "localhost", + // we don't use localhost here to ensure that only one connection is made + // because we perform exact matching on the printed output + hostname: "127.0.0.1", port: 0, socket: { close() { diff --git a/test/js/bun/net/socket-leak-fixture.js b/test/js/bun/net/socket-leak-fixture.js new file mode 100644 index 00000000000000..e029f9732e446e --- /dev/null +++ b/test/js/bun/net/socket-leak-fixture.js @@ -0,0 +1,43 @@ +import { openSync, closeSync } from "node:fs"; +import { expect } from "bun:test"; + +const server = Bun.listen({ + port: 0, + hostname: "localhost", + socket: { + open(socket) { + socket.end(); + }, + data(socket, data) {}, + }, +}); + +let connected = 0; +async function callback() { + await Bun.connect({ + port: server.port, + hostname: "localhost", + socket: { + open(socket) { + connected += 1; + }, + data(socket, data) {}, + }, + }); +} + +const fd_before = openSync("/dev/null", "w"); +closeSync(fd_before); + +// start 100 connections +const connections = await Promise.all(new Array(100).fill(0).map(callback)); + +expect(connected).toBe(100); + +const fd = openSync("/dev/null", "w"); +closeSync(fd); + +// ensure that we don't leak sockets when we initiate multiple connections +expect(fd - fd_before).toBeLessThan(5); + +server.stop(); diff --git a/test/js/bun/net/socket.test.ts b/test/js/bun/net/socket.test.ts index 4e61cea5091c2a..b948989a1d53df 100644 --- a/test/js/bun/net/socket.test.ts +++ b/test/js/bun/net/socket.test.ts @@ -352,3 +352,89 @@ it("it should not crash when returning a Error on client socket open", async () expect(result?.message).toBe("CustomError"); } }); + +it("it should only call open once", async () => { + const server = Bun.listen({ + port: 0, + hostname: "localhost", + socket: { + open(socket) { + socket.end("Hello"); + }, + data(socket, data) {}, + }, + }); + + const { resolve, reject, promise } = Promise.withResolvers(); + + let client: Socket | null = null; + let opened = false; + client = await Bun.connect({ + port: server.port, + hostname: "localhost", + socket: { + open(socket) { + expect(opened).toBe(false); + opened = true; + }, + connectError(socket, error) { + expect().fail("connectError should not be called"); + }, + close(socket) { + server.stop(); + resolve(); + }, + data(socket, data) {}, + }, + }); + + await promise; + expect(opened).toBe(true); +}); + +it.skipIf(isWindows)("should not leak file descriptors when connecting", async () => { + expect([fileURLToPath(new URL("./socket-leak-fixture.js", import.meta.url))]).toRun(); +}); + +it("should not call open if the connection had an error", async () => { + const server = Bun.listen({ + port: 0, + hostname: "0.0.0.0", + socket: { + open(socket) { + socket.end(); + }, + data(socket, data) {}, + }, + }); + + const { resolve, reject, promise } = Promise.withResolvers(); + + let client: Socket | null = null; + let hadError = false; + try { + client = await Bun.connect({ + port: server.port, + hostname: "::1", + socket: { + open(socket) { + expect().fail("open should not be called, the connection should fail"); + }, + connectError(socket, error) { + expect(hadError).toBe(false); + hadError = true; + resolve(); + }, + close(socket) { + expect().fail("close should not be called, the connection should fail"); + }, + data(socket, data) {}, + }, + }); + } catch (e) {} + + await Bun.sleep(50); + await promise; + server.stop(); + expect(hadError).toBe(true); +}); diff --git a/test/js/node/tls/node-tls-cert.test.ts b/test/js/node/tls/node-tls-cert.test.ts index 02cbb855988be9..067468a2f6eceb 100644 --- a/test/js/node/tls/node-tls-cert.test.ts +++ b/test/js/node/tls/node-tls-cert.test.ts @@ -77,7 +77,7 @@ function connect(options: any) { resolveOrReject(); }) .listen(0, function () { - const optClient = { ...options.client, port: server.server.address().port }; + const optClient = { ...options.client, port: server.server.address().port, host: "127.0.0.1" }; try { const conn = tls .connect(optClient, () => { diff --git a/test/js/node/tls/node-tls-context.test.ts b/test/js/node/tls/node-tls-context.test.ts index 54ae259cd25ae6..e68399360d2c48 100644 --- a/test/js/node/tls/node-tls-context.test.ts +++ b/test/js/node/tls/node-tls-context.test.ts @@ -115,6 +115,7 @@ describe("tls.Server", () => { { ...clientOptionsBase, port: (server.address() as AddressInfo).port, + host: "127.0.0.1", servername, }, () => { @@ -174,6 +175,7 @@ describe("tls.Server", () => { server.listen(0, () => { const options = { port: (server?.address() as AddressInfo).port, + host: "127.0.0.1", key: agent1Key, cert: agent1Cert, ca: [ca1], @@ -278,6 +280,7 @@ describe("tls.Server", () => { { ...options, port: (server.address() as AddressInfo).port, + host: "127.0.0.1", rejectUnauthorized: false, }, () => {