From fda182a61c49051d6f0fd0fccda8798cdadc8140 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Mon, 17 Jan 2022 20:48:30 -0800 Subject: [PATCH] Python: Fix UB in Inputs Passing Trying to fix the macOS PyPy 3.7 error seen in https://github.com/conda-forge/warpx-feedstock/issues/37 Testing in https://github.com/conda-forge/warpx-feedstock/pull/38 After googling for a while, the original implementation was likely based on https://code.activestate.com/lists/python-list/704158, which contains bugs. 1) Bug: `create_string_buffer` Allocating new, null-terminated char arrays with `ctypes.create_string_buffer` does lead to scrambled arrays in pypy3.7. As far as I can see, this [should have also worked](https://docs.python.org/3/library/ctypes.html), but maybe there is a bug in the upstream implementation or the original code created some kind of use-after-free on a temporary while the new implementation just shares the existing byte address. This leads to errors such as the ones here: https://github.com/conda-forge/warpx-feedstock/pull/38#issuecomment-1010160519 The call `self.argvC[i] = ctypes.c_char_p(enc_arg)` is equivalent in creating a `NULL`-terminated char array. 2) Bug: Last Argv Argument The last argument in the array of char arrays `argv` in ANSII C needs to be a plain `NULL` ptr. Before this PR, this has been allocated but never initialized, leading to a undefined behavior (read as: crashes). Reference: https://stackoverflow.com/a/39096006/2719194 --- Python/pywarpx/_libwarpx.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py index 9de8227397d..78f731c9a1b 100755 --- a/Python/pywarpx/_libwarpx.py +++ b/Python/pywarpx/_libwarpx.py @@ -397,10 +397,15 @@ def get_nattr_species(self, species_name): def amrex_init(self, argv, mpi_comm=None): # --- Construct the ctype list of strings to pass in argc = len(argv) + + # note: +1 since there is an extra char-string array element, + # that ANSII C requires to be a simple NULL entry + # https://stackoverflow.com/a/39096006/2719194 argvC = (_LP_c_char * (argc+1))() for i, arg in enumerate(argv): enc_arg = arg.encode('utf-8') - argvC[i] = ctypes.create_string_buffer(enc_arg) + argvC[i] = ctypes.c_char_p(enc_arg) + argvC[argc] = ctypes.c_char_p(b"\0") # +1 element must be NULL if mpi_comm is None or MPI is None: self.libwarpx_so.amrex_init(argc, argvC)