Skip to content

Commit

Permalink
Merge pull request #971 from jphickey/fix-969-socket-accept
Browse files Browse the repository at this point in the history
Fix #969, socket accept using incorrect record
  • Loading branch information
astrogeco authored Apr 28, 2021
2 parents 8fad92d + 05a20f1 commit bcb8050
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 130 deletions.
2 changes: 1 addition & 1 deletion src/os/shared/src/osapi-sockets.c
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ int32 OS_SocketAccept(osal_id_t sock_id, osal_id_t *connsock_id, OS_SockAddr_t *
if (return_code == OS_SUCCESS)
{
conn_record = OS_OBJECT_TABLE_GET(OS_global_stream_table, conn_token);
conn = OS_OBJECT_TABLE_GET(OS_stream_table, sock_token);
conn = OS_OBJECT_TABLE_GET(OS_stream_table, conn_token);

/* Incr the refcount to record the fact that an operation is pending on this */
memset(conn, 0, sizeof(OS_stream_internal_record_t));
Expand Down
322 changes: 193 additions & 129 deletions src/tests/network-api-test/network-api-test.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@

#define UT_EXIT_LOOP_MAX 100

/*
* Number of client->server connections to create.
* This tests that the server socket can accept multiple connections.
*/
#define UT_STREAM_CONNECTION_COUNT 4

osal_id_t s_task_id;
osal_id_t p1_socket_id;
osal_id_t p2_socket_id;
Expand All @@ -45,6 +51,8 @@ OS_SockAddr_t s_addr;
OS_SockAddr_t c_addr;
bool networkImplemented = true;

char ServerFn_ErrorString[128];

/*****************************************************************************
*
* Datagram Network Functional Test Setup
Expand Down Expand Up @@ -406,31 +414,68 @@ void Server_Fn(void)
osal_id_t connsock_id = OS_OBJECT_ID_UNDEFINED;
uint32 iter;
OS_SockAddr_t addr;
char Buf_rcv_s[4] = {0};
char Buf_trans[8] = {0};
uint8 Buf_each_char_s[256] = {0};
int32 Status;

/* Accept incoming connections */
OS_SocketAccept(s_socket_id, &connsock_id, &addr, OS_PEND);
/* Fill the memory with a count pattern */
UtMemFill(Buf_each_char_s, sizeof(Buf_each_char_s));

/* Recieve incoming data from client*/
OS_TimedRead(connsock_id, Buf_rcv_s, sizeof(Buf_rcv_s), 10);
iter = 0;
while (iter < UT_STREAM_CONNECTION_COUNT)
{
++iter;

/* Transform the incoming data and send it back to client */
strcpy(Buf_trans, "uvw");
strcat(Buf_trans, Buf_rcv_s);
OS_TimedWrite(connsock_id, Buf_trans, sizeof(Buf_trans), 10);
/* Accept incoming connections */
Status = OS_SocketAccept(s_socket_id, &connsock_id, &addr, OS_PEND);
if (Status != OS_SUCCESS)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_SocketAccept() return code=%d",
(int)Status);
break;
}

/* Send all 256 chars to client */
for (iter = 0; iter < 256; iter++)
{
Buf_each_char_s[iter] = iter;
}
/* Recieve incoming data from client (should be exactly 4 bytes) */
Status = OS_TimedRead(connsock_id, Buf_trans, sizeof(Buf_trans), 10);
if (Status != 4)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedRead() return code=%d", (int)Status);
break;
}

/* Send back to client:
* 1. uint32 value indicating number of connections so far (4 bytes)
* 2. Original value recieved above (4 bytes)
* 3. String of all possible 8-bit chars [0-255] (256 bytes)
*/
Status = OS_TimedWrite(connsock_id, &iter, sizeof(iter), 10);
if (Status != sizeof(iter))
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(uint32) return code=%d",
(int)Status);
break;
}

OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10);
Status = OS_TimedWrite(connsock_id, Buf_trans, 4, 10);
if (Status != 4)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(Buf_trans) return code=%d",
(int)Status);
break;
}

Status = OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10);
if (Status != sizeof(Buf_each_char_s))
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString),
"OS_TimedWrite(Buf_each_char_s) return code=%d", (int)Status);
break;
}

