Skip to content

Commit

Permalink
Merge pull request #1129 from gilles-peskine-arm/psa-buffers-test-poi…
Browse files Browse the repository at this point in the history
…son-2.28

Backport 2.28: Memory poisoning function for Asan
  • Loading branch information
davidhorstmann-arm authored Dec 11, 2023
2 parents 806c27c + 7d68a19 commit 95b54f3
Show file tree
Hide file tree
Showing 16 changed files with 979 additions and 166 deletions.
109 changes: 103 additions & 6 deletions programs/test/metatest.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@

#define MBEDTLS_ALLOW_PRIVATE_ACCESS

#include <mbedtls/debug.h>
#include <mbedtls/platform.h>
#include <mbedtls/platform_util.h>
#include "test/helpers.h"
#include "test/macros.h"
#include "test/memory.h"

#include <stdio.h>
#include <string.h>
Expand All @@ -40,19 +42,39 @@
#include <mbedtls/threading.h>
#endif

/* C99 feature missing from older versions of MSVC */
#if (defined(_MSC_VER) && (_MSC_VER <= 1900))
#define /*no-check-names*/ __func__ __FUNCTION__
#endif


/* This is an external variable, so the compiler doesn't know that we're never
* changing its value.
*/
volatile int false_but_the_compiler_does_not_know = 0;

/* Hide calls to calloc/free from static checkers such as
* `gcc-12 -Wuse-after-free`, to avoid compile-time complaints about
* code where we do mean to cause a runtime error. */
void * (* volatile calloc_but_the_compiler_does_not_know)(size_t, size_t) = mbedtls_calloc;
void(*volatile free_but_the_compiler_does_not_know)(void *) = mbedtls_free;

/* Set n bytes at the address p to all-bits-zero, in such a way that
* the compiler should not know that p is all-bits-zero. */
static void set_to_zero_but_the_compiler_does_not_know(volatile void *p, size_t n)
{
memset((void *) p, false_but_the_compiler_does_not_know, n);
}

/* Simulate an access to the given object, to avoid compiler optimizations
* in code that prepares or consumes the object. */
static void do_nothing_with_object(void *p)
{
(void) p;
}
void(*volatile do_nothing_with_object_but_the_compiler_does_not_know)(void *) =
do_nothing_with_object;


/****************************************************************/
/* Test framework features */
Expand Down Expand Up @@ -98,21 +120,21 @@ void null_pointer_call(const char *name)
void read_after_free(const char *name)
{
(void) name;
volatile char *p = mbedtls_calloc(1, 1);
volatile char *p = calloc_but_the_compiler_does_not_know(1, 1);
*p = 'a';
mbedtls_free((void *) p);
free_but_the_compiler_does_not_know((void *) p);
/* Undefined behavior (read after free) */
mbedtls_printf("%u\n", (unsigned) *p);
}

void double_free(const char *name)
{
(void) name;
volatile char *p = mbedtls_calloc(1, 1);
volatile char *p = calloc_but_the_compiler_does_not_know(1, 1);
*p = 'a';
mbedtls_free((void *) p);
free_but_the_compiler_does_not_know((void *) p);
/* Undefined behavior (double free) */
mbedtls_free((void *) p);
free_but_the_compiler_does_not_know((void *) p);
}

void read_uninitialized_stack(const char *name)
Expand All @@ -132,11 +154,70 @@ void read_uninitialized_stack(const char *name)
void memory_leak(const char *name)
{
(void) name;
volatile char *p = mbedtls_calloc(1, 1);
volatile char *p = calloc_but_the_compiler_does_not_know(1, 1);
mbedtls_printf("%u\n", (unsigned) *p);
/* Leak of a heap object */
}

