Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #969, socket accept using incorrect record #971

Merged
merged 2 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/os/shared/src/osapi-sockets.c
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ int32 OS_SocketAccept(osal_id_t sock_id, osal_id_t *connsock_id, OS_SockAddr_t *
if (return_code == OS_SUCCESS)
{
conn_record = OS_OBJECT_TABLE_GET(OS_global_stream_table, conn_token);
conn = OS_OBJECT_TABLE_GET(OS_stream_table, sock_token);
conn = OS_OBJECT_TABLE_GET(OS_stream_table, conn_token);

/* Incr the refcount to record the fact that an operation is pending on this */
memset(conn, 0, sizeof(OS_stream_internal_record_t));
Expand Down
322 changes: 193 additions & 129 deletions src/tests/network-api-test/network-api-test.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@

#define UT_EXIT_LOOP_MAX 100

/*
* Number of client->server connections to create.
* This tests that the server socket can accept multiple connections.
*/
#define UT_STREAM_CONNECTION_COUNT 4

osal_id_t s_task_id;
osal_id_t p1_socket_id;
osal_id_t p2_socket_id;
Expand All @@ -45,6 +51,8 @@ OS_SockAddr_t s_addr;
OS_SockAddr_t c_addr;
bool networkImplemented = true;

char ServerFn_ErrorString[128];

/*****************************************************************************
*
* Datagram Network Functional Test Setup
Expand Down Expand Up @@ -406,31 +414,68 @@ void Server_Fn(void)
osal_id_t connsock_id = OS_OBJECT_ID_UNDEFINED;
uint32 iter;
OS_SockAddr_t addr;
char Buf_rcv_s[4] = {0};
char Buf_trans[8] = {0};
uint8 Buf_each_char_s[256] = {0};
int32 Status;

/* Accept incoming connections */
OS_SocketAccept(s_socket_id, &connsock_id, &addr, OS_PEND);
/* Fill the memory with a count pattern */
UtMemFill(Buf_each_char_s, sizeof(Buf_each_char_s));

/* Recieve incoming data from client*/
OS_TimedRead(connsock_id, Buf_rcv_s, sizeof(Buf_rcv_s), 10);
iter = 0;
while (iter < UT_STREAM_CONNECTION_COUNT)
{
++iter;

/* Transform the incoming data and send it back to client */
strcpy(Buf_trans, "uvw");
strcat(Buf_trans, Buf_rcv_s);
OS_TimedWrite(connsock_id, Buf_trans, sizeof(Buf_trans), 10);
/* Accept incoming connections */
Status = OS_SocketAccept(s_socket_id, &connsock_id, &addr, OS_PEND);
if (Status != OS_SUCCESS)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_SocketAccept() return code=%d",
(int)Status);
break;
}

/* Send all 256 chars to client */
for (iter = 0; iter < 256; iter++)
{
Buf_each_char_s[iter] = iter;
}
/* Recieve incoming data from client (should be exactly 4 bytes) */
Status = OS_TimedRead(connsock_id, Buf_trans, sizeof(Buf_trans), 10);
if (Status != 4)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedRead() return code=%d", (int)Status);
break;
}

/* Send back to client:
* 1. uint32 value indicating number of connections so far (4 bytes)
* 2. Original value recieved above (4 bytes)
* 3. String of all possible 8-bit chars [0-255] (256 bytes)
*/
Status = OS_TimedWrite(connsock_id, &iter, sizeof(iter), 10);
if (Status != sizeof(iter))
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(uint32) return code=%d",
(int)Status);
break;
}

OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10);
Status = OS_TimedWrite(connsock_id, Buf_trans, 4, 10);
if (Status != 4)
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(Buf_trans) return code=%d",
(int)Status);
break;
}

Status = OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10);
if (Status != sizeof(Buf_each_char_s))
{
snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString),
"OS_TimedWrite(Buf_each_char_s) return code=%d", (int)Status);
break;
}

OS_close(connsock_id);
}

OS_close(s_socket_id);
OS_close(connsock_id);

} /* end Server_Fn */

Expand All @@ -451,14 +496,16 @@ void TestStreamNetworkApi(void)
OS_task_prop_t taskprop;
char Buf_rcv_c[4] = {0};
char Buf_send_c[4] = {0};
char Buf_rcv_trans[8] = {0};
char Buf_expec_trans[8] = {0};
uint8 Buf_each_expected[256] = {0};
uint8 Buf_each_char_rcv[256] = {0};

