Skip to content

Commit

Permalink
Add a method to reject a TCP connection
Browse files Browse the repository at this point in the history
  • Loading branch information
ThadHouse committed Sep 8, 2024
1 parent 9f2f0d4 commit 9aaec8c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 9 deletions.
5 changes: 4 additions & 1 deletion src/inc/quic_datapath.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,14 @@ typedef struct CXPLAT_QEO_CONNECTION {

//
// Function pointer type for datapath TCP accept callbacks.
// Any QUIC_FAILED status will reject the connection.
// Do not call CxPlatSocketDelete from this callback, it will
// crash.
//
typedef
_IRQL_requires_max_(DISPATCH_LEVEL)
_Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK)
void
QUIC_STATUS
(CXPLAT_DATAPATH_ACCEPT_CALLBACK)(
_In_ CXPLAT_SOCKET* ListenerSocket,
_In_ void* ListenerContext,
Expand Down
3 changes: 2 additions & 1 deletion src/perf/lib/Tcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ bool TcpServer::Start(const QUIC_ADDR* LocalAddress)

_IRQL_requires_max_(DISPATCH_LEVEL)
_Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK)
void
QUIC_STATUS
TcpServer::AcceptCallback(
_In_ CXPLAT_SOCKET* /* ListenerSocket */,
_In_ void* ListenerContext,
Expand All @@ -373,6 +373,7 @@ TcpServer::AcceptCallback(
auto This = (TcpServer*)ListenerContext;
auto Connection = new(std::nothrow) TcpConnection(This->Engine, This->SecConfig, AcceptSocket, This);
*AcceptClientContext = Connection;
return QUIC_STATUS_SUCCESS;
}

// ############################ CONNECTION ############################
Expand Down
2 changes: 1 addition & 1 deletion src/perf/lib/Tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class TcpServer {
static
_IRQL_requires_max_(DISPATCH_LEVEL)
_Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK)
void
QUIC_STATUS
AcceptCallback(
_In_ CXPLAT_SOCKET* ListenerSocket,
_In_ void* ListenerContext,
Expand Down
7 changes: 5 additions & 2 deletions src/platform/datapath_epoll.c
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ CxPlatSocketContextInitialize(
// Only set SO_REUSEPORT on a server socket, otherwise the client could be
// assigned a server port (unless it's forcing sharing).
//
if ((Config->Flags & CXPLAT_SOCKET_FLAG_SHARE || Config->RemoteAddress == NULL) &&
if ((Config->Flags & CXPLAT_SOCKET_FLAG_SHARE || Config->RemoteAddress == NULL) &&
SocketContext->Binding->Datapath->PartitionCount > 1) {
//
// The port is shared across processors.
Expand Down Expand Up @@ -1552,11 +1552,14 @@ CxPlatSocketContextAcceptCompletion(

CxPlatSocketContextSetEvents(&SocketContext->AcceptSocket->SocketContexts[0], EPOLL_CTL_ADD, EPOLLIN);
SocketContext->AcceptSocket->SocketContexts[0].IoStarted = TRUE;
Datapath->TcpHandlers.Accept(
Status = Datapath->TcpHandlers.Accept(
SocketContext->Binding,
SocketContext->Binding->ClientContext,
SocketContext->AcceptSocket,
&SocketContext->AcceptSocket->ClientContext);
if (QUIC_FAILED(Status)) {
goto Error;
}

SocketContext->AcceptSocket = NULL;

Expand Down
6 changes: 5 additions & 1 deletion src/platform/datapath_winuser.c
Original file line number Diff line number Diff line change
Expand Up @@ -2667,11 +2667,15 @@ CxPlatDataPathSocketProcessAcceptCompletion(
goto Error;
}

Datapath->TcpHandlers.Accept(
QUIC_STATUS Status = Datapath->TcpHandlers.Accept(
ListenerSocketProc->Parent,
ListenerSocketProc->Parent->ClientContext,
ListenerSocketProc->AcceptSocket,
&ListenerSocketProc->AcceptSocket->ClientContext);
if (QUIC_FAILED(Status)) {
goto Error;
}

ListenerSocketProc->AcceptSocket = NULL;

AcceptSocketProc->IoStarted = TRUE;
Expand Down
53 changes: 50 additions & 3 deletions src/platform/unittest/DataPathTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ struct TcpListenerContext {
CXPLAT_SOCKET* Server;
TcpClientContext ServerContext;
bool Accepted : 1;
bool Reject : 1;
bool Rejected : 1;
CXPLAT_EVENT AcceptEvent;
TcpListenerContext() : Server(nullptr), Accepted(false) {
TcpListenerContext() : Server(nullptr), Accepted(false), Reject{false}, Rejected{false} {
CxPlatEventInitialize(&AcceptEvent, FALSE, FALSE);
}
~TcpListenerContext() {
Expand Down Expand Up @@ -317,14 +319,16 @@ struct DataPathTest : public ::testing::TestWithParam<int32_t>
CxPlatRecvDataReturn(RecvDataChain);
}

static void
static QUIC_STATUS
EmptyAcceptCallback(
_In_ CXPLAT_SOCKET* /* ListenerSocket */,
_In_ void* /* ListenerContext */,
_In_ CXPLAT_SOCKET* /* ClientSocket */,
_Out_ void** /* ClientContext */
)
{
// If we somehow get a connection here, reject it
return QUIC_STATUS_CONNECTION_REFUSED;
}

static void
Expand All @@ -336,7 +340,7 @@ struct DataPathTest : public ::testing::TestWithParam<int32_t>
{
}

static void
static QUIC_STATUS
TcpAcceptCallback(
_In_ CXPLAT_SOCKET* /* ListenerSocket */,
_In_ void* Context,
Expand All @@ -345,10 +349,16 @@ struct DataPathTest : public ::testing::TestWithParam<int32_t>
)
{
TcpListenerContext* ListenerContext = (TcpListenerContext*)Context;
if (ListenerContext->Reject) {
ListenerContext->Rejected = true;
CxPlatEventSet(ListenerContext->AcceptEvent);
return QUIC_STATUS_CONNECTION_REFUSED;
}
ListenerContext->Server = ClientSocket;
*ClientContext = &ListenerContext->ServerContext;
ListenerContext->Accepted = true;
CxPlatEventSet(ListenerContext->AcceptEvent);
return QUIC_STATUS_SUCCESS;
}

static void
Expand Down Expand Up @@ -1077,6 +1087,43 @@ TEST_P(DataPathTest, TcpConnect)
ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.DisconnectEvent, 500));
}

TEST_P(DataPathTest, TcpRejectConnect)
{
CxPlatDataPath Datapath(nullptr, &TcpRecvCallbacks);
if (!Datapath.IsSupported(CXPLAT_DATAPATH_FEATURE_TCP)) {
GTEST_SKIP_("TCP is not supported");
}
VERIFY_QUIC_SUCCESS(Datapath.GetInitStatus());
ASSERT_NE(nullptr, Datapath.Datapath);

TcpListenerContext ListenerContext;
auto serverAddress = GetNewLocalAddr();
CxPlatSocket Listener; Listener.CreateTcpListener(Datapath, &serverAddress.SockAddr, &ListenerContext);
while (Listener.GetInitStatus() == QUIC_STATUS_ADDRESS_IN_USE) {
serverAddress.SockAddr.Ipv4.sin_port = GetNextPort();
Listener.CreateTcpListener(Datapath, &serverAddress.SockAddr, &ListenerContext);
}
VERIFY_QUIC_SUCCESS(Listener.GetInitStatus());
ASSERT_NE(nullptr, Listener.Socket);
serverAddress.SockAddr = Listener.GetLocalAddress();
ASSERT_NE(serverAddress.SockAddr.Ipv4.sin_port, (uint16_t)0);

ListenerContext.Reject = true;

TcpClientContext ClientContext;
CxPlatSocket Client; Client.CreateTcp(Datapath, nullptr, &serverAddress.SockAddr, &ClientContext);
VERIFY_QUIC_SUCCESS(Client.GetInitStatus());
ASSERT_NE(nullptr, Client.Socket);
ASSERT_NE(Client.GetLocalAddress().Ipv4.sin_port, (uint16_t)0);

ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.ConnectEvent, 500));
ASSERT_TRUE(CxPlatEventWaitWithTimeout(ListenerContext.AcceptEvent, 500));
ASSERT_EQ(true, ListenerContext.Rejected);
ASSERT_EQ(nullptr, ListenerContext.Server);

ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.DisconnectEvent, 500));
}

TEST_P(DataPathTest, TcpDisconnect)
{
CxPlatDataPath Datapath(nullptr, &TcpRecvCallbacks);
Expand Down

0 comments on commit 9aaec8c

Please sign in to comment.