Skip to content

Commit

Permalink
gh-104223: Fix issues with inheriting from buffer classes (#104227)
Browse files Browse the repository at this point in the history
Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
  • Loading branch information
JelleZijlstra and kumaraditya303 authored May 8, 2023
1 parent 874010c commit 405eacc
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 13 deletions.
1 change: 1 addition & 0 deletions Include/cpython/memoryobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef struct {
#define _Py_MEMORYVIEW_FORTRAN 0x004 /* Fortran contiguous layout */
#define _Py_MEMORYVIEW_SCALAR 0x008 /* scalar: ndim = 0 */
#define _Py_MEMORYVIEW_PIL 0x010 /* PIL-style layout */
#define _Py_MEMORYVIEW_RESTRICTED 0x020 /* Disallow new references to the memoryview's buffer */

typedef struct {
PyObject_VAR_HEAD
Expand Down
3 changes: 2 additions & 1 deletion Include/internal/pycore_memoryobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ extern "C" {
#endif

PyObject *
PyMemoryView_FromObjectAndFlags(PyObject *v, int flags);
_PyMemoryView_FromBufferProc(PyObject *v, int flags,
getbufferproc bufferproc);

#ifdef __cplusplus
}
Expand Down
170 changes: 170 additions & 0 deletions Lib/test/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4579,6 +4579,176 @@ def test_c_buffer(self):
buf.__release_buffer__(mv)
self.assertEqual(buf.references, 0)

def test_inheritance(self):
class A(bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)

a = A(b"hello")
mv = memoryview(a)
self.assertEqual(mv.tobytes(), b"hello")

def test_inheritance_releasebuffer(self):
rb_call_count = 0
class B(bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)
def __release_buffer__(self, view):
nonlocal rb_call_count
rb_call_count += 1
super().__release_buffer__(view)

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(rb_call_count, 0)
self.assertEqual(rb_call_count, 1)

def test_inherit_but_return_something_else(self):
class A(bytearray):
def __buffer__(self, flags):
return memoryview(b"hello")

a = A(b"hello")
with memoryview(a) as mv:
self.assertEqual(mv.tobytes(), b"hello")

rb_call_count = 0
rb_raised = False
class B(bytearray):
def __buffer__(self, flags):
return memoryview(b"hello")
def __release_buffer__(self, view):
nonlocal rb_call_count
rb_call_count += 1
try:
super().__release_buffer__(view)
except ValueError:
nonlocal rb_raised
rb_raised = True

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(rb_call_count, 0)
self.assertEqual(rb_call_count, 1)
self.assertIs(rb_raised, True)

def test_override_only_release(self):
class C(bytearray):
def __release_buffer__(self, buffer):
super().__release_buffer__(buffer)

c = C(b"hello")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")

def test_release_saves_reference(self):
smuggled_buffer = None

class C(bytearray):
def __release_buffer__(s, buffer: memoryview):
with self.assertRaises(ValueError):
memoryview(buffer)
with self.assertRaises(ValueError):
buffer.cast("b")
with self.assertRaises(ValueError):
buffer.toreadonly()
with self.assertRaises(ValueError):
buffer[:1]
with self.assertRaises(ValueError):
buffer.__buffer__(0)
nonlocal smuggled_buffer
smuggled_buffer = buffer
self.assertEqual(buffer.tobytes(), b"hello")
super().__release_buffer__(buffer)

c = C(b"hello")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")
c.clear()
with self.assertRaises(ValueError):
smuggled_buffer.tobytes()

def test_release_saves_reference_no_subclassing(self):
ba = bytearray(b"hello")

class C:
def __buffer__(self, flags):
return memoryview(ba)

def __release_buffer__(self, buffer):
self.buffer = buffer

c = C()
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")
self.assertEqual(c.buffer.tobytes(), b"hello")

with self.assertRaises(BufferError):
ba.clear()
c.buffer.release()
ba.clear()

def test_multiple_inheritance_buffer_last(self):
class A:
def __buffer__(self, flags):
return memoryview(b"hello A")

class B(A, bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello A")

class Releaser:
def __release_buffer__(self, buffer):
self.buffer = buffer

class C(Releaser, bytearray):
def __buffer__(self, flags):
return super().__buffer__(flags)

c = C(b"hello C")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello C")
c.clear()
with self.assertRaises(ValueError):
c.buffer.tobytes()

def test_multiple_inheritance_buffer_last(self):
class A:
def __buffer__(self, flags):
raise RuntimeError("should not be called")

def __release_buffer__(self, buffer):
raise RuntimeError("should not be called")

class B(bytearray, A):
def __buffer__(self, flags):
return super().__buffer__(flags)

b = B(b"hello")
with memoryview(b) as mv:
self.assertEqual(mv.tobytes(), b"hello")

class Releaser:
buffer = None
def __release_buffer__(self, buffer):
self.buffer = buffer

class C(bytearray, Releaser):
def __buffer__(self, flags):
return super().__buffer__(flags)

c = C(b"hello")
with memoryview(c) as mv:
self.assertEqual(mv.tobytes(), b"hello")
c.clear()
self.assertIs(c.buffer, None)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions Objects/bytearrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ static void
bytearray_releasebuffer(PyByteArrayObject *obj, Py_buffer *view)
{
obj->ob_exports--;
assert(obj->ob_exports >= 0);
}

static int
Expand Down
45 changes: 44 additions & 1 deletion Objects/memoryobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ PyTypeObject _PyManagedBuffer_Type = {
return -1; \
}

#define CHECK_RESTRICTED(mv) \
if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
PyErr_SetString(PyExc_ValueError, \
"cannot create new view on restricted memoryview"); \
return NULL; \
}

#define CHECK_RESTRICTED_INT(mv) \
if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
PyErr_SetString(PyExc_ValueError, \
"cannot create new view on restricted memoryview"); \
return -1; \
}

/* See gh-92888. These macros signal that we need to check the memoryview
again due to possible read after frees. */
#define CHECK_RELEASED_AGAIN(mv) CHECK_RELEASED(mv)
Expand Down Expand Up @@ -781,14 +795,15 @@ PyMemoryView_FromBuffer(const Py_buffer *info)
using the given flags.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
PyObject *
static PyObject *
PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
{
_PyManagedBufferObject *mbuf;

if (PyMemoryView_Check(v)) {
PyMemoryViewObject *mv = (PyMemoryViewObject *)v;
CHECK_RELEASED(mv);
CHECK_RESTRICTED(mv);
return mbuf_add_view(mv->mbuf, &mv->view);
}
else if (PyObject_CheckBuffer(v)) {
Expand All @@ -806,6 +821,30 @@ PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
Py_TYPE(v)->tp_name);
return NULL;
}

/* Create a memoryview from an object that implements the buffer protocol,
using the given flags.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
PyObject *
_PyMemoryView_FromBufferProc(PyObject *v, int flags, getbufferproc bufferproc)
{
_PyManagedBufferObject *mbuf = mbuf_alloc();
if (mbuf == NULL)
return NULL;

int res = bufferproc(v, &mbuf->master, flags);
if (res < 0) {
mbuf->master.obj = NULL;
Py_DECREF(mbuf);
return NULL;
}

PyObject *ret = mbuf_add_view(mbuf, NULL);
Py_DECREF(mbuf);
return ret;
}

/* Create a memoryview from an object that implements the buffer protocol.
If the object is a memoryview, the new memoryview must be registered
with the same managed buffer. Otherwise, a new managed buffer is created. */
Expand Down Expand Up @@ -1397,6 +1436,7 @@ memoryview_cast_impl(PyMemoryViewObject *self, PyObject *format,
Py_ssize_t ndim = 1;

CHECK_RELEASED(self);
CHECK_RESTRICTED(self);

if (!MV_C_CONTIGUOUS(self->flags)) {
PyErr_SetString(PyExc_TypeError,
Expand Down Expand Up @@ -1452,6 +1492,7 @@ memoryview_toreadonly_impl(PyMemoryViewObject *self)
/*[clinic end generated code: output=2c7e056f04c99e62 input=dc06d20f19ba236f]*/
{
CHECK_RELEASED(self);
CHECK_RESTRICTED(self);
/* Even if self is already readonly, we still need to create a new
* object for .release() to work correctly.
*/
Expand All @@ -1474,6 +1515,7 @@ memory_getbuf(PyMemoryViewObject *self, Py_buffer *view, int flags)
int baseflags = self->flags;

CHECK_RELEASED_INT(self);
CHECK_RESTRICTED_INT(self);

/* start with complete information */
*view = *base;
Expand Down Expand Up @@ -2535,6 +2577,7 @@ memory_subscript(PyMemoryViewObject *self, PyObject *key)
return memory_item(self, index);
}
else if (PySlice_Check(key)) {
CHECK_RESTRICTED(self);
PyMemoryViewObject *sliced;

sliced = (PyMemoryViewObject *)mbuf_add_view(self->mbuf, view);
Expand Down
Loading

0 comments on commit 405eacc

Please sign in to comment.