/*
* Set up a server
* NOTE: The server cannot directly use UtAssert because the library is not thread-safe
* If the server task encounters an error, it will write the string to this buffer, and it
* will be reported at the end of this routine.
*
* Be sure it is empty to start with.
*/
memset(ServerFn_ErrorString, 0, sizeof(ServerFn_ErrorString));

/* Open a server socket */
s_socket_id = OS_OBJECT_ID_UNDEFINED;
Expand Down Expand Up @@ -493,13 +540,15 @@ void TestStreamNetworkApi(void)
* Set up a client
*/

/* Open a client socket */
expected = OS_SUCCESS;
c_socket_id = OS_OBJECT_ID_UNDEFINED;
/*
* Create a server thread, and connect client from
* this thread to server thread and verify connection
*/

actual = OS_SocketOpen(&c_socket_id, OS_SocketDomain_INET, OS_SocketType_STREAM);
UtAssert_True(actual == expected, "OS_SocketOpen() (%ld) == OS_SUCCESS", (long)actual);
UtAssert_True(OS_ObjectIdDefined(c_socket_id), "c_socket_id (%lu) != 0", OS_ObjectIdToInteger(c_socket_id));
/* Create a server task/thread */
status = OS_TaskCreate(&s_task_id, "Server", Server_Fn, OSAL_TASK_STACK_ALLOCATE, OSAL_SIZE_C(16384),
OSAL_PRIORITY_C(50), 0);
UtAssert_True(status == OS_SUCCESS, "OS_TaskCreate() (%ld) == OS_SUCCESS", (long)status);

/* Initialize client address */
actual = OS_SocketAddrInit(&c_addr, OS_SocketDomain_INET);
Expand All @@ -514,113 +563,122 @@ void TestStreamNetworkApi(void)
UtAssert_True(actual == expected, "OS_SocketAddrFromString() (%ld) == OS_SUCCESS", (long)actual);

/*
* Create a server thread, and connect client from
* this thread to server thread and verify connection
* Connect to a server - this is done in a loop
* to confirm a server socket can be re-used for multiple clients
*/

/* Create a server task/thread */
status = OS_TaskCreate(&s_task_id, "Server", Server_Fn, OSAL_TASK_STACK_ALLOCATE, OSAL_SIZE_C(16384),
OSAL_PRIORITY_C(50), 0);
UtAssert_True(status == OS_SUCCESS, "OS_TaskCreate() (%ld) == OS_SUCCESS", (long)status);

