Skip to content

Commit

Permalink
refactor: Use tox memory allocator for temporary buffers in crypto.
Browse files Browse the repository at this point in the history
  • Loading branch information
iphydf committed Nov 27, 2024
1 parent 819aa2b commit 9e8c8b7
Show file tree
Hide file tree
Showing 28 changed files with 232 additions and 187 deletions.
16 changes: 8 additions & 8 deletions auto_tests/TCP_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static void test_basic(void)
random_nonce(rng, handshake + CRYPTO_PUBLIC_KEY_SIZE);

// Encrypting handshake
int ret = encrypt_data(self_public_key, f_secret_key, handshake + CRYPTO_PUBLIC_KEY_SIZE, handshake_plain,
int ret = encrypt_data(mem, self_public_key, f_secret_key, handshake + CRYPTO_PUBLIC_KEY_SIZE, handshake_plain,
TCP_HANDSHAKE_PLAIN_SIZE, handshake + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE);
ck_assert_msg(ret == TCP_CLIENT_HANDSHAKE_SIZE - (CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE),
"encrypt_data() call failed.");
Expand All @@ -128,7 +128,7 @@ static void test_basic(void)
uint8_t response_plain[TCP_HANDSHAKE_PLAIN_SIZE];
ck_assert_msg(net_recv(ns, logger, sock, response, TCP_SERVER_HANDSHAKE_SIZE, &localhost) == TCP_SERVER_HANDSHAKE_SIZE,
"Could/did not receive a server response to the initial handshake.");
ret = decrypt_data(self_public_key, f_secret_key, response, response + CRYPTO_NONCE_SIZE,
ret = decrypt_data(mem, self_public_key, f_secret_key, response, response + CRYPTO_NONCE_SIZE,
TCP_SERVER_HANDSHAKE_SIZE - CRYPTO_NONCE_SIZE, response_plain);
ck_assert_msg(ret == TCP_HANDSHAKE_PLAIN_SIZE, "Failed to decrypt handshake response.");
uint8_t f_nonce_r[CRYPTO_NONCE_SIZE];
Expand All @@ -143,7 +143,7 @@ static void test_basic(void)
uint8_t r_req[2 + 1 + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_MAC_SIZE];
uint16_t size = 1 + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_MAC_SIZE;
size = net_htons(size);
encrypt_data_symmetric(f_shared_key, f_nonce, r_req_p, 1 + CRYPTO_PUBLIC_KEY_SIZE, r_req + 2);
encrypt_data_symmetric(mem, f_shared_key, f_nonce, r_req_p, 1 + CRYPTO_PUBLIC_KEY_SIZE, r_req + 2);
increment_nonce(f_nonce);
memcpy(r_req, &size, 2);

Expand Down Expand Up @@ -174,7 +174,7 @@ static void test_basic(void)
"Wrong packet size for request response.");

uint8_t packet_resp_plain[4096];
ret = decrypt_data_symmetric(f_shared_key, f_nonce_r, packet_resp + 2, recv_data_len - 2, packet_resp_plain);
ret = decrypt_data_symmetric(mem, f_shared_key, f_nonce_r, packet_resp + 2, recv_data_len - 2, packet_resp_plain);
ck_assert_msg(ret != -1, "Failed to decrypt the TCP server's response.");
increment_nonce(f_nonce_r);

Expand Down Expand Up @@ -228,7 +228,7 @@ static struct sec_TCP_con *new_tcp_con(const Logger *logger, const Memory *mem,
memcpy(handshake, sec_c->public_key, CRYPTO_PUBLIC_KEY_SIZE);
random_nonce(rng, handshake + CRYPTO_PUBLIC_KEY_SIZE);

int ret = encrypt_data(tcp_server_public_key(tcp_s), f_secret_key, handshake + CRYPTO_PUBLIC_KEY_SIZE, handshake_plain,
int ret = encrypt_data(mem, tcp_server_public_key(tcp_s), f_secret_key, handshake + CRYPTO_PUBLIC_KEY_SIZE, handshake_plain,
TCP_HANDSHAKE_PLAIN_SIZE, handshake + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE);
ck_assert_msg(ret == TCP_CLIENT_HANDSHAKE_SIZE - (CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE),
"Failed to encrypt the outgoing handshake.");
Expand All @@ -248,7 +248,7 @@ static struct sec_TCP_con *new_tcp_con(const Logger *logger, const Memory *mem,
uint8_t response_plain[TCP_HANDSHAKE_PLAIN_SIZE];
ck_assert_msg(net_recv(sec_c->ns, logger, sock, response, TCP_SERVER_HANDSHAKE_SIZE, &localhost) == TCP_SERVER_HANDSHAKE_SIZE,
"Failed to receive server handshake response.");
ret = decrypt_data(tcp_server_public_key(tcp_s), f_secret_key, response, response + CRYPTO_NONCE_SIZE,
ret = decrypt_data(mem, tcp_server_public_key(tcp_s), f_secret_key, response, response + CRYPTO_NONCE_SIZE,
TCP_SERVER_HANDSHAKE_SIZE - CRYPTO_NONCE_SIZE, response_plain);
ck_assert_msg(ret == TCP_HANDSHAKE_PLAIN_SIZE, "Failed to decrypt server handshake response.");
encrypt_precompute(response_plain, t_secret_key, sec_c->shared_key);
Expand All @@ -271,7 +271,7 @@ static int write_packet_tcp_test_connection(const Logger *logger, struct sec_TCP

uint16_t c_length = net_htons(length + CRYPTO_MAC_SIZE);
memcpy(packet, &c_length, sizeof(uint16_t));
int len = encrypt_data_symmetric(con->shared_key, con->sent_nonce, data, length, packet + sizeof(uint16_t));
int len = encrypt_data_symmetric(con->mem, con->shared_key, con->sent_nonce, data, length, packet + sizeof(uint16_t));

if ((unsigned int)len != (packet_size - sizeof(uint16_t))) {
return -1;
Expand All @@ -296,7 +296,7 @@ static int read_packet_sec_tcp(const Logger *logger, struct sec_TCP_con *con, ui

int rlen = net_recv(con->ns, logger, con->sock, data, length, &localhost);
ck_assert_msg(rlen == length, "Did not receive packet of correct length. Wanted %i, instead got %i", length, rlen);
rlen = decrypt_data_symmetric(con->shared_key, con->recv_nonce, data + 2, length - 2, data);
rlen = decrypt_data_symmetric(con->mem, con->shared_key, con->recv_nonce, data + 2, length - 2, data);
ck_assert_msg(rlen != -1, "Failed to decrypt a received packet from the Relay server.");
increment_nonce(con->recv_nonce);
return rlen;
Expand Down
50 changes: 32 additions & 18 deletions auto_tests/crypto_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ static const uint8_t test_c[147] = {

static void test_known(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);

uint8_t c[147];
uint8_t m[131];

Expand All @@ -88,19 +91,22 @@ static void test_known(void)
ck_assert_msg(sizeof(test_c) == sizeof(c), "sanity check failed");
ck_assert_msg(sizeof(test_m) == sizeof(m), "sanity check failed");

const uint16_t clen = encrypt_data(bobpk, alicesk, test_nonce, test_m, sizeof(test_m) / sizeof(uint8_t), c);
const uint16_t clen = encrypt_data(mem, bobpk, alicesk, test_nonce, test_m, sizeof(test_m) / sizeof(uint8_t), c);

ck_assert_msg(memcmp(test_c, c, sizeof(c)) == 0, "cyphertext doesn't match test vector");
ck_assert_msg(clen == sizeof(c) / sizeof(uint8_t), "wrong ciphertext length");

const uint16_t mlen = decrypt_data(bobpk, alicesk, test_nonce, test_c, sizeof(test_c) / sizeof(uint8_t), m);
const uint16_t mlen = decrypt_data(mem, bobpk, alicesk, test_nonce, test_c, sizeof(test_c) / sizeof(uint8_t), m);

ck_assert_msg(memcmp(test_m, m, sizeof(m)) == 0, "decrypted text doesn't match test vector");
ck_assert_msg(mlen == sizeof(m) / sizeof(uint8_t), "wrong plaintext length");
}

static void test_fast_known(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);

uint8_t k[CRYPTO_SHARED_KEY_SIZE];
uint8_t c[147];
uint8_t m[131];
Expand All @@ -112,19 +118,21 @@ static void test_fast_known(void)
ck_assert_msg(sizeof(test_c) == sizeof(c), "sanity check failed");
ck_assert_msg(sizeof(test_m) == sizeof(m), "sanity check failed");

const uint16_t clen = encrypt_data_symmetric(k, test_nonce, test_m, sizeof(test_m) / sizeof(uint8_t), c);
const uint16_t clen = encrypt_data_symmetric(mem, k, test_nonce, test_m, sizeof(test_m) / sizeof(uint8_t), c);

ck_assert_msg(memcmp(test_c, c, sizeof(c)) == 0, "cyphertext doesn't match test vector");
ck_assert_msg(clen == sizeof(c) / sizeof(uint8_t), "wrong ciphertext length");

const uint16_t mlen = decrypt_data_symmetric(k, test_nonce, test_c, sizeof(test_c) / sizeof(uint8_t), m);
const uint16_t mlen = decrypt_data_symmetric(mem, k, test_nonce, test_c, sizeof(test_c) / sizeof(uint8_t), m);

ck_assert_msg(memcmp(test_m, m, sizeof(m)) == 0, "decrypted text doesn't match test vector");
ck_assert_msg(mlen == sizeof(m) / sizeof(uint8_t), "wrong plaintext length");
}

static void test_endtoend(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);
const Random *rng = os_random();
ck_assert(rng != nullptr);

Expand Down Expand Up @@ -166,21 +174,21 @@ static void test_endtoend(void)
ck_assert_msg(memcmp(k1, k2, CRYPTO_SHARED_KEY_SIZE) == 0, "encrypt_precompute: bad");

//Encrypt all four ways
const uint16_t c1len = encrypt_data(pk2, sk1, n, m, mlen, c1);
const uint16_t c2len = encrypt_data(pk1, sk2, n, m, mlen, c2);
const uint16_t c3len = encrypt_data_symmetric(k1, n, m, mlen, c3);
const uint16_t c4len = encrypt_data_symmetric(k2, n, m, mlen, c4);
const uint16_t c1len = encrypt_data(mem, pk2, sk1, n, m, mlen, c1);
const uint16_t c2len = encrypt_data(mem, pk1, sk2, n, m, mlen, c2);
const uint16_t c3len = encrypt_data_symmetric(mem, k1, n, m, mlen, c3);
const uint16_t c4len = encrypt_data_symmetric(mem, k2, n, m, mlen, c4);

ck_assert_msg(c1len == c2len && c1len == c3len && c1len == c4len, "cyphertext lengths differ");
ck_assert_msg(c1len == mlen + (uint16_t)CRYPTO_MAC_SIZE, "wrong cyphertext length");
ck_assert_msg(memcmp(c1, c2, c1len) == 0 && memcmp(c1, c3, c1len) == 0
&& memcmp(c1, c4, c1len) == 0, "crypertexts differ");

//Decrypt all four ways
const uint16_t m1len = decrypt_data(pk2, sk1, n, c1, c1len, m1);
const uint16_t m2len = decrypt_data(pk1, sk2, n, c1, c1len, m2);
const uint16_t m3len = decrypt_data_symmetric(k1, n, c1, c1len, m3);
const uint16_t m4len = decrypt_data_symmetric(k2, n, c1, c1len, m4);
const uint16_t m1len = decrypt_data(mem, pk2, sk1, n, c1, c1len, m1);
const uint16_t m2len = decrypt_data(mem, pk1, sk2, n, c1, c1len, m2);
const uint16_t m3len = decrypt_data_symmetric(mem, k1, n, c1, c1len, m3);
const uint16_t m4len = decrypt_data_symmetric(mem, k2, n, c1, c1len, m4);

ck_assert_msg(m1len == m2len && m1len == m3len && m1len == m4len, "decrypted text lengths differ");
ck_assert_msg(m1len == mlen, "wrong decrypted text length");
Expand All @@ -192,6 +200,8 @@ static void test_endtoend(void)

static void test_large_data(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);
const Random *rng = os_random();
ck_assert(rng != nullptr);
uint8_t k[CRYPTO_SHARED_KEY_SIZE];
Expand All @@ -216,13 +226,13 @@ static void test_large_data(void)
//Generate key
rand_bytes(rng, k, CRYPTO_SHARED_KEY_SIZE);

const uint16_t c1len = encrypt_data_symmetric(k, n, m1, m1_size, c1);
const uint16_t c2len = encrypt_data_symmetric(k, n, m2, m2_size, c2);
const uint16_t c1len = encrypt_data_symmetric(mem, k, n, m1, m1_size, c1);
const uint16_t c2len = encrypt_data_symmetric(mem, k, n, m2, m2_size, c2);

ck_assert_msg(c1len == m1_size + CRYPTO_MAC_SIZE, "could not encrypt");
ck_assert_msg(c2len == m2_size + CRYPTO_MAC_SIZE, "could not encrypt");

const uint16_t m1plen = decrypt_data_symmetric(k, n, c1, c1len, m1prime);
const uint16_t m1plen = decrypt_data_symmetric(mem, k, n, c1, c1len, m1prime);

ck_assert_msg(m1plen == m1_size, "decrypted text lengths differ");
ck_assert_msg(memcmp(m1prime, m1, m1_size) == 0, "decrypted texts differ");
Expand All @@ -236,6 +246,8 @@ static void test_large_data(void)

static void test_large_data_symmetric(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);
const Random *rng = os_random();
ck_assert(rng != nullptr);
uint8_t k[CRYPTO_SYMMETRIC_KEY_SIZE];
Expand All @@ -256,10 +268,10 @@ static void test_large_data_symmetric(void)
//Generate key
new_symmetric_key(rng, k);

const uint16_t c1len = encrypt_data_symmetric(k, n, m1, m1_size, c1);
const uint16_t c1len = encrypt_data_symmetric(mem, k, n, m1, m1_size, c1);
ck_assert_msg(c1len == m1_size + CRYPTO_MAC_SIZE, "could not encrypt data");

const uint16_t m1plen = decrypt_data_symmetric(k, n, c1, c1len, m1prime);
const uint16_t m1plen = decrypt_data_symmetric(mem, k, n, c1, c1len, m1prime);

ck_assert_msg(m1plen == m1_size, "decrypted text lengths differ");
ck_assert_msg(memcmp(m1prime, m1, m1_size) == 0, "decrypted texts differ");
Expand All @@ -271,6 +283,8 @@ static void test_large_data_symmetric(void)

static void test_very_large_data(void)
{
const Memory *mem = os_memory();
ck_assert(mem != nullptr);
const Random *rng = os_random();
ck_assert(rng != nullptr);

Expand All @@ -287,7 +301,7 @@ static void test_very_large_data(void)
ck_assert(plain != nullptr);
ck_assert(encrypted != nullptr);

encrypt_data(pk, sk, nonce, plain, plain_size, encrypted);
encrypt_data(mem, pk, sk, nonce, plain, plain_size, encrypted);

free(encrypted);
free(plain);
Expand Down
18 changes: 9 additions & 9 deletions auto_tests/onion_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ static int handle_test_3(void *object, const IP_Port *source, const uint8_t *pac
#if 0
print_client_id(packet, length);
#endif
int len = decrypt_data(test_3_pub_key, dht_get_self_secret_key(onion->dht),
int len = decrypt_data(onion->mem, test_3_pub_key, dht_get_self_secret_key(onion->dht),
packet + 1 + ONION_ANNOUNCE_SENDBACK_DATA_LENGTH,
packet + 1 + ONION_ANNOUNCE_SENDBACK_DATA_LENGTH + CRYPTO_NONCE_SIZE,
2 + CRYPTO_SHA256_SIZE + CRYPTO_MAC_SIZE, plain);
Expand Down Expand Up @@ -144,7 +144,7 @@ static int handle_test_3_old(void *object, const IP_Port *source, const uint8_t
#if 0
print_client_id(packet, length);
#endif
int len = decrypt_data(test_3_pub_key, dht_get_self_secret_key(onion->dht),
int len = decrypt_data(onion->mem, test_3_pub_key, dht_get_self_secret_key(onion->dht),
packet + 1 + ONION_ANNOUNCE_SENDBACK_DATA_LENGTH,
packet + 1 + ONION_ANNOUNCE_SENDBACK_DATA_LENGTH + CRYPTO_NONCE_SIZE,
1 + CRYPTO_SHA256_SIZE + CRYPTO_MAC_SIZE, plain);
Expand Down Expand Up @@ -182,7 +182,7 @@ static int handle_test_4(void *object, const IP_Port *source, const uint8_t *pac
return 1;
}

int len = decrypt_data(packet + 1 + CRYPTO_NONCE_SIZE, dht_get_self_secret_key(onion->dht), packet + 1,
int len = decrypt_data(onion->mem, packet + 1 + CRYPTO_NONCE_SIZE, dht_get_self_secret_key(onion->dht), packet + 1,
packet + 1 + CRYPTO_NONCE_SIZE + CRYPTO_PUBLIC_KEY_SIZE, sizeof("Install gentoo") + CRYPTO_MAC_SIZE, plain);

if (len == -1) {
Expand All @@ -202,10 +202,10 @@ static int handle_test_4(void *object, const IP_Port *source, const uint8_t *pac
* Use Onion_Path path to send data of length to dest.
* Maximum length of data is ONION_MAX_DATA_SIZE.
*/
static void send_onion_packet(const Networking_Core *net, const Random *rng, const Onion_Path *path, const IP_Port *dest, const uint8_t *data, uint16_t length)
static void send_onion_packet(const Networking_Core *net, const Memory *mem, const Random *rng, const Onion_Path *path, const IP_Port *dest, const uint8_t *data, uint16_t length)
{
uint8_t packet[ONION_MAX_PACKET_SIZE];
const int len = create_onion_packet(rng, packet, sizeof(packet), path, dest, data, length);
const int len = create_onion_packet(mem, rng, packet, sizeof(packet), path, dest, data, length);
ck_assert_msg(len != -1, "failed to create onion packet");
ck_assert_msg(sendpacket(net, &path->ip_port1, packet, len) == len, "failed to send onion packet");
}
Expand Down Expand Up @@ -264,7 +264,7 @@ static void test_basic(void)
nodes[3] = n2;
Onion_Path path;
create_onion_path(rng, onion1->dht, &path, nodes);
send_onion_packet(onion1->net, rng, &path, &nodes[3].ip_port, req_packet, sizeof(req_packet));
send_onion_packet(onion1->net, onion1->mem, rng, &path, &nodes[3].ip_port, req_packet, sizeof(req_packet));

handled_test_1 = 0;

Expand All @@ -291,7 +291,7 @@ static void test_basic(void)
uint64_t s;
memcpy(&s, sb_data, sizeof(uint64_t));
memcpy(test_3_pub_key, nodes[3].public_key, CRYPTO_PUBLIC_KEY_SIZE);
int ret = send_announce_request(log1, onion1->net, rng, &path, &nodes[3],
int ret = send_announce_request(log1, onion1->mem, onion1->net, rng, &path, &nodes[3],
dht_get_self_public_key(onion1->dht),
dht_get_self_secret_key(onion1->dht),
zeroes,
Expand All @@ -313,7 +313,7 @@ static void test_basic(void)
memcpy(onion_announce_entry_public_key(onion2_a, 1), dht_get_self_public_key(onion2->dht), CRYPTO_PUBLIC_KEY_SIZE);
onion_announce_entry_set_time(onion2_a, 1, mono_time_get(mono_time2));
networking_registerhandler(onion1->net, NET_PACKET_ONION_DATA_RESPONSE, &handle_test_4, onion1);
send_announce_request(log1, onion1->net, rng, &path, &nodes[3],
send_announce_request(log1, onion1->mem, onion1->net, rng, &path, &nodes[3],
dht_get_self_public_key(onion1->dht),
dht_get_self_secret_key(onion1->dht),
test_3_ping_id,
Expand All @@ -338,7 +338,7 @@ static void test_basic(void)
ck_assert_msg((onion3 != nullptr), "Onion failed initializing.");

random_nonce(rng, nonce);
ret = send_data_request(log3, onion3->net, rng, &path, &nodes[3].ip_port,
ret = send_data_request(log3, onion3->mem, onion3->net, rng, &path, &nodes[3].ip_port,
dht_get_self_public_key(onion1->dht),
dht_get_self_public_key(onion1->dht),
nonce, (const uint8_t *)"Install gentoo", sizeof("Install gentoo"));
Expand Down
5 changes: 5 additions & 0 deletions toxcore/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ cc_library(
deps = [
":attributes",
":ccompat",
":mem",
":util",
"@libsodium",
],
Expand Down Expand Up @@ -209,6 +210,7 @@ cc_test(
deps = [
":crypto_core",
":crypto_core_test_util",
":mem_test_util",
":util",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
Expand Down Expand Up @@ -479,10 +481,12 @@ cc_test(
cc_fuzz_test(
name = "DHT_fuzz_test",
size = "small",
testonly = True,
srcs = ["DHT_fuzz_test.cc"],
corpus = ["//tools/toktok-fuzzer/corpus:DHT_fuzz_test"],
deps = [
":DHT",
":mem_test_util",
"//c-toxcore/testing/fuzzing:fuzz_support",
],
)
Expand Down Expand Up @@ -778,6 +782,7 @@ cc_library(
":crypto_core",
":group_announce",
":logger",
":mem",
":mono_time",
":network",
":onion_announce",
Expand Down
Loading

0 comments on commit 9e8c8b7

Please sign in to comment.