/* name = "test_memory_poison_%(start)_%(offset)_%(count)_%(direction)"
* Poison a region starting at start from an 8-byte aligned origin,
* encompassing count bytes. Access the region at offset from the start.
* %(start), %(offset) and %(count) are decimal integers.
* %(direction) is either the character 'r' for read or 'w' for write.
*/
void test_memory_poison(const char *name)
{
size_t start = 0, offset = 0, count = 0;
char direction = 'r';
if (sscanf(name,
"%*[^0-9]%" MBEDTLS_PRINTF_SIZET
"%*[^0-9]%" MBEDTLS_PRINTF_SIZET
"%*[^0-9]%" MBEDTLS_PRINTF_SIZET
"_%c",
&start, &offset, &count, &direction) != 4) {
mbedtls_fprintf(stderr, "%s: Bad name format: %s\n", __func__, name);
return;
}

union {
long long ll;
unsigned char buf[32];
} aligned;
memset(aligned.buf, 'a', sizeof(aligned.buf));

if (start > sizeof(aligned.buf)) {
mbedtls_fprintf(stderr,
"%s: start=%" MBEDTLS_PRINTF_SIZET
" > size=%" MBEDTLS_PRINTF_SIZET,
__func__, start, sizeof(aligned.buf));
return;
}
if (start + count > sizeof(aligned.buf)) {
mbedtls_fprintf(stderr,
"%s: start+count=%" MBEDTLS_PRINTF_SIZET
" > size=%" MBEDTLS_PRINTF_SIZET,
__func__, start + count, sizeof(aligned.buf));
return;
}
if (offset >= count) {
mbedtls_fprintf(stderr,
"%s: offset=%" MBEDTLS_PRINTF_SIZET
" >= count=%" MBEDTLS_PRINTF_SIZET,
__func__, offset, count);
return;
}

MBEDTLS_TEST_MEMORY_POISON(aligned.buf + start, count);

if (direction == 'w') {
aligned.buf[start + offset] = 'b';
do_nothing_with_object_but_the_compiler_does_not_know(aligned.buf);
} else {
do_nothing_with_object_but_the_compiler_does_not_know(aligned.buf);
mbedtls_printf("%u\n", (unsigned) aligned.buf[start + offset]);
}
}


