Skip to content

Commit

Permalink
Merge pull request #772 from andrewwhitehead/buffer-refs
Browse files Browse the repository at this point in the history
Buffer protocol updates to support object references, custom release method
  • Loading branch information
kngwyu authored Feb 21, 2020
2 parents 45d892a + aae57e7 commit 90b14fb
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 28 deletions.
52 changes: 45 additions & 7 deletions src/class/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,27 @@
use crate::callback::UnitCallbackConverter;
use crate::err::PyResult;
use crate::ffi;
use crate::type_object::PyTypeInfo;
use crate::{PyClass, PyClassShell};
use std::os::raw::c_int;

/// Buffer protocol interface
///
/// For more information check [buffer protocol](https://docs.python.org/3/c-api/buffer.html)
/// c-api
#[allow(unused_variables)]
pub trait PyBufferProtocol<'p>: PyTypeInfo {
fn bf_getbuffer(&'p self, view: *mut ffi::Py_buffer, flags: c_int) -> Self::Result
pub trait PyBufferProtocol<'p>: PyClass {
fn bf_getbuffer(
slf: &mut PyClassShell<Self>,
view: *mut ffi::Py_buffer,
flags: c_int,
) -> Self::Result
where
Self: PyBufferGetBufferProtocol<'p>,
{
unimplemented!()
}

fn bf_releasebuffer(&'p self, view: *mut ffi::Py_buffer) -> Self::Result
fn bf_releasebuffer(slf: &mut PyClassShell<Self>, view: *mut ffi::Py_buffer) -> Self::Result
where
Self: PyBufferReleaseBufferProtocol<'p>,
{
Expand Down Expand Up @@ -59,7 +63,7 @@ where
fn tp_as_buffer() -> Option<ffi::PyBufferProcs> {
Some(ffi::PyBufferProcs {
bf_getbuffer: Self::cb_bf_getbuffer(),
bf_releasebuffer: None,
bf_releasebuffer: Self::cb_bf_releasebuffer(),
..ffi::PyBufferProcs_INIT
})
}
Expand Down Expand Up @@ -94,11 +98,45 @@ where
{
let py = crate::Python::assume_gil_acquired();
let _pool = crate::GILPool::new(py);
let slf = py.mut_from_borrowed_ptr::<T>(slf);
let slf = &mut *(slf as *mut PyClassShell<T>);

let result = slf.bf_getbuffer(arg1, arg2).into();
let result = T::bf_getbuffer(slf, arg1, arg2).into();
crate::callback::cb_convert(UnitCallbackConverter, py, result)
}
Some(wrap::<T>)
}
}

trait PyBufferReleaseBufferProtocolImpl {
fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc>;
}

impl<'p, T> PyBufferReleaseBufferProtocolImpl for T
where
T: PyBufferProtocol<'p>,
{
default fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc> {
None
}
}

impl<T> PyBufferReleaseBufferProtocolImpl for T
where
T: for<'p> PyBufferReleaseBufferProtocol<'p>,
{
#[inline]
fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc> {
unsafe extern "C" fn wrap<T>(slf: *mut ffi::PyObject, arg1: *mut ffi::Py_buffer)
where
T: for<'p> PyBufferReleaseBufferProtocol<'p>,
{
let py = crate::Python::assume_gil_acquired();
let _pool = crate::GILPool::new(py);
let slf = &mut *(slf as *mut PyClassShell<T>);

let result = T::bf_releasebuffer(slf, arg1).into();
crate::callback::cb_convert(UnitCallbackConverter, py, result);
}
Some(wrap::<T>)
}
}
99 changes: 78 additions & 21 deletions tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
use pyo3::buffer::PyBuffer;
use pyo3::class::PyBufferProtocol;
use pyo3::exceptions::BufferError;
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::{AsPyPointer, PyClassShell};
use std::ffi::CStr;
use std::os::raw::{c_int, c_void};
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

#[pyclass]
struct TestClass {
struct TestBufferClass {
vec: Vec<u8>,
drop_called: Arc<AtomicBool>,
}

#[pyproto]
impl PyBufferProtocol for TestClass {
fn bf_getbuffer(&self, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()> {
impl PyBufferProtocol for TestBufferClass {
fn bf_getbuffer(
slf: &mut PyClassShell<Self>,
view: *mut ffi::Py_buffer,
flags: c_int,
) -> PyResult<()> {
if view.is_null() {
return Err(BufferError::py_err("View is null"));
}

unsafe {
(*view).obj = ptr::null_mut();
}

if (flags & ffi::PyBUF_WRITABLE) == ffi::PyBUF_WRITABLE {
return Err(BufferError::py_err("Object is not writable"));
}

let bytes = &self.vec;
unsafe {
(*view).obj = slf.as_ptr();
ffi::Py_INCREF((*view).obj);
}

let bytes = &slf.vec;

unsafe {
(*view).buf = bytes.as_ptr() as *mut c_void;
Expand Down Expand Up @@ -58,21 +68,68 @@ impl PyBufferProtocol for TestClass {

Ok(())
}

fn bf_releasebuffer(_slf: &mut PyClassShell<Self>, _view: *mut ffi::Py_buffer) -> PyResult<()> {
Ok(())
}
}

impl Drop for TestBufferClass {
fn drop(&mut self) {
print!("dropped");
self.drop_called.store(true, Ordering::Relaxed);
}
}

#[test]
fn test_buffer() {
let gil = Python::acquire_gil();
let py = gil.python();

let t = Py::new(
py,
TestClass {
vec: vec![b' ', b'2', b'3'],
},
)
.unwrap();

let d = [("ob", t)].into_py_dict(py);
py.run("assert bytes(ob) == b' 23'", None, Some(d)).unwrap();
let drop_called = Arc::new(AtomicBool::new(false));

{
let gil = Python::acquire_gil();
let py = gil.python();
let instance = Py::new(
py,
TestBufferClass {
vec: vec![b' ', b'2', b'3'],
drop_called: drop_called.clone(),
},
)
.unwrap();
let env = [("ob", instance)].into_py_dict(py);
py.run("assert bytes(ob) == b' 23'", None, Some(env))
.unwrap();
}

assert!(drop_called.load(Ordering::Relaxed));
}

#[test]
fn test_buffer_referenced() {
let drop_called = Arc::new(AtomicBool::new(false));

let buf = {
let input = vec![b' ', b'2', b'3'];
let gil = Python::acquire_gil();
let py = gil.python();
let instance: PyObject = TestBufferClass {
vec: input.clone(),
drop_called: drop_called.clone(),
}
.into_py(py);

let buf = PyBuffer::get(py, instance.as_ref(py)).unwrap();
assert_eq!(buf.to_vec::<u8>(py).unwrap(), input);
drop(instance);
buf
};

assert!(!drop_called.load(Ordering::Relaxed));

{
let _py = Python::acquire_gil().python();
drop(buf);
}

assert!(drop_called.load(Ordering::Relaxed));
}

0 comments on commit 90b14fb

Please sign in to comment.