diff --git a/include/misc/fdtable.h b/include/misc/fdtable.h index 8383bcc1678df4..d80cc550c071e6 100644 --- a/include/misc/fdtable.h +++ b/include/misc/fdtable.h @@ -89,6 +89,16 @@ void z_free_fd(int fd); */ void *z_get_fd_obj(int fd, const struct fd_op_vtable *vtable, int err); +/** + * @brief Get underlying object pointer and vtable pointer from file descriptor. + * + * @param fd File descriptor previously returned by z_reserve_fd() + * @param vtable A pointer to a pointer variable to store the vtable + * + * @return Object pointer or NULL, with errno set + */ +void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable); + /** * Request codes for fd_op_vtable.ioctl(). * @@ -101,6 +111,9 @@ enum { ZFD_IOCTL_CLOSE = 1, ZFD_IOCTL_FSYNC, ZFD_IOCTL_LSEEK, + ZFD_IOCTL_FCNTL, + ZFD_IOCTL_POLL_PREPARE, + ZFD_IOCTL_POLL_UPDATE, }; #ifdef __cplusplus diff --git a/include/net/socket.h b/include/net/socket.h index 75834282540bb8..efec03bcbb45da 100644 --- a/include/net/socket.h +++ b/include/net/socket.h @@ -169,160 +169,81 @@ int zsock_getaddrinfo(const char *host, const char *service, const struct zsock_addrinfo *hints, struct zsock_addrinfo **res); -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - -int ztls_socket(int family, int type, int proto); -int ztls_close(int sock); -int ztls_bind(int sock, const struct sockaddr *addr, socklen_t addrlen); -int ztls_connect(int sock, const struct sockaddr *addr, socklen_t addrlen); -int ztls_listen(int sock, int backlog); -int ztls_accept(int sock, struct sockaddr *addr, socklen_t *addrlen); -ssize_t ztls_send(int sock, const void *buf, size_t len, int flags); -ssize_t ztls_recv(int sock, void *buf, size_t max_len, int flags); -ssize_t ztls_sendto(int sock, const void *buf, size_t len, int flags, - const struct sockaddr *dest_addr, socklen_t addrlen); -ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, - struct sockaddr *src_addr, socklen_t *addrlen); -int ztls_fcntl(int sock, int cmd, int flags); -int ztls_poll(struct zsock_pollfd *fds, int nfds, int timeout); -int ztls_getsockopt(int sock, int level, int optname, - void *optval, socklen_t *optlen); -int ztls_setsockopt(int sock, int level, int optname, - const void *optval, socklen_t optlen); - -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ - #if defined(CONFIG_NET_SOCKETS_POSIX_NAMES) #define pollfd zsock_pollfd #if !defined(CONFIG_NET_SOCKETS_OFFLOAD) static inline int socket(int family, int type, int proto) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_socket(family, type, proto); -#else return zsock_socket(family, type, proto); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int close(int sock) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_close(sock); -#else return zsock_close(sock); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int bind(int sock, const struct sockaddr *addr, socklen_t addrlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_bind(sock, addr, addrlen); -#else return zsock_bind(sock, addr, addrlen); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int connect(int sock, const struct sockaddr *addr, socklen_t addrlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_connect(sock, addr, addrlen); -#else return zsock_connect(sock, addr, addrlen); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int listen(int sock, int backlog) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_listen(sock, backlog); -#else return zsock_listen(sock, backlog); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int accept(int sock, struct sockaddr *addr, socklen_t *addrlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_accept(sock, addr, addrlen); -#else return zsock_accept(sock, addr, addrlen); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t send(int sock, const void *buf, size_t len, int flags) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_send(sock, buf, len, flags); -#else return zsock_send(sock, buf, len, flags); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t recv(int sock, void *buf, size_t max_len, int flags) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_recv(sock, buf, max_len, flags); -#else return zsock_recv(sock, buf, max_len, flags); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } /* This conflicts with fcntl.h, so code must include fcntl.h before socket.h: */ -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) -#define fcntl ztls_fcntl -#else #define fcntl zsock_fcntl -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ static inline ssize_t sendto(int sock, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_sendto(sock, buf, len, flags, dest_addr, addrlen); -#else return zsock_sendto(sock, buf, len, flags, dest_addr, addrlen); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t recvfrom(int sock, void *buf, size_t max_len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_recvfrom(sock, buf, max_len, flags, src_addr, addrlen); -#else return zsock_recvfrom(sock, buf, max_len, flags, src_addr, addrlen); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int poll(struct zsock_pollfd *fds, int nfds, int timeout) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_poll(fds, nfds, timeout); -#else return zsock_poll(fds, nfds, timeout); -#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int getsockopt(int sock, int level, int optname, void *optval, socklen_t *optlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_getsockopt(sock, level, optname, optval, optlen); -#else return zsock_getsockopt(sock, level, optname, optval, optlen); -#endif } static inline int setsockopt(int sock, int level, int optname, const void *optval, socklen_t optlen) { -#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) - return ztls_setsockopt(sock, level, optname, optval, optlen); -#else return zsock_setsockopt(sock, level, optname, optval, optlen); -#endif } static inline int getaddrinfo(const char *host, const char *service, diff --git a/lib/fdtable.c b/lib/fdtable.c index 9a3818c3d4c593..ce11f457584376 100644 --- a/lib/fdtable.c +++ b/lib/fdtable.c @@ -89,6 +89,20 @@ void *z_get_fd_obj(int fd, const struct fd_op_vtable *vtable, int err) return fd_entry->obj; } +void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable) +{ + struct fd_entry *fd_entry; + + if (_check_fd(fd) < 0) { + return NULL; + } + + fd_entry = &fdtable[fd]; + *vtable = fd_entry->vtable; + + return fd_entry->obj; +} + int z_reserve_fd(void) { int fd; diff --git a/subsys/net/lib/sockets/sockets.c b/subsys/net/lib/sockets/sockets.c index efdd5deaf48570..c49dc60acf2033 100644 --- a/subsys/net/lib/sockets/sockets.c +++ b/subsys/net/lib/sockets/sockets.c @@ -23,7 +23,24 @@ LOG_MODULE_REGISTER(net_sock, CONFIG_NET_SOCKETS_LOG_LEVEL); #define SET_ERRNO(x) \ { int _err = x; if (_err < 0) { errno = -_err; return -1; } } -static const struct fd_op_vtable sock_fd_op_vtable; +#define VTABLE_CALL(fn, sock, ...) \ + do { \ + const struct socket_op_vtable *vtable; \ + void *ctx = get_sock_vtable(sock, &vtable); \ + if (ctx == NULL) { \ + return -1; \ + } \ + return vtable->fn(ctx, __VA_ARGS__); \ + } while (0) + +const struct socket_op_vtable sock_fd_op_vtable; + +static inline void *get_sock_vtable( + int sock, const struct socket_op_vtable **vtable) +{ + return z_get_fd_obj_and_vtable(sock, + (const struct fd_op_vtable **)vtable); +} static void zsock_received_cb(struct net_context *ctx, struct net_pkt *pkt, int status, void *user_data); @@ -58,12 +75,7 @@ static void zsock_flush_queue(struct net_context *ctx) k_fifo_cancel_wait(&ctx->recv_q); } -static inline struct net_context *sock_to_net_ctx(int sock) -{ - return z_get_fd_obj(sock, &sock_fd_op_vtable, ENOTSOCK); -} - -int _impl_zsock_socket(int family, int type, int proto) +int zsock_socket_internal(int family, int type, int proto) { int fd = z_reserve_fd(); struct net_context *ctx; @@ -93,11 +105,23 @@ int _impl_zsock_socket(int family, int type, int proto) _k_object_recycle(ctx); #endif - z_finalize_fd(fd, ctx, &sock_fd_op_vtable); + z_finalize_fd(fd, ctx, (const struct fd_op_vtable *)&sock_fd_op_vtable); return fd; } +int _impl_zsock_socket(int family, int type, int proto) +{ +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + if (((proto >= IPPROTO_TLS_1_0) && (proto <= IPPROTO_TLS_1_2)) || + (proto >= IPPROTO_DTLS_1_0 && proto <= IPPROTO_DTLS_1_2)) { + return ztls_socket(family, type, proto); + } +#endif + + return zsock_socket_internal(family, type, proto); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_socket, family, type, proto) { @@ -133,7 +157,8 @@ int zsock_close_ctx(struct net_context *ctx) int _impl_zsock_close(int sock) { - struct net_context *ctx = sock_to_net_ctx(sock); + const struct socket_op_vtable *vtable; + void *ctx = get_sock_vtable(sock, &vtable); if (ctx == NULL) { return -1; @@ -141,7 +166,7 @@ int _impl_zsock_close(int sock) z_free_fd(sock); - return zsock_close_ctx(ctx); + return vtable->fd_vtable.ioctl(ctx, ZFD_IOCTL_CLOSE); } #ifdef CONFIG_USERSPACE @@ -209,14 +234,9 @@ static void zsock_received_cb(struct net_context *ctx, struct net_pkt *pkt, k_fifo_put(&ctx->recv_q, pkt); } -int _impl_zsock_bind(int sock, const struct sockaddr *addr, socklen_t addrlen) +int zsock_bind_ctx(struct net_context *ctx, const struct sockaddr *addr, + socklen_t addrlen) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - SET_ERRNO(net_context_bind(ctx, addr, addrlen)); /* For DGRAM socket, we expect to receive packets after call to * bind(), but for STREAM socket, next expected operation is @@ -230,6 +250,11 @@ int _impl_zsock_bind(int sock, const struct sockaddr *addr, socklen_t addrlen) return 0; } +int _impl_zsock_bind(int sock, const struct sockaddr *addr, socklen_t addrlen) +{ + VTABLE_CALL(bind, sock, addr, addrlen); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_bind, sock, addr, addrlen) { @@ -243,22 +268,23 @@ Z_SYSCALL_HANDLER(zsock_bind, sock, addr, addrlen) } #endif /* CONFIG_USERSPACE */ -int _impl_zsock_connect(int sock, const struct sockaddr *addr, - socklen_t addrlen) +int zsock_connect_ctx(struct net_context *ctx, const struct sockaddr *addr, + socklen_t addrlen) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - SET_ERRNO(net_context_connect(ctx, addr, addrlen, NULL, K_FOREVER, NULL)); - SET_ERRNO(net_context_recv(ctx, zsock_received_cb, K_NO_WAIT, ctx->user_data)); + SET_ERRNO(net_context_recv(ctx, zsock_received_cb, K_NO_WAIT, + ctx->user_data)); return 0; } +int _impl_zsock_connect(int sock, const struct sockaddr *addr, + socklen_t addrlen) +{ + VTABLE_CALL(connect, sock, addr, addrlen); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_connect, sock, addr, addrlen) { @@ -272,20 +298,19 @@ Z_SYSCALL_HANDLER(zsock_connect, sock, addr, addrlen) } #endif /* CONFIG_USERSPACE */ -int _impl_zsock_listen(int sock, int backlog) +int zsock_listen_ctx(struct net_context *ctx, int backlog) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - SET_ERRNO(net_context_listen(ctx, backlog)); SET_ERRNO(net_context_accept(ctx, zsock_accepted_cb, K_NO_WAIT, ctx)); return 0; } +int _impl_zsock_listen(int sock, int backlog) +{ + VTABLE_CALL(listen, sock, backlog); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_listen, sock, backlog) { @@ -293,15 +318,11 @@ Z_SYSCALL_HANDLER(zsock_listen, sock, backlog) } #endif /* CONFIG_USERSPACE */ -int _impl_zsock_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) +int zsock_accept_ctx(struct net_context *parent, struct sockaddr *addr, + socklen_t *addrlen) { - struct net_context *parent = sock_to_net_ctx(sock); int fd; - if (parent == NULL) { - return -1; - } - fd = z_reserve_fd(); if (fd < 0) { return -1; @@ -330,11 +351,16 @@ int _impl_zsock_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) } } - z_finalize_fd(fd, ctx, &sock_fd_op_vtable); + z_finalize_fd(fd, ctx, (const struct fd_op_vtable *)&sock_fd_op_vtable); return fd; } +int _impl_zsock_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) +{ + VTABLE_CALL(accept, sock, addr, addrlen); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_accept, sock, addr, addrlen) { @@ -416,13 +442,7 @@ ssize_t zsock_sendto_ctx(struct net_context *ctx, const void *buf, size_t len, ssize_t _impl_zsock_sendto(int sock, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - - return zsock_sendto_ctx(ctx, buf, len, flags, dest_addr, addrlen); + VTABLE_CALL(sendto, sock, buf, len, flags, dest_addr, addrlen); } #ifdef CONFIG_USERSPACE @@ -626,13 +646,7 @@ ssize_t zsock_recvfrom_ctx(struct net_context *ctx, void *buf, size_t max_len, ssize_t _impl_zsock_recvfrom(int sock, void *buf, size_t max_len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - - return zsock_recvfrom_ctx(ctx, buf, max_len, flags, src_addr, addrlen); + VTABLE_CALL(recvfrom, sock, buf, max_len, flags, src_addr, addrlen); } #ifdef CONFIG_USERSPACE @@ -668,17 +682,8 @@ Z_SYSCALL_HANDLER(zsock_recvfrom, sock, buf, max_len, flags, src_addr, } #endif /* CONFIG_USERSPACE */ -/* As this is limited function, we don't follow POSIX signature, with - * "..." instead of last arg. - */ -int _impl_zsock_fcntl(int sock, int cmd, int flags) +int zsock_fcntl_ctx(struct net_context *ctx, int cmd, int flags) { - struct net_context *ctx = sock_to_net_ctx(sock); - - if (ctx == NULL) { - return -1; - } - switch (cmd) { case F_GETFL: if (sock_is_nonblock(ctx)) { @@ -698,6 +703,14 @@ int _impl_zsock_fcntl(int sock, int cmd, int flags) } } +/* As this is limited function, we don't follow POSIX signature, with + * "..." instead of last arg. + */ +int _impl_zsock_fcntl(int sock, int cmd, int flags) +{ + VTABLE_CALL(fd_vtable.ioctl, sock, cmd, flags); +} + #ifdef CONFIG_USERSPACE Z_SYSCALL_HANDLER(zsock_fcntl, sock, cmd, flags) { @@ -705,14 +718,66 @@ Z_SYSCALL_HANDLER(zsock_fcntl, sock, cmd, flags) } #endif +static int zsock_poll_prepare_ctx(struct net_context *ctx, + struct zsock_pollfd *pfd, + struct k_poll_event **pev, + struct k_poll_event *pev_end) +{ + if (pfd->events & ZSOCK_POLLIN) { + if (*pev == pev_end) { + errno = ENOMEM; + return -1; + } + + (*pev)->obj = &ctx->recv_q; + (*pev)->type = K_POLL_TYPE_FIFO_DATA_AVAILABLE; + (*pev)->mode = K_POLL_MODE_NOTIFY_ONLY; + (*pev)->state = K_POLL_STATE_NOT_READY; + (*pev)++; + } + + return 0; +} + +static int zsock_poll_update_ctx(struct net_context *ctx, + struct zsock_pollfd *pfd, + struct k_poll_event **pev) +{ + ARG_UNUSED(ctx); + + /* For now, assume that socket is always writable */ + if (pfd->events & ZSOCK_POLLOUT) { + pfd->revents |= ZSOCK_POLLOUT; + } + + if (pfd->events & ZSOCK_POLLIN) { + if ((*pev)->state != K_POLL_STATE_NOT_READY) { + pfd->revents |= ZSOCK_POLLIN; + } + (*pev)++; + } + + return 0; +} + +static inline int time_left(u32_t start, u32_t timeout) +{ + u32_t elapsed = k_uptime_get_32() - start; + + return timeout - elapsed; +} + int _impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int timeout) { - int i; + bool retry; int ret = 0; + int i, remaining_time; struct zsock_pollfd *pfd; struct k_poll_event poll_events[CONFIG_NET_SOCKETS_POLL_MAX]; struct k_poll_event *pev; struct k_poll_event *pev_end = poll_events + ARRAY_SIZE(poll_events); + const struct socket_op_vtable *vtable; + u32_t entry_time = k_uptime_get_32(); if (timeout < 0) { timeout = K_FOREVER; @@ -727,69 +792,86 @@ int _impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int timeout) continue; } - ctx = sock_to_net_ctx(pfd->fd); - + ctx = get_sock_vtable(pfd->fd, &vtable); if (ctx == NULL) { /* Will set POLLNVAL in return loop */ continue; } - if (pfd->events & ZSOCK_POLLIN) { - if (pev == pev_end) { - errno = ENOMEM; - return -1; + if (vtable->fd_vtable.ioctl(ctx, ZFD_IOCTL_POLL_PREPARE, + pfd, &pev, pev_end) < 0) { + if (errno == EALREADY) { + timeout = K_NO_WAIT; + continue; } - pev->obj = &ctx->recv_q; - pev->type = K_POLL_TYPE_FIFO_DATA_AVAILABLE; - pev->mode = K_POLL_MODE_NOTIFY_ONLY; - pev->state = K_POLL_STATE_NOT_READY; - pev++; + return -1; } } - ret = k_poll(poll_events, pev - poll_events, timeout); - /* EAGAIN when timeout expired, EINTR when cancelled (i.e. EOF) */ - if (ret != 0 && ret != -EAGAIN && ret != -EINTR) { - errno = -ret; - return -1; - } + remaining_time = timeout; - ret = 0; + do { + ret = k_poll(poll_events, pev - poll_events, remaining_time); + /* EAGAIN when timeout expired, EINTR when cancelled (i.e. EOF) */ + if (ret != 0 && ret != -EAGAIN && ret != -EINTR) { + errno = -ret; + return -1; + } - pev = poll_events; - for (pfd = fds, i = nfds; i--; pfd++) { - struct net_context *ctx; + retry = false; + ret = 0; - pfd->revents = 0; + pev = poll_events; + for (pfd = fds, i = nfds; i--; pfd++) { + struct net_context *ctx; - if (pfd->fd < 0) { - continue; - } + pfd->revents = 0; - ctx = sock_to_net_ctx(pfd->fd); - if (ctx == NULL) { - pfd->revents = ZSOCK_POLLNVAL; - ret++; - continue; - } + if (pfd->fd < 0) { + continue; + } - /* For now, assume that socket is always writable */ - if (pfd->events & ZSOCK_POLLOUT) { - pfd->revents |= ZSOCK_POLLOUT; - } + ctx = get_sock_vtable(pfd->fd, &vtable); + if (ctx == NULL) { + pfd->revents = ZSOCK_POLLNVAL; + ret++; + continue; + } + + if (vtable->fd_vtable.ioctl(ctx, ZFD_IOCTL_POLL_UPDATE, + pfd, &pev) < 0) { + if (errno == EAGAIN) { + retry = true; + continue; + } + + return -1; + } - if (pfd->events & ZSOCK_POLLIN) { - if (pev->state != K_POLL_STATE_NOT_READY) { - pfd->revents |= ZSOCK_POLLIN; + if (pfd->revents != 0) { + ret++; } - pev++; } - if (pfd->revents != 0) { - ret++; + if (retry) { + if (ret > 0) { + break; + } + + if (timeout == K_NO_WAIT) { + break; + } + + if (timeout != K_FOREVER) { + /* Recalculate the timeout value. */ + remaining_time = time_left(entry_time, timeout); + if (remaining_time <= 0) { + break; + } + } } - } + } while (retry); return ret; } @@ -861,8 +943,21 @@ Z_SYSCALL_HANDLER(zsock_inet_pton, family, src, dst) } #endif +int zsock_getsockopt_ctx(struct net_context *ctx, int level, int optname, + void *optval, socklen_t *optlen) +{ + errno = ENOPROTOOPT; + return -1; +} + int zsock_getsockopt(int sock, int level, int optname, void *optval, socklen_t *optlen) +{ + VTABLE_CALL(getsockopt, sock, level, optname, optval, optlen); +} + +int zsock_setsockopt_ctx(struct net_context *ctx, int level, int optname, + const void *optval, socklen_t optlen) { errno = ENOPROTOOPT; return -1; @@ -871,8 +966,7 @@ int zsock_getsockopt(int sock, int level, int optname, int zsock_setsockopt(int sock, int level, int optname, const void *optval, socklen_t optlen) { - errno = ENOPROTOOPT; - return -1; + VTABLE_CALL(setsockopt, sock, level, optname, optval, optlen); } static ssize_t sock_read_vmeth(void *obj, void *buffer, size_t count) @@ -891,14 +985,115 @@ static int sock_ioctl_vmeth(void *obj, unsigned int request, ...) case ZFD_IOCTL_CLOSE: return zsock_close_ctx(obj); + case ZFD_IOCTL_FCNTL: { + va_list args; + int cmd, flags; + + va_start(args, request); + cmd = va_arg(args, int); + flags = va_arg(args, int); + va_end(args); + + return zsock_fcntl_ctx(obj, cmd, flags); + } + + case ZFD_IOCTL_POLL_PREPARE: { + va_list args; + struct zsock_pollfd *pfd; + struct k_poll_event **pev; + struct k_poll_event *pev_end; + + va_start(args, request); + pfd = va_arg(args, struct zsock_pollfd *); + pev = va_arg(args, struct k_poll_event **); + pev_end = va_arg(args, struct k_poll_event *); + va_end(args); + + return zsock_poll_prepare_ctx(obj, pfd, pev, pev_end); + } + + case ZFD_IOCTL_POLL_UPDATE: { + va_list args; + struct zsock_pollfd *pfd; + struct k_poll_event **pev; + + va_start(args, request); + pfd = va_arg(args, struct zsock_pollfd *); + pev = va_arg(args, struct k_poll_event **); + va_end(args); + + return zsock_poll_update_ctx(obj, pfd, pev); + } + default: errno = EOPNOTSUPP; return -1; } } -static const struct fd_op_vtable sock_fd_op_vtable = { - .read = sock_read_vmeth, - .write = sock_write_vmeth, - .ioctl = sock_ioctl_vmeth, +static int sock_bind_vmeth(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + return zsock_bind_ctx(obj, addr, addrlen); +} + +static int sock_connect_vmeth(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + return zsock_connect_ctx(obj, addr, addrlen); +} + +static int sock_listen_vmeth(void *obj, int backlog) +{ + return zsock_listen_ctx(obj, backlog); +} + +static int sock_accept_vmeth(void *obj, struct sockaddr *addr, + socklen_t *addrlen) +{ + return zsock_accept_ctx(obj, addr, addrlen); +} + +static ssize_t sock_sendto_vmeth(void *obj, const void *buf, size_t len, + int flags, const struct sockaddr *dest_addr, + socklen_t addrlen) +{ + return zsock_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen); +} + +static ssize_t sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len, + int flags, struct sockaddr *src_addr, + socklen_t *addrlen) +{ + return zsock_recvfrom_ctx(obj, buf, max_len, flags, + src_addr, addrlen); +} + +static int sock_getsockopt_vmeth(void *obj, int level, int optname, + void *optval, socklen_t *optlen) +{ + return zsock_getsockopt_ctx(obj, level, optname, optval, optlen); +} + +static int sock_setsockopt_vmeth(void *obj, int level, int optname, + const void *optval, socklen_t optlen) +{ + return zsock_setsockopt_ctx(obj, level, optname, optval, optlen); +} + + +const struct socket_op_vtable sock_fd_op_vtable = { + .fd_vtable = { + .read = sock_read_vmeth, + .write = sock_write_vmeth, + .ioctl = sock_ioctl_vmeth, + }, + .bind = sock_bind_vmeth, + .connect = sock_connect_vmeth, + .listen = sock_listen_vmeth, + .accept = sock_accept_vmeth, + .sendto = sock_sendto_vmeth, + .recvfrom = sock_recvfrom_vmeth, + .getsockopt = sock_getsockopt_vmeth, + .setsockopt = sock_setsockopt_vmeth, }; diff --git a/subsys/net/lib/sockets/sockets_internal.h b/subsys/net/lib/sockets/sockets_internal.h index a885e034a0df7f..e096547404922a 100644 --- a/subsys/net/lib/sockets/sockets_internal.h +++ b/subsys/net/lib/sockets/sockets_internal.h @@ -7,6 +7,8 @@ #ifndef _SOCKETS_INTERNAL_H_ #define _SOCKETS_INTERNAL_H_ +#include + #define SOCK_EOF 1 #define SOCK_NONBLOCK 2 @@ -28,4 +30,23 @@ static inline u32_t sock_get_flag(struct net_context *ctx, u32_t mask) #define sock_set_eof(ctx) sock_set_flag(ctx, SOCK_EOF, SOCK_EOF) #define sock_is_nonblock(ctx) sock_get_flag(ctx, SOCK_NONBLOCK) +struct socket_op_vtable { + struct fd_op_vtable fd_vtable; + int (*bind)(void *obj, const struct sockaddr *addr, socklen_t addrlen); + int (*connect)(void *obj, const struct sockaddr *addr, + socklen_t addrlen); + int (*listen)(void *obj, int backlog); + int (*accept)(void *obj, struct sockaddr *addr, socklen_t *addrlen); + ssize_t (*sendto)(void *obj, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen); + ssize_t (*recvfrom)(void *obj, void *buf, size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen); + int (*getsockopt)(void *obj, int level, int optname, + void *optval, socklen_t *optlen); + int (*setsockopt)(void *obj, int level, int optname, + const void *optval, socklen_t optlen); +}; + +int ztls_socket(int family, int type, int proto); + #endif /* _SOCKETS_INTERNAL_H_ */ diff --git a/subsys/net/lib/sockets/sockets_tls.c b/subsys/net/lib/sockets/sockets_tls.c index fc5aaed544006a..80e9fc83526f48 100644 --- a/subsys/net/lib/sockets/sockets_tls.c +++ b/subsys/net/lib/sockets/sockets_tls.c @@ -15,6 +15,7 @@ LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL); #include #include #include +#include #include #if defined(CONFIG_MBEDTLS) @@ -37,6 +38,10 @@ LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL); #include "sockets_internal.h" #include "tls_internal.h" +extern const struct socket_op_vtable sock_fd_op_vtable; + +static const struct socket_op_vtable tls_sock_fd_op_vtable; + /** A list of secure tags that TLS context should use. */ struct sec_tag_list { /** An array of secure tags referencing TLS credentials. */ @@ -144,11 +149,6 @@ static struct k_mutex context_lock; #define IS_LISTENING(context) (net_context_get_state(context) == \ NET_CONTEXT_LISTENING) -static inline struct net_context *sock_to_net_ctx(int sock) -{ - return z_get_fd_obj(sock, NULL, ENOTSOCK); -} - #if defined(MBEDTLS_DEBUG_C) && (CONFIG_NET_SOCKETS_LOG_LEVEL >= LOG_LEVEL_DBG) static void tls_debug(void *ctx, int level, const char *file, int line, const char *str) @@ -450,17 +450,12 @@ static void dtls_peer_address_get(struct net_context *context, static int dtls_tx(void *ctx, const unsigned char *buf, size_t len) { - int sock = POINTER_TO_INT(ctx); - struct net_context *context = sock_to_net_ctx(sock); + struct net_context *net_ctx = ctx; ssize_t sent; - if (context == NULL) { - return MBEDTLS_ERR_NET_INVALID_CONTEXT; - } - - sent = zsock_sendto(sock, buf, len, context->tls->flags, - &context->tls->dtls_peer_addr, - context->tls->dtls_peer_addrlen); + sent = sock_fd_op_vtable.sendto(net_ctx, buf, len, net_ctx->tls->flags, + &net_ctx->tls->dtls_peer_addr, + net_ctx->tls->dtls_peer_addrlen); if (sent < 0) { if (errno == EAGAIN) { return MBEDTLS_ERR_SSL_WANT_WRITE; @@ -474,39 +469,38 @@ static int dtls_tx(void *ctx, const unsigned char *buf, size_t len) static int dtls_rx(void *ctx, unsigned char *buf, size_t len, uint32_t timeout) { - int sock = POINTER_TO_INT(ctx); - struct net_context *context = sock_to_net_ctx(sock); - bool is_block = !((context->tls->flags & ZSOCK_MSG_DONTWAIT) || - sock_is_nonblock(context)); + struct net_context *net_ctx = ctx; + bool is_block = !((net_ctx->tls->flags & ZSOCK_MSG_DONTWAIT) || + sock_is_nonblock(net_ctx)); int remaining_time = (timeout == 0) ? K_FOREVER : timeout; u32_t entry_time = k_uptime_get_32(); socklen_t addrlen = sizeof(struct sockaddr); struct sockaddr addr; int err; ssize_t received; - struct pollfd fds; bool retry; - - if (context == NULL) { - return MBEDTLS_ERR_NET_INVALID_CONTEXT; - } + struct k_poll_event pev; do { retry = false; /* mbedtLS does not allow blocking rx for DTLS, therefore use - * poll for timeout functionality. + * k_poll for timeout functionality. */ if (is_block) { - fds.fd = sock; - fds.events = POLLIN; - if (zsock_poll(&fds, 1, remaining_time) == 0) { + pev.obj = &net_ctx->recv_q; + pev.type = K_POLL_TYPE_FIFO_DATA_AVAILABLE; + pev.mode = K_POLL_MODE_NOTIFY_ONLY; + pev.state = K_POLL_STATE_NOT_READY; + + if (k_poll(&pev, 1, remaining_time) == -EAGAIN) { return MBEDTLS_ERR_SSL_TIMEOUT; } } - received = zsock_recvfrom(sock, buf, len, context->tls->flags, - &addr, &addrlen); + received = sock_fd_op_vtable.recvfrom( + net_ctx, buf, len, net_ctx->tls->flags, + &addr, &addrlen); if (received < 0) { if (errno == EAGAIN) { return MBEDTLS_ERR_SSL_WANT_READ; @@ -515,14 +509,14 @@ static int dtls_rx(void *ctx, unsigned char *buf, size_t len, uint32_t timeout) return MBEDTLS_ERR_NET_RECV_FAILED; } - if (context->tls->dtls_peer_addrlen == 0) { + if (net_ctx->tls->dtls_peer_addrlen == 0) { /* Only allow to store peer address for DTLS servers. */ - if (context->tls->options.role + if (net_ctx->tls->options.role == MBEDTLS_SSL_IS_SERVER) { - dtls_peer_address_set(context, &addr, addrlen); + dtls_peer_address_set(net_ctx, &addr, addrlen); err = mbedtls_ssl_set_client_transport_id( - &context->tls->ssl, + &net_ctx->tls->ssl, (const unsigned char *)&addr, addrlen); if (err < 0) { return err; @@ -533,7 +527,7 @@ static int dtls_rx(void *ctx, unsigned char *buf, size_t len, uint32_t timeout) */ return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED; } - } else if (!dtls_is_peer_addr_valid(context, &addr, addrlen)) { + } else if (!dtls_is_peer_addr_valid(net_ctx, &addr, addrlen)) { /* Received data from different peer, ignore it. */ retry = true; @@ -553,15 +547,11 @@ static int dtls_rx(void *ctx, unsigned char *buf, size_t len, uint32_t timeout) static int tls_tx(void *ctx, const unsigned char *buf, size_t len) { - int sock = POINTER_TO_INT(ctx); - struct net_context *context = sock_to_net_ctx(sock); + struct net_context *net_ctx = ctx; ssize_t sent; - if (context == NULL) { - return MBEDTLS_ERR_NET_INVALID_CONTEXT; - } - - sent = zsock_send(sock, buf, len, context->tls->flags); + sent = sock_fd_op_vtable.sendto(ctx, buf, len, + net_ctx->tls->flags, NULL, 0); if (sent < 0) { if (errno == EAGAIN) { return MBEDTLS_ERR_SSL_WANT_WRITE; @@ -575,15 +565,11 @@ static int tls_tx(void *ctx, const unsigned char *buf, size_t len) static int tls_rx(void *ctx, unsigned char *buf, size_t len) { - int sock = POINTER_TO_INT(ctx); - struct net_context *context = sock_to_net_ctx(sock); + struct net_context *net_ctx = ctx; ssize_t received; - if (context == NULL) { - return MBEDTLS_ERR_NET_INVALID_CONTEXT; - } - - received = zsock_recv(sock, buf, len, context->tls->flags); + received = sock_fd_op_vtable.recvfrom(ctx, buf, len, + net_ctx->tls->flags, NULL, 0); if (received < 0) { if (errno == EAGAIN) { return MBEDTLS_ERR_SSL_WANT_READ; @@ -815,8 +801,7 @@ static int tls_mbedtls_handshake(struct net_context *context, bool block) return ret; } -static int tls_mbedtls_init(int sock, struct net_context *context, - bool is_server) +static int tls_mbedtls_init(struct net_context *context, bool is_server) { int role, type, ret; @@ -827,11 +812,11 @@ static int tls_mbedtls_init(int sock, struct net_context *context, MBEDTLS_SSL_TRANSPORT_DATAGRAM; if (type == MBEDTLS_SSL_TRANSPORT_STREAM) { - mbedtls_ssl_set_bio(&context->tls->ssl, (void *)sock, + mbedtls_ssl_set_bio(&context->tls->ssl, context, tls_tx, tls_rx, NULL); } else { #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) - mbedtls_ssl_set_bio(&context->tls->ssl, (void *)sock, + mbedtls_ssl_set_bio(&context->tls->ssl, context, dtls_tx, NULL, dtls_rx); #else return -ENOTSUP; @@ -1109,7 +1094,13 @@ static int tls_opt_dtls_role_set(struct net_context *context, int ztls_socket(int family, int type, int proto) { enum net_ip_protocol_secure tls_proto = 0; - int sock, ret, err; + int fd = z_reserve_fd(); + int ret; + struct net_context *ctx; + + if (fd < 0) { + return -1; + } if (proto >= IPPROTO_TLS_1_0 && proto <= IPPROTO_TLS_1_2) { if (type != SOCK_STREAM) { @@ -1134,58 +1125,64 @@ int ztls_socket(int family, int type, int proto) #endif } - sock = zsock_socket(family, type, proto); - if (sock < 0) { - /* errno will be propagated */ + ret = net_context_get(family, type, proto, &ctx); + if (ret < 0) { + z_free_fd(fd); + errno = -ret; return -1; } - if (tls_proto != 0) { - /* If TLS protocol is used, allocate TLS context */ - struct net_context *context = sock_to_net_ctx(sock); + /* Initialize user_data, all other calls will preserve it */ + ctx->user_data = NULL; - context->tls = tls_alloc(); + /* recv_q and accept_q are in union */ + k_fifo_init(&ctx->recv_q); - if (!context->tls) { - ret = -ENOMEM; - goto error; +#ifdef CONFIG_USERSPACE + /* Set net context object as initialized and grant access to the + * calling thread (and only the calling thread) + */ + _k_object_recycle(ctx); +#endif + + if (tls_proto != 0) { + /* If TLS protocol is used, allocate TLS context */ + ctx->tls = tls_alloc(); + if (ctx->tls == NULL) { + z_free_fd(fd); + (void)net_context_put(ctx); + errno = ENOMEM; + return -1; } - context->tls->tls_version = tls_proto; + ctx->tls->tls_version = tls_proto; } - return sock; + z_finalize_fd( + fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable); -error: - err = zsock_close(sock); - __ASSERT(err == 0, "Socket close failed"); - - errno = -ret; - return -1; + return fd; } -int ztls_close(int sock) +int ztls_close_ctx(struct net_context *ctx) { - struct net_context *context = sock_to_net_ctx(sock); int ret, err = 0; - if (context == NULL) { - return -1; - } - if (context->tls) { + if (ctx->tls != NULL) { /* Try to send close notification. */ - context->tls->flags = 0; - (void)mbedtls_ssl_close_notify(&context->tls->ssl); + ctx->tls->flags = 0; + (void)mbedtls_ssl_close_notify(&ctx->tls->ssl); - err = tls_release(context->tls); + err = tls_release(ctx->tls); + } else { + err = -EBADF; } - ret = zsock_close(sock); + ret = sock_fd_op_vtable.fd_vtable.ioctl(ctx, ZFD_IOCTL_CLOSE); - /* In case zsock_close fails, we propagate errno value set by - * zsock_close. - * In case zsock_close succeeds, but tls_release fails, set errno + /* In case close fails, we propagate errno value set by close. + * In case close succeeds, but tls_release fails, set errno * according to tls_release return value. */ if (ret == 0 && err < 0) { @@ -1196,54 +1193,46 @@ int ztls_close(int sock) return ret; } -int ztls_bind(int sock, const struct sockaddr *addr, socklen_t addrlen) -{ - /* No extra action needed here. */ - return zsock_bind(sock, addr, addrlen); -} - -int ztls_connect(int sock, const struct sockaddr *addr, socklen_t addrlen) +int ztls_connect_ctx(struct net_context *ctx, const struct sockaddr *addr, + socklen_t addrlen) { int ret; - struct net_context *context = sock_to_net_ctx(sock); - if (context == NULL) { + if (ctx->tls == NULL) { + errno = EBADF; return -1; } - ret = zsock_connect(sock, addr, addrlen); + ret = sock_fd_op_vtable.connect(ctx, addr, addrlen); if (ret < 0) { - /* errno will be propagated */ - return -1; + return ret; } - if (context->tls) { - if (net_context_get_type(context) == SOCK_STREAM) { - /* Do the handshake for TLS, not DTLS. */ - ret = tls_mbedtls_init(sock, context, false); - if (ret < 0) { - goto error; - } + if (net_context_get_type(ctx) == SOCK_STREAM) { + /* Do the handshake for TLS, not DTLS. */ + ret = tls_mbedtls_init(ctx, false); + if (ret < 0) { + goto error; + } - /* Do not use any socket flags during the handshake. */ - context->tls->flags = 0; + /* Do not use any socket flags during the handshake. */ + ctx->tls->flags = 0; - /* TODO For simplicity, TLS handshake blocks the socket - * even for non-blocking socket. - */ - ret = tls_mbedtls_handshake(context, true); - if (ret < 0) { - goto error; - } - } else { + /* TODO For simplicity, TLS handshake blocks the socket + * even for non-blocking socket. + */ + ret = tls_mbedtls_handshake(ctx, true); + if (ret < 0) { + goto error; + } + } else { #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) - /* Just store the address. */ - dtls_peer_address_set(context, addr, addrlen); + /* Just store the address. */ + dtls_peer_address_set(ctx, addr, addrlen); #else - ret = -ENOTSUP; - goto error; + ret = -ENOTSUP; + goto error; #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */ - } } return 0; @@ -1253,86 +1242,93 @@ int ztls_connect(int sock, const struct sockaddr *addr, socklen_t addrlen) return -1; } -int ztls_listen(int sock, int backlog) -{ - /* No extra action needed here. */ - return zsock_listen(sock, backlog); -} - -int ztls_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) +int ztls_accept_ctx(struct net_context *parent, struct sockaddr *addr, + socklen_t *addrlen) { - int child_sock, ret, err; - struct net_context *parent_context = sock_to_net_ctx(sock); - struct net_context *child_context = NULL; + int ret, err, fd; + struct net_context *child; - if (parent_context == NULL) { + if (parent->tls == NULL) { + errno = EBADF; return -1; } - child_sock = zsock_accept(sock, addr, addrlen); - if (child_sock < 0) { - /* errno will be propagated */ + fd = z_reserve_fd(); + if (fd < 0) { return -1; } - if (parent_context->tls) { - child_context = sock_to_net_ctx(child_sock); - - child_context->tls = tls_clone(parent_context->tls); - if (!child_context->tls) { - ret = -ENOMEM; - goto error; - } + child = k_fifo_get(&parent->accept_q, K_FOREVER); - ret = tls_mbedtls_init(child_sock, child_context, true); - if (ret < 0) { - goto error; - } + #ifdef CONFIG_USERSPACE + _k_object_recycle(child); + #endif - /* Do not use any socket flags during the handshake. */ - child_context->tls->flags = 0; + if (addr != NULL && addrlen != NULL) { + int len = min(*addrlen, sizeof(child->remote)); - /* TODO For simplicity, TLS handshake blocks the socket even for - * non-blocking socket. + memcpy(addr, &child->remote, len); + /* addrlen is a value-result argument, set to actual + * size of source address */ - ret = tls_mbedtls_handshake(child_context, true); - if (ret < 0) { + if (child->remote.sa_family == AF_INET) { + *addrlen = sizeof(struct sockaddr_in); + } else if (child->remote.sa_family == AF_INET6) { + *addrlen = sizeof(struct sockaddr_in6); + } else { + ret = -ENOTSUP; goto error; } } - return child_sock; + z_finalize_fd( + fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable); + + child->tls = tls_clone(parent->tls); + if (!child->tls) { + ret = -ENOMEM; + goto error; + } + + ret = tls_mbedtls_init(child, true); + if (ret < 0) { + goto error; + } + + /* Do not use any socket flags during the handshake. */ + child->tls->flags = 0; + + /* TODO For simplicity, TLS handshake blocks the socket even for + * non-blocking socket. + */ + ret = tls_mbedtls_handshake(child, true); + if (ret < 0) { + goto error; + } + + return fd; error: - if (child_context && child_context->tls) { - err = tls_release(child_context->tls); + if (child->tls != NULL) { + err = tls_release(child->tls); __ASSERT(err == 0, "TLS context release failed"); } - err = zsock_close(child_sock); + err = sock_fd_op_vtable.fd_vtable.ioctl(child, ZFD_IOCTL_CLOSE); __ASSERT(err == 0, "Child socket close failed"); + z_free_fd(fd); + errno = -ret; return -1; } -ssize_t ztls_send(int sock, const void *buf, size_t len, int flags) -{ - return ztls_sendto(sock, buf, len, flags, NULL, 0); -} - -ssize_t ztls_recv(int sock, void *buf, size_t max_len, int flags) -{ - return ztls_recvfrom(sock, buf, max_len, flags, NULL, 0); -} - -static ssize_t send_tls(int sock, const void *buf, +static ssize_t send_tls(struct net_context *ctx, const void *buf, size_t len, int flags) { - struct net_context *context = sock_to_net_ctx(sock); int ret; - ret = mbedtls_ssl_write(&context->tls->ssl, buf, len); + ret = mbedtls_ssl_write(&ctx->tls->ssl, buf, len); if (ret >= 0) { return ret; } @@ -1348,119 +1344,114 @@ static ssize_t send_tls(int sock, const void *buf, } #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) -static ssize_t sendto_dtls_client(int sock, const void *buf, size_t len, - int flags, const struct sockaddr *dest_addr, +static ssize_t sendto_dtls_client(struct net_context *ctx, const void *buf, + size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen) { - struct net_context *context = sock_to_net_ctx(sock); int ret; if (!dest_addr) { /* No address provided, check if we have stored one, * otherwise return error. */ - if (context->tls->dtls_peer_addrlen == 0) { + if (ctx->tls->dtls_peer_addrlen == 0) { ret = -EDESTADDRREQ; goto error; } - } else if (context->tls->dtls_peer_addrlen == 0) { + } else if (ctx->tls->dtls_peer_addrlen == 0) { /* Address provided and no peer address stored. */ - dtls_peer_address_set(context, dest_addr, addrlen); - } else if (!dtls_is_peer_addr_valid(context, dest_addr, addrlen) != 0) { + dtls_peer_address_set(ctx, dest_addr, addrlen); + } else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) { /* Address provided but it does not match stored one */ ret = -EISCONN; goto error; } - if (!context->tls->is_initialized) { - ret = tls_mbedtls_init(sock, context, false); + if (!ctx->tls->is_initialized) { + ret = tls_mbedtls_init(ctx, false); if (ret < 0) { goto error; } } - if (!context->tls->tls_established) { + if (!ctx->tls->tls_established) { /* TODO For simplicity, TLS handshake blocks the socket even for * non-blocking socket. */ - ret = tls_mbedtls_handshake(context, true); + ret = tls_mbedtls_handshake(ctx, true); if (ret < 0) { goto error; } } - return send_tls(sock, buf, len, flags); + return send_tls(ctx, buf, len, flags); error: errno = -ret; return -1; } -static ssize_t sendto_dtls_server(int sock, const void *buf, size_t len, - int flags, const struct sockaddr *dest_addr, +static ssize_t sendto_dtls_server(struct net_context *ctx, const void *buf, + size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen) { - struct net_context *context = sock_to_net_ctx(sock); - /* For DTLS server, require to have established DTLS connection * in order to send data. */ - if (!context->tls->tls_established) { + if (!ctx->tls->tls_established) { errno = ENOTCONN; return -1; } /* Verify we are sending to a peer that we have connection with. */ if (dest_addr && - !dtls_is_peer_addr_valid(context, dest_addr, addrlen) != 0) { + !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) { errno = EISCONN; return -1; } - return send_tls(sock, buf, len, flags); + return send_tls(ctx, buf, len, flags); } #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */ -ssize_t ztls_sendto(int sock, const void *buf, size_t len, int flags, - const struct sockaddr *dest_addr, socklen_t addrlen) +ssize_t ztls_sendto_ctx(struct net_context *ctx, const void *buf, size_t len, + int flags, const struct sockaddr *dest_addr, + socklen_t addrlen) { - struct net_context *context = sock_to_net_ctx(sock); - - if (context == NULL) { + if (ctx->tls == NULL) { + errno = EBADF; return -1; } - if (!context->tls) { - return zsock_sendto(sock, buf, len, flags, dest_addr, addrlen); - } - - context->tls->flags = flags; + ctx->tls->flags = flags; /* TLS */ - if (net_context_get_type(context) == SOCK_STREAM) { - return send_tls(sock, buf, len, flags); + if (net_context_get_type(ctx) == SOCK_STREAM) { + return send_tls(ctx, buf, len, flags); } #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) /* DTLS */ - if (context->tls->options.role == MBEDTLS_SSL_IS_SERVER) { - return sendto_dtls_server(sock, buf, len, flags, + if (ctx->tls->options.role == MBEDTLS_SSL_IS_SERVER) { + return sendto_dtls_server(ctx, buf, len, flags, dest_addr, addrlen); } - return sendto_dtls_client(sock, buf, len, flags, dest_addr, addrlen); + return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen); #else errno = ENOTSUP; return -1; #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */ } -static ssize_t recv_tls(int sock, void *buf, size_t max_len, int flags) +static ssize_t recv_tls(struct net_context *ctx, void *buf, + size_t max_len, int flags) { - struct net_context *context = sock_to_net_ctx(sock); int ret; - ret = mbedtls_ssl_read(&context->tls->ssl, buf, max_len); + ret = mbedtls_ssl_read(&ctx->tls->ssl, buf, max_len); if (ret >= 0) { return ret; } @@ -1489,22 +1480,22 @@ static ssize_t recv_tls(int sock, void *buf, size_t max_len, int flags) } #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) -static ssize_t recvfrom_dtls_client(int sock, void *buf, size_t max_len, - int flags, struct sockaddr *src_addr, +static ssize_t recvfrom_dtls_client(struct net_context *ctx, void *buf, + size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) { - struct net_context *context = sock_to_net_ctx(sock); int ret; - if (!context->tls->tls_established) { + if (!ctx->tls->tls_established) { ret = -ENOTCONN; goto error; } - ret = mbedtls_ssl_read(&context->tls->ssl, buf, max_len); + ret = mbedtls_ssl_read(&ctx->tls->ssl, buf, max_len); if (ret >= 0) { if (src_addr && addrlen) { - dtls_peer_address_get(context, src_addr, addrlen); + dtls_peer_address_get(ctx, src_addr, addrlen); } return ret; } @@ -1515,7 +1506,7 @@ static ssize_t recvfrom_dtls_client(int sock, void *buf, size_t max_len, return 0; case MBEDTLS_ERR_SSL_TIMEOUT: - (void)mbedtls_ssl_close_notify(&context->tls->ssl); + (void)mbedtls_ssl_close_notify(&ctx->tls->ssl); ret = -ETIMEDOUT; break; @@ -1534,18 +1525,18 @@ static ssize_t recvfrom_dtls_client(int sock, void *buf, size_t max_len, return -1; } -static ssize_t recvfrom_dtls_server(int sock, void *buf, size_t max_len, - int flags, struct sockaddr *src_addr, +static ssize_t recvfrom_dtls_server(struct net_context *ctx, void *buf, + size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) { - struct net_context *context = sock_to_net_ctx(sock); int ret; bool repeat; bool is_block = !((flags & ZSOCK_MSG_DONTWAIT) || - sock_is_nonblock(context)); + sock_is_nonblock(ctx)); - if (!context->tls->is_initialized) { - ret = tls_mbedtls_init(sock, context, true); + if (!ctx->tls->is_initialized) { + ret = tls_mbedtls_init(ctx, true); if (ret < 0) { goto error; } @@ -1557,15 +1548,15 @@ static ssize_t recvfrom_dtls_server(int sock, void *buf, size_t max_len, do { repeat = false; - if (!context->tls->tls_established) { - ret = tls_mbedtls_handshake(context, is_block); + if (!ctx->tls->tls_established) { + ret = tls_mbedtls_handshake(ctx, is_block); if (ret < 0) { /* In case of EAGAIN, just exit. */ if (ret == -EAGAIN) { break; } - ret = tls_mbedtls_reset(context); + ret = tls_mbedtls_reset(ctx); if (ret == 0) { repeat = true; } else { @@ -1576,23 +1567,22 @@ static ssize_t recvfrom_dtls_server(int sock, void *buf, size_t max_len, } } - ret = mbedtls_ssl_read(&context->tls->ssl, buf, max_len); + ret = mbedtls_ssl_read(&ctx->tls->ssl, buf, max_len); if (ret >= 0) { if (src_addr && addrlen) { - dtls_peer_address_get(context, src_addr, - addrlen); + dtls_peer_address_get(ctx, src_addr, addrlen); } return ret; } switch (ret) { case MBEDTLS_ERR_SSL_TIMEOUT: - (void)mbedtls_ssl_close_notify(&context->tls->ssl); + (void)mbedtls_ssl_close_notify(&ctx->tls->ssl); /* fallthrough */ case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: case MBEDTLS_ERR_SSL_CLIENT_RECONNECT: - ret = tls_mbedtls_reset(context); + ret = tls_mbedtls_reset(ctx); if (ret == 0) { repeat = true; } else { @@ -1617,20 +1607,15 @@ static ssize_t recvfrom_dtls_server(int sock, void *buf, size_t max_len, } #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */ -ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, - struct sockaddr *src_addr, socklen_t *addrlen) +ssize_t ztls_recvfrom_ctx(struct net_context *ctx, void *buf, size_t max_len, + int flags, struct sockaddr *src_addr, + socklen_t *addrlen) { - struct net_context *context = sock_to_net_ctx(sock); - - if (context == NULL) { + if (ctx->tls == NULL) { + errno = EBADF; return -1; } - if (!context->tls) { - return zsock_recvfrom(sock, buf, max_len, flags, - src_addr, addrlen); - } - if (flags & ZSOCK_MSG_PEEK) { /* TODO mbedTLS does not support 'peeking' This could be * bypassed by having intermediate buffer for peeking @@ -1639,21 +1624,21 @@ ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, return -1; } - context->tls->flags = flags; + ctx->tls->flags = flags; /* TLS */ - if (net_context_get_type(context) == SOCK_STREAM) { - return recv_tls(sock, buf, max_len, flags); + if (net_context_get_type(ctx) == SOCK_STREAM) { + return recv_tls(ctx, buf, max_len, flags); } #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) /* DTLS */ - if (context->tls->options.role == MBEDTLS_SSL_IS_SERVER) { - return recvfrom_dtls_server(sock, buf, max_len, flags, + if (ctx->tls->options.role == MBEDTLS_SSL_IS_SERVER) { + return recvfrom_dtls_server(ctx, buf, max_len, flags, src_addr, addrlen); } - return recvfrom_dtls_client(sock, buf, max_len, flags, + return recvfrom_dtls_client(ctx, buf, max_len, flags, src_addr, addrlen); #else errno = ENOTSUP; @@ -1661,150 +1646,106 @@ ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */ } -int ztls_fcntl(int sock, int cmd, int flags) +static int ztls_poll_prepare_ctx(struct net_context *ctx, + struct zsock_pollfd *pfd, + struct k_poll_event **pev, + struct k_poll_event *pev_end) { - /* No extra action needed here. */ - return zsock_fcntl(sock, cmd, flags); -} - -int ztls_poll(struct zsock_pollfd *fds, int nfds, int timeout) -{ - bool retry = true; - struct zsock_pollfd *pfd; - struct net_context *context; - int i, ret, remaining_time = timeout; - u32_t entry_time = k_uptime_get_32(); - - /* There might be some decrypted but unread data pending - * on mbedTLS, check for that. - */ - for (pfd = fds, i = nfds; i--; pfd++) { - /* Per POSIX, negative fd's are just ignored */ - if (pfd->fd < 0) { - continue; - } - - if (pfd->events & ZSOCK_POLLIN) { - context = sock_to_net_ctx(pfd->fd); - if (context == NULL) { - /* ZSOCK_POLLNVAL will be set by zsock_poll */ - continue; - } - - if (!context->tls || IS_LISTENING(context)) { - continue; - } + if (ctx->tls == NULL) { + /* POLLNVAL flag will be set in the update function. */ + return 0; + } - /* There already is mbedTLS data to read, so just poll - * with no timeout, to check if there has been any other - * activity on sockets. + if (pfd->events & ZSOCK_POLLIN) { + if (!IS_LISTENING(ctx)) { + /* If there already is mbedTLS data to read, there is no + * need to set the k_poll_event object. Return EALREADY + * so we won't block in the k_poll. */ - if (mbedtls_ssl_get_bytes_avail( - &context->tls->ssl) > 0) { - remaining_time = K_NO_WAIT; - break; + if (mbedtls_ssl_get_bytes_avail(&ctx->tls->ssl) > 0) { + errno = EALREADY; + return -1; } } - } - while (retry) { - ret = zsock_poll(fds, nfds, remaining_time); - if (ret < 0) { - /* errno will be propagated. */ - return ret; - } else if (ret == 0) { - /* Do not repeat on timeout. */ - retry = false; + if (*pev == pev_end) { + errno = ENOMEM; + return -1; } - /* Make mbedTLS recalculate the data on sockets that notified - * data availability, and update revents respectively. - */ - for (pfd = fds, i = nfds; i--; pfd++) { - /* Per POSIX, negative fd's are just ignored */ - if (pfd->fd < 0) { - continue; - } - - if (pfd->events & ZSOCK_POLLIN) { - context = sock_to_net_ctx(pfd->fd); - if (context == NULL) { - /* ZSOCK_POLLNVAL was set by - * zsock_poll - */ - continue; - } - - if (!context->tls || IS_LISTENING(context)) { - continue; - } + (*pev)->obj = &ctx->recv_q; + (*pev)->type = K_POLL_TYPE_FIFO_DATA_AVAILABLE; + (*pev)->mode = K_POLL_MODE_NOTIFY_ONLY; + (*pev)->state = K_POLL_STATE_NOT_READY; + (*pev)++; + } - if (pfd->revents & ZSOCK_POLLIN) { - /* EAGAIN might happen during or just - * after DLTS handshake. - */ - if (recv(pfd->fd, NULL, 0, - ZSOCK_MSG_DONTWAIT) < 0 && - errno != EAGAIN) { - /* No need to increment ret here - * as POLLIN was set. - */ - pfd->revents |= ZSOCK_POLLERR; - continue; - } - } + return 0; +} - if (mbedtls_ssl_get_bytes_avail( - &context->tls->ssl) > 0) { - if (pfd->revents == 0) { - ret++; - } +static int ztls_poll_update_ctx(struct net_context *ctx, + struct zsock_pollfd *pfd, + struct k_poll_event **pev) +{ + if (ctx->tls == NULL) { + pfd->revents = ZSOCK_POLLNVAL; + return 0; + } - pfd->revents |= ZSOCK_POLLIN; - } else if (!sock_is_eof(context)) { - if (pfd->revents == ZSOCK_POLLIN) { - ret--; - } + /* For now, assume that socket is always writable */ + if (pfd->events & ZSOCK_POLLOUT) { + pfd->revents |= ZSOCK_POLLOUT; + } - pfd->revents &= ~ZSOCK_POLLIN; - } + if (pfd->events & ZSOCK_POLLIN) { + if (!IS_LISTENING(ctx)) { + /* Already had TLS data to read on socket. */ + if (mbedtls_ssl_get_bytes_avail(&ctx->tls->ssl) > 0) { + pfd->revents |= ZSOCK_POLLIN; + return 0; } } - /* If there's something to report, exit. */ - if (ret > 0) { - retry = false; - } + /* Some encrypted data received on the socket. */ + if (((*pev)++)->state != K_POLL_STATE_NOT_READY) { + if (IS_LISTENING(ctx)) { + pfd->revents |= ZSOCK_POLLIN; + return 0; + } - if (retry && remaining_time != K_FOREVER && - remaining_time != K_NO_WAIT) { - /* Recalculate the timeout value. */ - remaining_time = time_left(entry_time, timeout); - if (remaining_time <= 0) { - retry = false; + /* EAGAIN might happen during or just after + * DLTS handshake. + */ + if (recv(pfd->fd, NULL, 0, ZSOCK_MSG_DONTWAIT) < 0 && + errno != EAGAIN) { + pfd->revents |= ZSOCK_POLLERR; + return 0; + } + + if (mbedtls_ssl_get_bytes_avail(&ctx->tls->ssl) > 0 || + sock_is_eof(ctx)) { + pfd->revents |= ZSOCK_POLLIN; + return 0; } + + /* Received encrypted data, but still not enough + * to decrypt it and return data through socket, + * ask for retry. + */ + errno = EAGAIN; + return -1; } } - return ret; + return 0; } -int ztls_getsockopt(int sock, int level, int optname, - void *optval, socklen_t *optlen) +int ztls_getsockopt_ctx(struct net_context *ctx, int level, int optname, + void *optval, socklen_t *optlen) { int err; - struct net_context *context; - - if (level != SOL_TLS) { - return zsock_getsockopt(sock, level, optname, optval, optlen); - } - - context = sock_to_net_ctx(sock); - if (context == NULL) { - return -1; - } - if (!context->tls) { + if (!ctx->tls) { errno = EBADF; return -1; } @@ -1814,17 +1755,22 @@ int ztls_getsockopt(int sock, int level, int optname, return -1; } + if (level != SOL_TLS) { + return sock_fd_op_vtable.getsockopt(ctx, level, optname, + optval, optlen); + } + switch (optname) { case TLS_SEC_TAG_LIST: - err = tls_opt_sec_tag_list_get(context, optval, optlen); + err = tls_opt_sec_tag_list_get(ctx, optval, optlen); break; case TLS_CIPHERSUITE_LIST: - err = tls_opt_ciphersuite_list_get(context, optval, optlen); + err = tls_opt_ciphersuite_list_get(ctx, optval, optlen); break; case TLS_CIPHERSUITE_USED: - err = tls_opt_ciphersuite_used_get(context, optval, optlen); + err = tls_opt_ciphersuite_used_get(ctx, optval, optlen); break; default: @@ -1841,45 +1787,40 @@ int ztls_getsockopt(int sock, int level, int optname, return 0; } -int ztls_setsockopt(int sock, int level, int optname, - const void *optval, socklen_t optlen) +int ztls_setsockopt_ctx(struct net_context *ctx, int level, int optname, + const void *optval, socklen_t optlen) { int err; - struct net_context *context; - if (level != SOL_TLS) { - return zsock_setsockopt(sock, level, optname, optval, optlen); - } - - context = sock_to_net_ctx(sock); - if (context == NULL) { + if (!ctx->tls) { + errno = EBADF; return -1; } - if (!context->tls) { - errno = EBADF; - return -1; + if (level != SOL_TLS) { + return sock_fd_op_vtable.setsockopt(ctx, level, optname, + optval, optlen); } switch (optname) { case TLS_SEC_TAG_LIST: - err = tls_opt_sec_tag_list_set(context, optval, optlen); + err = tls_opt_sec_tag_list_set(ctx, optval, optlen); break; case TLS_HOSTNAME: - err = tls_opt_hostname_set(context, optval, optlen); + err = tls_opt_hostname_set(ctx, optval, optlen); break; case TLS_CIPHERSUITE_LIST: - err = tls_opt_ciphersuite_list_set(context, optval, optlen); + err = tls_opt_ciphersuite_list_set(ctx, optval, optlen); break; case TLS_PEER_VERIFY: - err = tls_opt_peer_verify_set(context, optval, optlen); + err = tls_opt_peer_verify_set(ctx, optval, optlen); break; case TLS_DTLS_ROLE: - err = tls_opt_dtls_role_set(context, optval, optlen); + err = tls_opt_dtls_role_set(ctx, optval, optlen); break; default: @@ -1895,3 +1836,134 @@ int ztls_setsockopt(int sock, int level, int optname, return 0; } + +static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count) +{ + return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0); +} + +static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer, + size_t count) +{ + return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0); +} + +static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, ...) +{ + switch (request) { + case ZFD_IOCTL_CLOSE: + return ztls_close_ctx(obj); + + case ZFD_IOCTL_FCNTL: { + va_list args; + int err; + + /* Pass the call to the core socket implementation. */ + va_start(args, request); + err = sock_fd_op_vtable.fd_vtable.ioctl(obj, request, args); + va_end(args); + + return err; + } + + case ZFD_IOCTL_POLL_PREPARE: { + va_list args; + struct zsock_pollfd *pfd; + struct k_poll_event **pev; + struct k_poll_event *pev_end; + + va_start(args, request); + pfd = va_arg(args, struct zsock_pollfd *); + pev = va_arg(args, struct k_poll_event **); + pev_end = va_arg(args, struct k_poll_event *); + va_end(args); + + return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end); + } + + case ZFD_IOCTL_POLL_UPDATE: { + va_list args; + struct zsock_pollfd *pfd; + struct k_poll_event **pev; + + va_start(args, request); + pfd = va_arg(args, struct zsock_pollfd *); + pev = va_arg(args, struct k_poll_event **); + va_end(args); + + return ztls_poll_update_ctx(obj, pfd, pev); + } + + default: + errno = EOPNOTSUPP; + return -1; + } +} + +static int tls_sock_bind_vmeth(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + return sock_fd_op_vtable.bind(obj, addr, addrlen); +} + +static int tls_sock_connect_vmeth(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + return ztls_connect_ctx(obj, addr, addrlen); +} + +static int tls_sock_listen_vmeth(void *obj, int backlog) +{ + return sock_fd_op_vtable.listen(obj, backlog); +} + +static int tls_sock_accept_vmeth(void *obj, struct sockaddr *addr, + socklen_t *addrlen) +{ + return ztls_accept_ctx(obj, addr, addrlen); +} + +static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len, + int flags, + const struct sockaddr *dest_addr, + socklen_t addrlen) +{ + return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen); +} + +static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len, + int flags, struct sockaddr *src_addr, + socklen_t *addrlen) +{ + return ztls_recvfrom_ctx(obj, buf, max_len, flags, + src_addr, addrlen); +} + +static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname, + void *optval, socklen_t *optlen) +{ + return ztls_getsockopt_ctx(obj, level, optname, optval, optlen); +} + +static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname, + const void *optval, socklen_t optlen) +{ + return ztls_setsockopt_ctx(obj, level, optname, optval, optlen); +} + + +static const struct socket_op_vtable tls_sock_fd_op_vtable = { + .fd_vtable = { + .read = tls_sock_read_vmeth, + .write = tls_sock_write_vmeth, + .ioctl = tls_sock_ioctl_vmeth, + }, + .bind = tls_sock_bind_vmeth, + .connect = tls_sock_connect_vmeth, + .listen = tls_sock_listen_vmeth, + .accept = tls_sock_accept_vmeth, + .sendto = tls_sock_sendto_vmeth, + .recvfrom = tls_sock_recvfrom_vmeth, + .getsockopt = tls_sock_getsockopt_vmeth, + .setsockopt = tls_sock_setsockopt_vmeth, +};