/****************************************************************/
/* Threading */
Expand Down Expand Up @@ -285,6 +366,22 @@ metatest_t metatests[] = {
{ "double_free", "asan", double_free },
{ "read_uninitialized_stack", "msan", read_uninitialized_stack },
{ "memory_leak", "asan", memory_leak },
{ "test_memory_poison_0_0_8_r", "asan", test_memory_poison },
{ "test_memory_poison_0_0_8_w", "asan", test_memory_poison },
{ "test_memory_poison_0_7_8_r", "asan", test_memory_poison },
{ "test_memory_poison_0_7_8_w", "asan", test_memory_poison },
{ "test_memory_poison_0_0_1_r", "asan", test_memory_poison },
{ "test_memory_poison_0_0_1_w", "asan", test_memory_poison },
{ "test_memory_poison_0_1_2_r", "asan", test_memory_poison },
{ "test_memory_poison_0_1_2_w", "asan", test_memory_poison },
{ "test_memory_poison_7_0_8_r", "asan", test_memory_poison },
{ "test_memory_poison_7_0_8_w", "asan", test_memory_poison },
{ "test_memory_poison_7_7_8_r", "asan", test_memory_poison },
{ "test_memory_poison_7_7_8_w", "asan", test_memory_poison },
{ "test_memory_poison_7_0_1_r", "asan", test_memory_poison },
{ "test_memory_poison_7_0_1_w", "asan", test_memory_poison },
{ "test_memory_poison_7_1_2_r", "asan", test_memory_poison },
{ "test_memory_poison_7_1_2_w", "asan", test_memory_poison },
{ "mutex_lock_not_initialized", "pthread", mutex_lock_not_initialized },
{ "mutex_unlock_not_initialized", "pthread", mutex_unlock_not_initialized },
{ "mutex_free_not_initialized", "pthread", mutex_free_not_initialized },
Expand Down
20 changes: 0 additions & 20 deletions scripts/find-mem-leak.cocci

This file was deleted.

112 changes: 112 additions & 0 deletions scripts/mbedtls_dev/crypto_data_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Generate test data for cryptographic mechanisms.
This module is a work in progress, only implementing a few cases for now.
"""

# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later


import hashlib
from typing import Callable, Dict, Iterator, List, Optional #pylint: disable=unused-import

from . import crypto_knowledge
from . import psa_information
from . import test_case


def psa_low_level_dependencies(*expressions: str) -> List[str]:
"""Infer dependencies of a PSA low-level test case by looking for PSA_xxx symbols.
This function generates MBEDTLS_PSA_BUILTIN_xxx symbols.
"""
high_level = psa_information.automatic_dependencies(*expressions)
for dep in high_level:
assert dep.startswith('PSA_WANT_')
return ['MBEDTLS_PSA_BUILTIN_' + dep[9:] for dep in high_level]


class HashPSALowLevel:
"""Generate test cases for the PSA low-level hash interface."""

def __init__(self, info: psa_information.Information) -> None:
self.info = info
base_algorithms = sorted(info.constructors.algorithms)
all_algorithms = \
[crypto_knowledge.Algorithm(expr)
for expr in info.constructors.generate_expressions(base_algorithms)]
self.algorithms = \
[alg
for alg in all_algorithms
if (not alg.is_wildcard and
alg.can_do(crypto_knowledge.AlgorithmCategory.HASH))]

# CALCULATE[alg] = function to return the hash of its argument in hex
# TO-DO: implement the None entries with a third-party library, because
# hashlib might not have everything, depending on the Python version and
# the underlying OpenSSL. On Ubuntu 16.04, truncated sha512 and sha3/shake
# are not available. On Ubuntu 22.04, md2, md4 and ripemd160 are not
# available.
CALCULATE = {
'PSA_ALG_MD2': None,
'PSA_ALG_MD4': None,
'PSA_ALG_MD5': lambda data: hashlib.md5(data).hexdigest(),
'PSA_ALG_RIPEMD160': None, #lambda data: hashlib.new('ripdemd160').hexdigest()
'PSA_ALG_SHA_1': lambda data: hashlib.sha1(data).hexdigest(),
'PSA_ALG_SHA_224': lambda data: hashlib.sha224(data).hexdigest(),
'PSA_ALG_SHA_256': lambda data: hashlib.sha256(data).hexdigest(),
'PSA_ALG_SHA_384': lambda data: hashlib.sha384(data).hexdigest(),
'PSA_ALG_SHA_512': lambda data: hashlib.sha512(data).hexdigest(),
'PSA_ALG_SHA_512_224': None, #lambda data: hashlib.new('sha512_224').hexdigest()
'PSA_ALG_SHA_512_256': None, #lambda data: hashlib.new('sha512_256').hexdigest()
'PSA_ALG_SHA3_224': None, #lambda data: hashlib.sha3_224(data).hexdigest(),
'PSA_ALG_SHA3_256': None, #lambda data: hashlib.sha3_256(data).hexdigest(),
'PSA_ALG_SHA3_384': None, #lambda data: hashlib.sha3_384(data).hexdigest(),
'PSA_ALG_SHA3_512': None, #lambda data: hashlib.sha3_512(data).hexdigest(),
'PSA_ALG_SHAKE256_512': None, #lambda data: hashlib.shake_256(data).hexdigest(64),
} #type: Dict[str, Optional[Callable[[bytes], str]]]

@staticmethod
def one_test_case(alg: crypto_knowledge.Algorithm,
function: str, note: str,
arguments: List[str]) -> test_case.TestCase:
"""Construct one test case involving a hash."""
tc = test_case.TestCase()
tc.set_description('{}{} {}'
.format(function,
' ' + note if note else '',
alg.short_expression()))
tc.set_dependencies(psa_low_level_dependencies(alg.expression))
tc.set_function(function)
tc.set_arguments([alg.expression] +
['"{}"'.format(arg) for arg in arguments])
return tc

def test_cases_for_hash(self,
alg: crypto_knowledge.Algorithm
) -> Iterator[test_case.TestCase]:
"""Enumerate all test cases for one hash algorithm."""
calc = self.CALCULATE[alg.expression]
if calc is None:
return # not implemented yet

short = b'abc'
hash_short = calc(short)
long = (b'Hello, world. Here are 16 unprintable bytes: ['
b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a'
b'\x80\x81\x82\x83\xfe\xff]. '
b' This message was brought to you by a natural intelligence. '
b' If you can read this, good luck with your debugging!')
hash_long = calc(long)

yield self.one_test_case(alg, 'hash_empty', '', [calc(b'')])
yield self.one_test_case(alg, 'hash_valid_one_shot', '',
[short.hex(), hash_short])
for n in [0, 1, 64, len(long) - 1, len(long)]:
yield self.one_test_case(alg, 'hash_valid_multipart',
'{} + {}'.format(n, len(long) - n),
[long[:n].hex(), calc(long[:n]),
long[n:].hex(), hash_long])

def all_test_cases(self) -> Iterator[test_case.TestCase]:
"""Enumerate all test cases for all hash algorithms."""
for alg in self.algorithms:
yield from self.test_cases_for_hash(alg)
Loading

0 comments on commit 95b54f3

Please sign in to comment.