diff --git a/pw_rpc/client_integration_test.cc b/pw_rpc/client_integration_test.cc index 8f10211764..38dcbbd6d5 100644 --- a/pw_rpc/client_integration_test.cc +++ b/pw_rpc/client_integration_test.cc @@ -12,8 +12,6 @@ // License for the specific language governing permissions and limitations under // the License. -#include - #include #include #include @@ -30,10 +28,6 @@ namespace { constexpr int kIterations = 3; -// This client configures a socket read timeout to allow the RPC dispatch thread -// to exit gracefully. -constexpr timeval kSocketReadTimeout = {.tv_sec = 1, .tv_usec = 0}; - using namespace std::chrono_literals; using pw::ByteSpan; using pw::ConstByteSpan; @@ -145,18 +139,6 @@ int main(int argc, char* argv[]) { return 1; } - // Set read timout on socket to allow - // pw::rpc::integration_test::TerminateClient() to complete. - int retval = setsockopt(pw::rpc::integration_test::GetClientSocketFd(), - SOL_SOCKET, - SO_RCVTIMEO, - &rpc_test::kSocketReadTimeout, - sizeof(rpc_test::kSocketReadTimeout)); - PW_CHECK_INT_EQ(retval, - 0, - "Failed to configure socket receive timeout with errno=%d", - errno); - int test_retval = RUN_ALL_TESTS(); pw::rpc::integration_test::TerminateClient(); diff --git a/pw_rpc/fuzz/client_fuzzer.cc b/pw_rpc/fuzz/client_fuzzer.cc index c5086becf4..d34b94af62 100644 --- a/pw_rpc/fuzz/client_fuzzer.cc +++ b/pw_rpc/fuzz/client_fuzzer.cc @@ -16,8 +16,6 @@ #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first. // clang-format on -#include - #include #include "pw_log/log.h" @@ -28,10 +26,6 @@ namespace pw::rpc::fuzz { namespace { -// This client configures a socket read timeout to allow the RPC dispatch thread -// to exit gracefully. -constexpr timeval kSocketReadTimeout = {.tv_sec = 1, .tv_usec = 0}; - int FuzzClient(int argc, char** argv) { // TODO(aarongreen): Incorporate descriptions into usage message. Vector parsers{ @@ -78,19 +72,6 @@ int FuzzClient(int argc, char** argv) { return 1; } - // Set read timout on socket to allow - // pw::rpc::integration_test::TerminateClient() to complete. - int fd = integration_test::GetClientSocketFd(); - if (setsockopt(fd, - SOL_SOCKET, - SO_RCVTIMEO, - &kSocketReadTimeout, - sizeof(kSocketReadTimeout)) != 0) { - PW_LOG_ERROR("Failed to configure socket receive timeout with errno=%d", - errno); - return 1; - } - if (num_actions == 0) { num_actions = std::numeric_limits::max(); } diff --git a/pw_rpc/integration_testing.cc b/pw_rpc/integration_testing.cc index d0b18fadaa..2697b1c08c 100644 --- a/pw_rpc/integration_testing.cc +++ b/pw_rpc/integration_testing.cc @@ -33,7 +33,12 @@ unit_test::LoggingEventHandler log_test_events; Client& client() { return context.client(); } -int GetClientSocketFd() { return context.GetSocketFd(); } +int SetClientSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen) { + return context.SetSockOpt(level, optname, optval, optlen); +} void SetEgressChannelManipulator(ChannelManipulator* new_channel_manipulator) { context.SetEgressChannelManipulator(new_channel_manipulator); diff --git a/pw_rpc/public/pw_rpc/integration_test_socket_client.h b/pw_rpc/public/pw_rpc/integration_test_socket_client.h index f8847874b8..eba8404fa3 100644 --- a/pw_rpc/public/pw_rpc/integration_test_socket_client.h +++ b/pw_rpc/public/pw_rpc/integration_test_socket_client.h @@ -53,17 +53,21 @@ class SocketClientContext { } // Terminates the client, joining the RPC dispatch thread. - // - // WARNING: This may block forever if the socket is configured to block - // indefinitely on reads. Configuring the client socket's `SO_RCVTIMEO` to a - // nonzero timeout will allow the dispatch thread to always return. void Terminate() { PW_ASSERT(rpc_dispatch_thread_handle_.has_value()); should_terminate_.test_and_set(); + // Close the stream to avoid blocking forever on a socket read. + stream_.Close(); rpc_dispatch_thread_handle_->join(); } - int GetSocketFd() { return stream_.connection_fd(); } + // Configure options for the socket associated with the client. + int SetSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen) { + return stream_.SetSockOpt(level, optname, optval, optlen); + } void SetEgressChannelManipulator( ChannelManipulator* new_channel_manipulator) { diff --git a/pw_rpc/public/pw_rpc/integration_testing.h b/pw_rpc/public/pw_rpc/integration_testing.h index 8d6dd17c8b..2071c53ccb 100644 --- a/pw_rpc/public/pw_rpc/integration_testing.h +++ b/pw_rpc/public/pw_rpc/integration_testing.h @@ -85,9 +85,11 @@ void SetIngressChannelManipulator(ChannelManipulator* new_channel_manipulator); // Returns the global RPC client for integration test use. Client& client(); -// The file descriptor for the socket associated with the client. This may be -// used to configure socket options. -int GetClientSocketFd(); +// Configure options for the socket associated with the client. +int SetClientSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen); // Initializes logging and the global RPC client for integration testing. Starts // a background thread that processes incoming. @@ -98,10 +100,6 @@ Status InitializeClient(int argc, Status InitializeClient(int port); // Terminates the client, joining the RPC dispatch thread. -// -// WARNING: This may block forever if the socket is configured to block -// indefinitely on reads. Configuring the client socket's `SO_RCVTIMEO` to a -// nonzero timeout will allow the dispatch thread to always return. void TerminateClient(); } // namespace pw::rpc::integration_test diff --git a/pw_rpc/system_server/public/pw_rpc_system_server/socket.h b/pw_rpc/system_server/public/pw_rpc_system_server/socket.h index dc34e93354..96b79ada52 100644 --- a/pw_rpc/system_server/public/pw_rpc_system_server/socket.h +++ b/pw_rpc/system_server/public/pw_rpc_system_server/socket.h @@ -20,8 +20,10 @@ namespace pw::rpc::system_server { // Sets the port to use for pw::rpc::system_server backends that use sockets. void set_socket_port(uint16_t port); -// The file descriptor for the socket associated with the server. This may be -// used to configure socket options. -int GetServerSocketFd(); +// Configure options for the socket associated with the server. +int SetServerSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen); } // namespace pw::rpc::system_server diff --git a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h index e3a84c5eaf..3d41214716 100644 --- a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h +++ b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h @@ -106,6 +106,10 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { while (!stopped_) { const auto read_status = ReadData(); + // Break if ReadData was cancelled after the transport was stopped. + if (stopped_) { + break; + } if (!read_status.ok()) { internal::LogSocketReadError(read_status); } @@ -122,7 +126,11 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { } } - void Stop() { stopped_ = true; } + void Stop() { + stopped_ = true; + socket_stream_.Close(); + server_socket_.Close(); + } private: enum class ClientServerRole { kClient, kServer }; @@ -156,6 +164,11 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { NotifyReady(); Result stream = server_socket_.Accept(); + // If Accept was cancelled due to stopping the transport, return without + // error. + if (stopped_) { + return OkStatus(); + } if (!stream.ok()) { internal::LogSocketAcceptError(stream.status()); return stream.status(); diff --git a/pw_rpc_transport/rpc_integration_test.cc b/pw_rpc_transport/rpc_integration_test.cc index c48672d3e3..633ff848f6 100644 --- a/pw_rpc_transport/rpc_integration_test.cc +++ b/pw_rpc_transport/rpc_integration_test.cc @@ -120,12 +120,6 @@ TEST(RpcIntegrationTest, SocketTransport) { a.transport.Stop(); b.transport.Stop(); - // Unblock socket transports by sending terminator packets. - const std::array terminator_bytes{std::byte{0x42}}; - RpcFrame terminator{.header = {}, .payload = terminator_bytes}; - EXPECT_EQ(a.transport.Send(terminator), OkStatus()); - EXPECT_EQ(b.transport.Send(terminator), OkStatus()); - a_local_egress_thread.join(); b_local_egress_thread.join(); a_transport_thread.join(); diff --git a/pw_rpc_transport/socket_rpc_transport_test.cc b/pw_rpc_transport/socket_rpc_transport_test.cc index 12c0d845c2..8f128726ea 100644 --- a/pw_rpc_transport/socket_rpc_transport_test.cc +++ b/pw_rpc_transport/socket_rpc_transport_test.cc @@ -115,18 +115,12 @@ class SocketSender { } } - // stream::SocketStream doesn't support read timeouts so we have to - // unblock socket reads by sending more data after the transport is stopped. - pw::Status Terminate() { return transport_.Send(terminator_); } - private: SocketRpcTransport& transport_; std::vector sent_; std::array data_{}; std::uniform_int_distribution offset_dist_{0, 255}; std::uniform_int_distribution size_dist_{1, kMaxWriteSize}; - std::array terminator_bytes_{std::byte{0x42}}; - RpcFrame terminator_{.header = {}, .payload = terminator_bytes_}; }; class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore { @@ -182,10 +176,6 @@ TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) { server.Stop(); client.Stop(); - // Unblock socket reads to propagate the stop signal. - EXPECT_EQ(server_sender.Terminate(), OkStatus()); - EXPECT_EQ(client_sender.Terminate(), OkStatus()); - server_thread.join(); client_thread.join(); @@ -242,7 +232,6 @@ TEST(SocketRpcTransportTest, ServerReconnects) { // Stop the client but not the server: we're re-using the same server // with a new client below. client.Stop(); - EXPECT_EQ(server_sender.Terminate(), OkStatus()); client_thread.join(); } @@ -267,13 +256,11 @@ TEST(SocketRpcTransportTest, ServerReconnects) { std::back_inserter(received)); client.Stop(); - EXPECT_EQ(server_sender.Terminate(), OkStatus()); client_thread.join(); // This time stop the server as well. SocketSender client_sender(client); server.Stop(); - EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); } @@ -322,7 +309,6 @@ TEST(SocketRpcTransportTest, ClientReconnects) { server1_sent.end(), std::back_inserter(sent_by_server)); - EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); server = nullptr; @@ -345,11 +331,9 @@ TEST(SocketRpcTransportTest, ClientReconnects) { server2_sent.end(), std::back_inserter(sent_by_server)); - EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); client.Stop(); - EXPECT_EQ(server2_sender.Terminate(), OkStatus()); client_thread.join(); server = nullptr; diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel index 1d40bfd84f..733878ad0a 100644 --- a/pw_stream/BUILD.bazel +++ b/pw_stream/BUILD.bazel @@ -52,6 +52,7 @@ pw_cc_library( ":pw_stream", "//pw_log", "//pw_string", + "//pw_sync:mutex", "//pw_sys_io", ], ) diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn index a55fb91a0d..45357a0ec0 100644 --- a/pw_stream/BUILD.gn +++ b/pw_stream/BUILD.gn @@ -48,7 +48,10 @@ pw_source_set("pw_stream") { pw_source_set("socket_stream") { public_configs = [ ":public_include_path" ] - public_deps = [ ":pw_stream" ] + public_deps = [ + ":pw_stream", + "$dir_pw_sync:mutex", + ] deps = [ dir_pw_assert, dir_pw_log, diff --git a/pw_stream/CMakeLists.txt b/pw_stream/CMakeLists.txt index b04b440af3..623ac47203 100644 --- a/pw_stream/CMakeLists.txt +++ b/pw_stream/CMakeLists.txt @@ -40,6 +40,7 @@ pw_add_library(pw_stream.socket_stream STATIC public PUBLIC_DEPS pw_stream + pw_sync.mutex SOURCES socket_stream.cc PRIVATE_DEPS diff --git a/pw_stream/public/pw_stream/socket_stream.h b/pw_stream/public/pw_stream/socket_stream.h index a9b7b16a65..6b6bd6a774 100644 --- a/pw_stream/public/pw_stream/socket_stream.h +++ b/pw_stream/public/pw_stream/socket_stream.h @@ -18,25 +18,27 @@ #include "pw_result/result.h" #include "pw_span/span.h" #include "pw_stream/stream.h" +#include "pw_sync/lock_annotations.h" +#include "pw_sync/mutex.h" namespace pw::stream { class SocketStream : public NonSeekableReaderWriter { public: - constexpr SocketStream() = default; + SocketStream() = default; // Construct a SocketStream directly from a file descriptor. - explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) {} + explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) { + // Mark as ready and take ownership of the connection by this object. + ready_ = true; + TakeConnection(); + } // SocketStream objects are moveable but not copyable. SocketStream& operator=(SocketStream&& other) { - connection_fd_ = other.connection_fd_; - other.connection_fd_ = kInvalidFd; + MoveFrom(std::move(other)); return *this; } - SocketStream(SocketStream&& other) noexcept - : connection_fd_(other.connection_fd_) { - other.connection_fd_ = kInvalidFd; - } + SocketStream(SocketStream&& other) noexcept { MoveFrom(std::move(other)); } SocketStream(const SocketStream&) = delete; SocketStream& operator=(const SocketStream&) = delete; @@ -47,26 +49,81 @@ class SocketStream : public NonSeekableReaderWriter { // instead. Status Connect(const char* host, uint16_t port); + // Configures socket options. + int SetSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen); + // Close the socket stream and release all resources void Close(); - // Exposes the file descriptor for the active connection. This is exposed to - // allow configuration and introspection of this socket's current - // configuration using setsockopt() and getsockopt(). - // - // Returns -1 if there is no active connection. - int connection_fd() { return connection_fd_; } - private: - friend class ServerSocket; - static constexpr int kInvalidFd = -1; + class ConnectionOwnership { + public: + explicit ConnectionOwnership(SocketStream* socket_stream) + : socket_stream_(socket_stream) { + fd_ = socket_stream_->TakeConnection(); + std::lock_guard lock(socket_stream_->connection_mutex_); + pipe_r_fd_ = socket_stream->connection_pipe_r_fd_; + } + + ~ConnectionOwnership() { socket_stream_->ReleaseConnection(); } + + int fd() { return fd_; } + + int pipe_r_fd() { return pipe_r_fd_; } + + private: + SocketStream* socket_stream_; + int fd_; + int pipe_r_fd_; + }; + Status DoWrite(span data) override; StatusWithSize DoRead(ByteSpan dest) override; - int connection_fd_ = kInvalidFd; + // Take ownership of the connection. There may be multiple owners. Each time + // TakeConnection is called, ReleaseConnection must be called to release + // ownership, even if the connection is not valid. + // + // Returns the connection fd or kInvalidFd if the connection is not valid. + int TakeConnection(); + int TakeConnectionWithLockHeld() + PW_EXCLUSIVE_LOCKS_REQUIRED(connection_mutex_); + + // Release ownership of the connection. If no owners remain, close and clear + // the connection fds. + void ReleaseConnection(); + void ReleaseConnectionWithLockHeld() + PW_EXCLUSIVE_LOCKS_REQUIRED(connection_mutex_); + + // Moves other to this. + void MoveFrom(SocketStream&& other) { + std::lock_guard lock(connection_mutex_); + std::lock_guard other_lock(other.connection_mutex_); + + connection_own_count_ = other.connection_own_count_; + other.connection_own_count_ = 0; + ready_ = other.ready_; + other.ready_ = false; + connection_fd_ = other.connection_fd_; + other.connection_fd_ = kInvalidFd; + connection_pipe_r_fd_ = other.connection_pipe_r_fd_; + other.connection_pipe_r_fd_ = kInvalidFd; + connection_pipe_w_fd_ = other.connection_pipe_w_fd_; + other.connection_pipe_w_fd_ = kInvalidFd; + } + + sync::Mutex connection_mutex_; + int connection_own_count_ PW_GUARDED_BY(connection_mutex_) = 0; + bool ready_ PW_GUARDED_BY(connection_mutex_) = false; + int connection_fd_ PW_GUARDED_BY(connection_mutex_) = kInvalidFd; + int connection_pipe_r_fd_ PW_GUARDED_BY(connection_mutex_) = kInvalidFd; + int connection_pipe_w_fd_ PW_GUARDED_BY(connection_mutex_) = kInvalidFd; }; /// `ServerSocket` wraps a POSIX-style server socket, producing a `SocketStream` @@ -100,8 +157,47 @@ class ServerSocket { private: static constexpr int kInvalidFd = -1; + class SocketOwnership { + public: + explicit SocketOwnership(ServerSocket* server_socket) + : server_socket_(server_socket) { + fd_ = server_socket_->TakeSocket(); + std::lock_guard lock(server_socket->socket_mutex_); + pipe_r_fd_ = server_socket->socket_pipe_r_fd_; + } + + ~SocketOwnership() { server_socket_->ReleaseSocket(); } + + int fd() { return fd_; } + + int pipe_r_fd() { return pipe_r_fd_; } + + private: + ServerSocket* server_socket_; + int fd_; + int pipe_r_fd_; + }; + + // Take ownership of the socket. There may be multiple owners. Each time + // TakeSocket is called, ReleaseSocket must be called to release ownership, + // even if the socket is not invalid. + // + // Returns the socket fd or kInvalidFd if the socket is not valid. + int TakeSocket(); + int TakeSocketWithLockHeld() PW_EXCLUSIVE_LOCKS_REQUIRED(socket_mutex_); + + // Release ownership of the socket. If no owners remain, close and clear the + // socket fds. + void ReleaseSocket(); + void ReleaseSocketWithLockHeld() PW_EXCLUSIVE_LOCKS_REQUIRED(socket_mutex_); + uint16_t port_ = -1; - int socket_fd_ = kInvalidFd; + sync::Mutex socket_mutex_; + int socket_own_count_ PW_GUARDED_BY(socket_mutex_) = 0; + bool ready_ PW_GUARDED_BY(socket_mutex_) = false; + int socket_fd_ PW_GUARDED_BY(socket_mutex_) = kInvalidFd; + int socket_pipe_r_fd_ PW_GUARDED_BY(socket_mutex_) = kInvalidFd; + int socket_pipe_w_fd_ PW_GUARDED_BY(socket_mutex_) = kInvalidFd; }; } // namespace pw::stream diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc index b3125439ce..d7cec21f08 100644 --- a/pw_stream/socket_stream.cc +++ b/pw_stream/socket_stream.cc @@ -15,6 +15,8 @@ #include "pw_stream/socket_stream.h" #if defined(_WIN32) && _WIN32 +#include +#include #include #include #define SHUT_RDWR SD_BOTH @@ -22,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +59,25 @@ void ConfigureSocket([[maybe_unused]] int socket) { #if defined(_WIN32) && _WIN32 int close(SOCKET s) { return closesocket(s); } +ssize_t write(int fd, const void* buf, size_t count) { + return _write(fd, buf, count); +} + +int poll(struct pollfd* fds, unsigned int nfds, int timeout) { + return WSAPoll(fds, nfds, timeout); +} + +int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); } + +int setsockopt( + int fd, int level, int optname, const void* optval, unsigned int optlen) { + return setsockopt(static_cast(fd), + level, + optname, + static_cast(optval), + static_cast(optlen)); +} + class WinsockInitializer { public: WinsockInitializer() { @@ -93,37 +115,72 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) { } struct addrinfo* rp; + int connection_fd; for (rp = res; rp != nullptr; rp = rp->ai_next) { - connection_fd_ = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (connection_fd_ != kInvalidFd) { + connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (connection_fd != kInvalidFd) { break; } } - if (connection_fd_ == kInvalidFd) { + if (connection_fd == kInvalidFd) { PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno)); freeaddrinfo(res); return Status::Unknown(); } - ConfigureSocket(connection_fd_); - if (connect(connection_fd_, rp->ai_addr, rp->ai_addrlen) == -1) { - close(connection_fd_); - connection_fd_ = kInvalidFd; + ConfigureSocket(connection_fd); + if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) { + close(connection_fd); PW_LOG_ERROR( "Failed to connect to %s:%d: %s", host, port, std::strerror(errno)); freeaddrinfo(res); return Status::Unknown(); } + // Mark as ready and take ownership of the connection by this object. + { + std::lock_guard lock(connection_mutex_); + connection_fd_ = connection_fd; + TakeConnectionWithLockHeld(); + ready_ = true; + } + freeaddrinfo(res); return OkStatus(); } +// Configures socket options. +int SocketStream::SetSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen) { + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return EBADF; + } + return setsockopt(ownership.fd(), level, optname, optval, optlen); +} + void SocketStream::Close() { - if (connection_fd_ != kInvalidFd) { - close(connection_fd_); - connection_fd_ = kInvalidFd; + ConnectionOwnership ownership(this); + { + std::lock_guard lock(connection_mutex_); + if (ready_) { + // Shutdown the connection and send tear down notification to unblock any + // waiters. + if (connection_fd_ != kInvalidFd) { + shutdown(connection_fd_, SHUT_RDWR); + } + if (connection_pipe_w_fd_ != kInvalidFd) { + write(connection_pipe_w_fd_, "T", 1); + } + + // Release ownership of the connection by this object and mark as no + // longer ready. + ReleaseConnectionWithLockHeld(); + ready_ = false; + } } } @@ -135,10 +192,17 @@ Status SocketStream::DoWrite(span data) { send_flags |= MSG_NOSIGNAL; #endif // defined(__linux__) - ssize_t bytes_sent = send(connection_fd_, - reinterpret_cast(data.data()), - data.size_bytes(), - send_flags); + ssize_t bytes_sent; + { + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return Status::Unknown(); + } + bytes_sent = send(ownership.fd(), + reinterpret_cast(data.data()), + data.size_bytes(), + send_flags); + } if (bytes_sent < 0 || static_cast(bytes_sent) != data.size()) { if (errno == EPIPE) { @@ -153,7 +217,23 @@ Status SocketStream::DoWrite(span data) { } StatusWithSize SocketStream::DoRead(ByteSpan dest) { - ssize_t bytes_rcvd = recv(connection_fd_, + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return StatusWithSize::Unknown(); + } + + // Wait for data to read or a tear down notification. + pollfd fds_to_poll[2]; + fds_to_poll[0].fd = ownership.fd(); + fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP; + fds_to_poll[1].fd = ownership.pipe_r_fd(); + fds_to_poll[1].events = POLLIN; + poll(fds_to_poll, 2, -1); + if (!(fds_to_poll[0].revents & POLLIN)) { + return StatusWithSize::Unknown(); + } + + ssize_t bytes_rcvd = recv(ownership.fd(), reinterpret_cast(dest.data()), dest.size_bytes(), 0); @@ -174,18 +254,67 @@ StatusWithSize SocketStream::DoRead(ByteSpan dest) { return StatusWithSize(bytes_rcvd); } +int SocketStream::TakeConnection() { + std::lock_guard lock(connection_mutex_); + return TakeConnectionWithLockHeld(); +} + +int SocketStream::TakeConnectionWithLockHeld() { + ++connection_own_count_; + + if (ready_ && (connection_fd_ != kInvalidFd) && + (connection_pipe_r_fd_ == kInvalidFd)) { + int fd_list[2]; + if (pipe(fd_list) >= 0) { + connection_pipe_r_fd_ = fd_list[0]; + connection_pipe_w_fd_ = fd_list[1]; + } + } + + if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) || + (connection_pipe_w_fd_ == kInvalidFd)) { + return kInvalidFd; + } + return connection_fd_; +} + +void SocketStream::ReleaseConnection() { + std::lock_guard lock(connection_mutex_); + ReleaseConnectionWithLockHeld(); +} + +void SocketStream::ReleaseConnectionWithLockHeld() { + --connection_own_count_; + + if (connection_own_count_ <= 0) { + ready_ = false; + if (connection_fd_ != kInvalidFd) { + close(connection_fd_); + connection_fd_ = kInvalidFd; + } + if (connection_pipe_r_fd_ != kInvalidFd) { + close(connection_pipe_r_fd_); + connection_pipe_r_fd_ = kInvalidFd; + } + if (connection_pipe_w_fd_ != kInvalidFd) { + close(connection_pipe_w_fd_); + connection_pipe_w_fd_ = kInvalidFd; + } + } +} + // Listen for connections on the given port. // If port is 0, a random unused port is chosen and can be retrieved with // port(). Status ServerSocket::Listen(uint16_t port) { - socket_fd_ = socket(AF_INET6, SOCK_STREAM, 0); - if (socket_fd_ == kInvalidFd) { + int socket_fd = socket(AF_INET6, SOCK_STREAM, 0); + if (socket_fd == kInvalidFd) { return Status::Unknown(); } // Allow binding to an address that may still be in use by a closed socket. constexpr int value = 1; - setsockopt(socket_fd_, + setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&value), @@ -197,27 +326,37 @@ Status ServerSocket::Listen(uint16_t port) { addr.sin6_family = AF_INET6; addr.sin6_port = htons(port); addr.sin6_addr = in6addr_any; - if (bind(socket_fd_, reinterpret_cast(&addr), addr_len) < 0) { + if (bind(socket_fd, reinterpret_cast(&addr), addr_len) < 0) { + close(socket_fd); return Status::Unknown(); } } - if (listen(socket_fd_, kServerBacklogLength) < 0) { + if (listen(socket_fd, kServerBacklogLength) < 0) { + close(socket_fd); return Status::Unknown(); } // Find out which port the socket is listening on, and fill in port_. struct sockaddr_in6 addr = {}; socklen_t addr_len = sizeof(addr); - if (getsockname(socket_fd_, reinterpret_cast(&addr), &addr_len) < + if (getsockname(socket_fd, reinterpret_cast(&addr), &addr_len) < 0 || static_cast(addr_len) > sizeof(addr)) { - close(socket_fd_); + close(socket_fd); return Status::Unknown(); } port_ = ntohs(addr.sin6_port); + // Mark as ready and take ownership of the socket by this object. + { + std::lock_guard lock(socket_mutex_); + socket_fd_ = socket_fd; + TakeSocketWithLockHeld(); + ready_ = true; + } + return OkStatus(); } @@ -227,23 +366,101 @@ Result ServerSocket::Accept() { struct sockaddr_in6 sockaddr_client_ = {}; socklen_t len = sizeof(sockaddr_client_); - int connection_fd = - accept(socket_fd_, reinterpret_cast(&sockaddr_client_), &len); + SocketOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return Status::Unknown(); + } + + // Wait for a connection or a tear down notification. + pollfd fds_to_poll[2]; + fds_to_poll[0].fd = ownership.fd(); + fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP; + fds_to_poll[1].fd = ownership.pipe_r_fd(); + fds_to_poll[1].events = POLLIN; + int rv = poll(fds_to_poll, 2, -1); + if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) { + return Status::Unknown(); + } + + int connection_fd = accept( + ownership.fd(), reinterpret_cast(&sockaddr_client_), &len); if (connection_fd == kInvalidFd) { return Status::Unknown(); } ConfigureSocket(connection_fd); - SocketStream client_stream; - client_stream.connection_fd_ = connection_fd; - return client_stream; + return SocketStream(connection_fd); } // Close the server socket, preventing further connections. void ServerSocket::Close() { - if (socket_fd_ != kInvalidFd) { - close(socket_fd_); - socket_fd_ = kInvalidFd; + SocketOwnership ownership(this); + { + std::lock_guard lock(socket_mutex_); + if (ready_) { + // Shutdown the socket and send tear down notification to unblock any + // waiters. + if (socket_fd_ != kInvalidFd) { + shutdown(socket_fd_, SHUT_RDWR); + } + if (socket_pipe_w_fd_ != kInvalidFd) { + write(socket_pipe_w_fd_, "T", 1); + } + + // Release ownership of the socket by this object and mark as no longer + // ready. + ReleaseSocketWithLockHeld(); + ready_ = false; + } + } +} + +int ServerSocket::TakeSocket() { + std::lock_guard lock(socket_mutex_); + return TakeSocketWithLockHeld(); +} + +int ServerSocket::TakeSocketWithLockHeld() { + ++socket_own_count_; + + if (ready_ && (socket_fd_ != kInvalidFd) && + (socket_pipe_r_fd_ == kInvalidFd)) { + int fd_list[2]; + if (pipe(fd_list) >= 0) { + socket_pipe_r_fd_ = fd_list[0]; + socket_pipe_w_fd_ = fd_list[1]; + } + } + + if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) || + (socket_pipe_w_fd_ == kInvalidFd)) { + return kInvalidFd; + } + return socket_fd_; +} + +void ServerSocket::ReleaseSocket() { + std::lock_guard lock(socket_mutex_); + ReleaseSocketWithLockHeld(); +} + +void ServerSocket::ReleaseSocketWithLockHeld() { + --socket_own_count_; + + if (socket_own_count_ <= 0) { + ready_ = false; + if (socket_fd_ != kInvalidFd) { + close(socket_fd_); + socket_fd_ = kInvalidFd; + } + if (socket_pipe_r_fd_ != kInvalidFd) { + close(socket_pipe_r_fd_); + socket_pipe_r_fd_ = kInvalidFd; + } + if (socket_pipe_w_fd_ != kInvalidFd) { + close(socket_pipe_w_fd_); + socket_pipe_w_fd_ = kInvalidFd; + } } } diff --git a/pw_transfer/integration_test/client.cc b/pw_transfer/integration_test/client.cc index 488dbbe40c..b3702e54e7 100644 --- a/pw_transfer/integration_test/client.cc +++ b/pw_transfer/integration_test/client.cc @@ -57,10 +57,6 @@ namespace { // smaller receive buffer size. constexpr int kMaxSocketSendBufferSize = 1; -// This client configures a socket read timeout to allow the RPC dispatch thread -// to exit gracefully. -constexpr timeval kSocketReadTimeout = {.tv_sec = 1, .tv_usec = 0}; - thread::Options& TransferThreadOptions() { static thread::stl::Options options; return options; @@ -203,8 +199,7 @@ int main(int argc, char* argv[]) { return 1; } - int retval = setsockopt( - pw::rpc::integration_test::GetClientSocketFd(), + int retval = pw::rpc::integration_test::SetClientSockOpt( SOL_SOCKET, SO_SNDBUF, &pw::transfer::integration_test::kMaxSocketSendBufferSize, @@ -214,17 +209,6 @@ int main(int argc, char* argv[]) { "Failed to configure socket send buffer size with errno=%d", errno); - retval = - setsockopt(pw::rpc::integration_test::GetClientSocketFd(), - SOL_SOCKET, - SO_RCVTIMEO, - &pw::transfer::integration_test::kSocketReadTimeout, - sizeof(pw::transfer::integration_test::kSocketReadTimeout)); - PW_CHECK_INT_EQ(retval, - 0, - "Failed to configure socket receive timeout with errno=%d", - errno); - if (!pw::transfer::integration_test::PerformTransferActions(config).ok()) { PW_LOG_INFO("Failed to transfer!"); return 1; diff --git a/pw_transfer/integration_test/server.cc b/pw_transfer/integration_test/server.cc index 98a29f0450..7aacd4a180 100644 --- a/pw_transfer/integration_test/server.cc +++ b/pw_transfer/integration_test/server.cc @@ -158,11 +158,11 @@ void RunServer(int socket_port, ServerConfig config) { thread::Thread transfer_thread_handle = thread::Thread(thread::stl::Options(), transfer_thread); - int retval = setsockopt(rpc::system_server::GetServerSocketFd(), - SOL_SOCKET, - SO_SNDBUF, - &kMaxSocketSendBufferSize, - sizeof(kMaxSocketSendBufferSize)); + int retval = + rpc::system_server::SetServerSockOpt(SOL_SOCKET, + SO_SNDBUF, + &kMaxSocketSendBufferSize, + sizeof(kMaxSocketSendBufferSize)); PW_CHECK_INT_EQ(retval, 0, "Failed to configure socket send buffer size with errno=%d", diff --git a/targets/host/system_rpc_server.cc b/targets/host/system_rpc_server.cc index 8704be48ce..f4053cd0d9 100644 --- a/targets/host/system_rpc_server.cc +++ b/targets/host/system_rpc_server.cc @@ -49,7 +49,12 @@ void set_socket_port(uint16_t new_socket_port) { socket_port = new_socket_port; } -int GetServerSocketFd() { return socket_stream.connection_fd(); } +int SetServerSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen) { + return socket_stream.SetSockOpt(level, optname, optval, optlen); +} void Init() { log_basic::SetOutput([](std::string_view log) {