OS_close(connsock_id);
}

OS_close(s_socket_id);
OS_close(connsock_id);

} /* end Server_Fn */

Expand All @@ -451,14 +496,16 @@ void TestStreamNetworkApi(void)
OS_task_prop_t taskprop;
char Buf_rcv_c[4] = {0};
char Buf_send_c[4] = {0};
char Buf_rcv_trans[8] = {0};
char Buf_expec_trans[8] = {0};
uint8 Buf_each_expected[256] = {0};
uint8 Buf_each_char_rcv[256] = {0};

/*
* Set up a server
* NOTE: The server cannot directly use UtAssert because the library is not thread-safe
* If the server task encounters an error, it will write the string to this buffer, and it
* will be reported at the end of this routine.
*
* Be sure it is empty to start with.
*/
memset(ServerFn_ErrorString, 0, sizeof(ServerFn_ErrorString));

/* Open a server socket */
s_socket_id = OS_OBJECT_ID_UNDEFINED;
Expand Down Expand Up @@ -493,13 +540,15 @@ void TestStreamNetworkApi(void)
* Set up a client
*/

/* Open a client socket */
expected = OS_SUCCESS;
c_socket_id = OS_OBJECT_ID_UNDEFINED;
/*
* Create a server thread, and connect client from
* this thread to server thread and verify connection
*/

actual = OS_SocketOpen(&c_socket_id, OS_SocketDomain_INET, OS_SocketType_STREAM);
UtAssert_True(actual == expected, "OS_SocketOpen() (%ld) == OS_SUCCESS", (long)actual);
UtAssert_True(OS_ObjectIdDefined(c_socket_id), "c_socket_id (%lu) != 0", OS_ObjectIdToInteger(c_socket_id));
/* Create a server task/thread */
status = OS_TaskCreate(&s_task_id, "Server", Server_Fn, OSAL_TASK_STACK_ALLOCATE, OSAL_SIZE_C(16384),
OSAL_PRIORITY_C(50), 0);
UtAssert_True(status == OS_SUCCESS, "OS_TaskCreate() (%ld) == OS_SUCCESS", (long)status);

/* Initialize client address */
actual = OS_SocketAddrInit(&c_addr, OS_SocketDomain_INET);
Expand All @@ -514,113 +563,122 @@ void TestStreamNetworkApi(void)
UtAssert_True(actual == expected, "OS_SocketAddrFromString() (%ld) == OS_SUCCESS", (long)actual);

/*
* Create a server thread, and connect client from
* this thread to server thread and verify connection
* Connect to a server - this is done in a loop
* to confirm a server socket can be re-used for multiple clients
*/

/* Create a server task/thread */
status = OS_TaskCreate(&s_task_id, "Server", Server_Fn, OSAL_TASK_STACK_ALLOCATE, OSAL_SIZE_C(16384),
OSAL_PRIORITY_C(50), 0);
UtAssert_True(status == OS_SUCCESS, "OS_TaskCreate() (%ld) == OS_SUCCESS", (long)status);