/* Connect to a server */
actual = OS_SocketConnect(c_socket_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_SUCCESS", (long)actual);

/*
* Test for invalid input parameters
*/

/* OS_TimedRead */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedRead(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedRead(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_ERROR_TIMEOUT;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 0);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* OS_TimedWrite */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedWrite(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedWrite(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* OS_SocketAccept */
expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, NULL, 0);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, &temp_addr, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, &temp_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

/* OS_SocketConnect */
expected = OS_INVALID_POINTER;
actual = OS_SocketConnect(c_socket_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_ERR_INCORRECT_OBJ_STATE;
actual = OS_SocketConnect(c_socket_id, &s_addr, 0);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INCORRECT_OBJ_STATE", (long)actual);

expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_SocketConnect(temp_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INVALID_ID", (long)actual);

/*
* Once connection is made between
* server and client, transfer data
*/

/* Send data to server to be transformed and sent back */
strcpy(Buf_send_c, "xyz");
expected = sizeof(Buf_send_c);
actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* Recieve back transformed data from server*/
expected = sizeof(Buf_expec_trans);
strcpy(Buf_expec_trans, "uvwxyz");

actual = OS_TimedRead(c_socket_id, Buf_rcv_trans, sizeof(Buf_rcv_trans), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_True(strcmp(Buf_rcv_trans, Buf_expec_trans) == 0, "Buf_rcv_trans (%s) == Buf_expected (%s)",
Buf_rcv_trans, Buf_expec_trans);

/* Recieve all 256 chars from server one at a time */
expected = sizeof(Buf_each_char_rcv);
actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* Verify all 256 chars received */
for (iter = 0; iter < 256; iter++)
iter = 0;
while (iter < UT_STREAM_CONNECTION_COUNT)
{
Buf_each_expected[iter] = iter;
/* Open a client socket */
expected = OS_SUCCESS;
c_socket_id = OS_OBJECT_ID_UNDEFINED;

actual = OS_SocketOpen(&c_socket_id, OS_SocketDomain_INET, OS_SocketType_STREAM);
UtAssert_True(actual == expected, "OS_SocketOpen() (%ld) == OS_SUCCESS", (long)actual);
UtAssert_True(OS_ObjectIdDefined(c_socket_id), "c_socket_id (%lu) != 0", OS_ObjectIdToInteger(c_socket_id));

actual = OS_SocketConnect(c_socket_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_SUCCESS", (long)actual);

/*
* Test for invalid input parameters -
* This is done after valid connection when the c_socket_id is valid,
* but it only needs to be done once, so only do this on the first pass.
*/
if (iter == 0)
{
/* OS_TimedRead */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedRead(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedRead(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_ERROR_TIMEOUT;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 0);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/* OS_TimedWrite */
expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_TimedWrite(temp_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

expected = OS_INVALID_POINTER;
actual = OS_TimedWrite(c_socket_id, NULL, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* OS_SocketAccept */
expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, NULL, 0);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, NULL, &temp_addr, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_INVALID_POINTER;
actual = OS_SocketAccept(s_socket_id, &temp_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketAccept() (%ld) == OS_INVALID_POINTER", (long)actual);

/* OS_SocketConnect */
expected = OS_INVALID_POINTER;
actual = OS_SocketConnect(c_socket_id, NULL, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_INVALID_POINTER", (long)actual);

expected = OS_ERR_INCORRECT_OBJ_STATE;
actual = OS_SocketConnect(c_socket_id, &s_addr, 0);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INCORRECT_OBJ_STATE",
(long)actual);

expected = OS_ERR_INVALID_ID;
temp_id = OS_ObjectIdFromInteger(0xFFFFFFFF);
actual = OS_SocketConnect(temp_id, &s_addr, 10);
UtAssert_True(actual == expected, "OS_SocketConnect() (%ld) == OS_ERR_INVALID_ID", (long)actual);
}

/*
* Once connection is made between
* server and client, transfer data
*/
++iter;
snprintf(Buf_send_c, sizeof(Buf_send_c), "%03x", (iter + 0xabc) & 0xfff);

/* Send data to server */
expected = sizeof(Buf_send_c);
actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10);
UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected);

/* Recieve back data from server, first is loop count */
expected = sizeof(loopcnt);
actual = OS_TimedRead(c_socket_id, &loopcnt, sizeof(loopcnt), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_UINT32_EQ(iter, loopcnt);

/* Recieve back data from server, next is original string */
expected = sizeof(Buf_rcv_c);
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_True(strcmp(Buf_send_c, Buf_rcv_c) == 0, "Buf_rcv_c (%s) == Buf_send_c (%s)", Buf_rcv_c,
Buf_send_c);

/* Recieve back data from server, next is 8-bit charset */
expected = sizeof(Buf_each_char_rcv);
actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);
UtAssert_MemCmpCount(Buf_each_char_rcv, sizeof(Buf_each_char_rcv), "Verify byte count pattern");

/* Server should close the socket, reads will return 0 indicating EOF */
expected = 0;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

OS_close(c_socket_id);
}

UtAssert_True(memcmp(Buf_each_expected, Buf_each_char_rcv, sizeof(Buf_each_expected)) == 0,
"buffer content match");

/* Once connection socket is closed, verify that no data is recieved */
expected = 0;
actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10);
UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected);

/*
* NOTE: Tests for invalid and other nominal input parameters
* to some of the network functions being called here are already
Expand All @@ -635,6 +693,12 @@ void TestStreamNetworkApi(void)
loopcnt++;
}
UtAssert_True(loopcnt < UT_EXIT_LOOP_MAX, "Task exited after %ld iterations", (long)loopcnt);

/* Check that the server function did NOT Report any errors */
if (ServerFn_ErrorString[0] != 0)
{
UtAssert_Failed("Server_Fn(): %s", ServerFn_ErrorString);
}
}

} /* end TestStreamNetworkApi */
Expand Down