/* Connect to a server */
actual = OS_SocketConnect(c_socket_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_SUCCESS", (long)actual);

/*
* Test for invalid input parameters
*/

/* OS_TimedRead */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedRead(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedRead(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_ERROR_TIMEOUT;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 0);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* OS_TimedWrite */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedWrite(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedWrite(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* OS_SocketAccept */
expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, NULL, 0);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, &temp_addr, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, &temp_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

/* OS_SocketConnect */
expected = OS_INVALID_POINTER;
actual = OS_SocketConnect(c_socket_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_ERR_INCORRECT_OBJ_STATE;
actual = OS_SocketConnect(c_socket_id, &s_addr, 0);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INCORRECT_OBJ_STATE", (long)actual);

expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_SocketConnect(temp_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INVALID_ID", (long)actual);

/*
* Once connection is made between
* server and client, transfer data
*/

/* Send data to server to be transformed and sent back */
strcpy(Buf_send_c, "xyz");
expected = sizeof(Buf_send_c);
actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* Recieve back transformed data from server*/
expected = sizeof(Buf_expec_trans);
strcpy(Buf_expec_trans, "uvwxyz");

actual = OS_TimedRead(c_socket_id, Buf_rcv_trans, sizeof(Buf_rcv_trans), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_True(strcmp(Buf_rcv_trans, Buf_expec_trans) == 0, "Buf_rcv_trans (%s) == Buf_expected (%s)",
Buf_rcv_trans, Buf_expec_trans);

/* Recieve all 256 chars from server one at a time */
expected = sizeof(Buf_each_char_rcv);
actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* Verify all 256 chars received */
for (iter = 0; iter < 256; iter++)
iter = 0;
while (iter < UT_STREAM_CONNECTION_COUNT)
{
Buf_each_expected[iter] = iter;
/* Open a client socket */
expected = OS_SUCCESS;
c_socket_id = OS_OBJECT_ID_UNDEFINED;

actual = OS_SocketOpen(&c_socket_id, OS_SocketDomain_INET, OS_SocketType_STREAM);
UtAssert_True(actual == expected, "OS_SocketOpen() (%ld) == OS_SUCCESS", (long)actual);
UtAssert_True(OS_ObjectIdDefined(c_socket_id), "c_socket_id (%lu) != 0", OS_ObjectIdToInteger(c_socket_id));

actual = OS_SocketConnect(c_socket_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_SUCCESS", (long)actual);

/*
* Test for invalid input parameters -
* This is done after valid connection when the c_socket_id is valid,
* but it only needs to be done once, so only do this on the first pass.
*/
if (iter == 0)
{
/* OS_TimedRead */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedRead(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedRead(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_ERROR_TIMEOUT;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 0);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* OS_TimedWrite */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedWrite(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedWrite(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* OS_SocketAccept */
expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, NULL, 0);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, &temp_addr, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, &temp_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

/* OS_SocketConnect */
expected = OS_INVALID_POINTER;
actual = OS_SocketConnect(c_socket_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_ERR_INCORRECT_OBJ_STATE;
actual = OS_SocketConnect(c_socket_id, &s_addr, 0);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INCORRECT_OBJ_STATE",
(long)actual);

expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_SocketConnect(temp_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INVALID_ID", (long)actual);
}

/*
* Once connection is made between
* server and client, transfer data
*/
++iter;
snprintf(Buf_send_c, sizeof(Buf_send_c), "%03x", (iter + 0xabc) & 0xfff);

/* Send data to server */
expected = sizeof(Buf_send_c);
actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* Recieve back data from server, first is loop count */
expected = sizeof(loopcnt);
actual = OS_TimedRead(c_socket_id, &loopcnt, sizeof(loopcnt), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_UINT32_EQ(iter, loopcnt);

/* Recieve back data from server, next is original string */
expected = sizeof(Buf_rcv_c);
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_True(strcmp(Buf_send_c, Buf_rcv_c) == 0, "Buf_rcv_c (%s) == Buf_send_c (%s)", Buf_rcv_c,
Buf_send_c);

/* Recieve back data from server, next is 8-bit charset */
expected = sizeof(Buf_each_char_rcv);
actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_MemCmpCount(Buf_each_char_rcv, sizeof(Buf_each_char_rcv), "Verify byte count pattern");

/* Server should close the socket, reads will return 0 indicating EOF */
expected = 0;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

OS_close(c_socket_id);
}

UtAssert_True(memcmp(Buf_each_expected, Buf_each_char_rcv, sizeof(Buf_each_expected)) == 0,
"buffer content match");

/* Once connection socket is closed, verify that no data is recieved */
expected = 0;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/*
* NOTE: Tests for invalid and other nominal input parameters
* to some of the network functions being called here are already
Expand All @@ -635,6 +693,12 @@ void TestStreamNetworkApi(void)
loopcnt++;
}
UtAssert_True(loopcnt < UT_EXIT_LOOP_MAX, "Task exited after %ld iterations", (long)loopcnt);

/* Check that the server function did NOT Report any errors */
if (ServerFn_ErrorString[0] != 0)
{
UtAssert_Failed("Server_Fn(): %s", ServerFn_ErrorString);
}
}

} /* end TestStreamNetworkApi */
Expand Down

0 comments on commit bcb8050

Please sign in to comment.