Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Bound constructors for PyByteArray and PyMemoryView #3786

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 100 additions & 31 deletions src/types/bytearray.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::err::{PyErr, PyResult};
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::instance::{Borrowed, Bound};
use crate::{ffi, AsPyPointer, Py, PyAny, PyNativeType, Python};
use crate::py_result_ext::PyResultExt;
use crate::types::any::PyAnyMethods;
use crate::{ffi, AsPyPointer, PyAny, PyNativeType, Python};
use std::os::raw::c_char;
use std::slice;

Expand All @@ -11,13 +14,44 @@ pub struct PyByteArray(PyAny);
pyobject_native_type_core!(PyByteArray, pyobject_native_static_type_object!(ffi::PyByteArray_Type), #checkfunction=ffi::PyByteArray_Check);

impl PyByteArray {
/// Deprecated form of [`PyByteArray::new_bound`]
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyByteArray::new` will be replaced by `PyByteArray::new_bound` in a future PyO3 version"
)
)]
pub fn new<'py>(py: Python<'py>, src: &[u8]) -> &'py PyByteArray {
Self::new_bound(py, src).into_gil_ref()
}

/// Creates a new Python bytearray object.
///
/// The byte string is initialized by copying the data from the `&[u8]`.
pub fn new<'p>(py: Python<'p>, src: &[u8]) -> &'p PyByteArray {
pub fn new_bound<'py>(py: Python<'py>, src: &[u8]) -> Bound<'py, PyByteArray> {
let ptr = src.as_ptr() as *const c_char;
let len = src.len() as ffi::Py_ssize_t;
unsafe { py.from_owned_ptr::<PyByteArray>(ffi::PyByteArray_FromStringAndSize(ptr, len)) }
unsafe {
ffi::PyByteArray_FromStringAndSize(ptr, len)
.assume_owned(py)
.downcast_into_unchecked()
}
}

/// Deprecated form of [`PyByteArray::new_bound_with`]
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyByteArray::new_with` will be replaced by `PyByteArray::new_bound_with` in a future PyO3 version"
)
)]
pub fn new_with<F>(py: Python<'_>, len: usize, init: F) -> PyResult<&PyByteArray>
where
F: FnOnce(&mut [u8]) -> PyResult<()>,
{
Self::new_bound_with(py, len, init).map(Bound::into_gil_ref)
}

/// Creates a new Python `bytearray` object with an `init` closure to write its contents.
Expand All @@ -34,7 +68,7 @@ impl PyByteArray {
///
/// # fn main() -> PyResult<()> {
/// Python::with_gil(|py| -> PyResult<()> {
/// let py_bytearray = PyByteArray::new_with(py, 10, |bytes: &mut [u8]| {
/// let py_bytearray = PyByteArray::new_bound_with(py, 10, |bytes: &mut [u8]| {
/// bytes.copy_from_slice(b"Hello Rust");
/// Ok(())
/// })?;
Expand All @@ -44,34 +78,56 @@ impl PyByteArray {
/// })
/// # }
/// ```
pub fn new_with<F>(py: Python<'_>, len: usize, init: F) -> PyResult<&PyByteArray>
pub fn new_bound_with<F>(
py: Python<'_>,
len: usize,
init: F,
) -> PyResult<Bound<'_, PyByteArray>>
where
F: FnOnce(&mut [u8]) -> PyResult<()>,
{
unsafe {
let pyptr =
ffi::PyByteArray_FromStringAndSize(std::ptr::null(), len as ffi::Py_ssize_t);
// Check for an allocation error and return it
let pypybytearray: Py<PyByteArray> = Py::from_owned_ptr_or_err(py, pyptr)?;
let buffer: *mut u8 = ffi::PyByteArray_AsString(pyptr).cast();
// Allocate buffer and check for an error
let pybytearray: Bound<'_, Self> =
ffi::PyByteArray_FromStringAndSize(std::ptr::null(), len as ffi::Py_ssize_t)
.assume_owned_or_err(py)?
.downcast_into_unchecked();

let buffer: *mut u8 = ffi::PyByteArray_AsString(pybytearray.as_ptr()).cast();
debug_assert!(!buffer.is_null());
// Zero-initialise the uninitialised bytearray
std::ptr::write_bytes(buffer, 0u8, len);
// (Further) Initialise the bytearray in init
// If init returns an Err, pypybytearray will automatically deallocate the buffer
init(std::slice::from_raw_parts_mut(buffer, len)).map(|_| pypybytearray.into_ref(py))
init(std::slice::from_raw_parts_mut(buffer, len)).map(|_| pybytearray)
}
}

/// Creates a new Python `bytearray` object from another Python object that
/// implements the buffer protocol.
/// Deprecated form of [`PyByteArray::from_bound`]
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyByteArray::from` will be replaced by `PyByteArray::from_bound` in a future PyO3 version"
)
)]
pub fn from(src: &PyAny) -> PyResult<&PyByteArray> {
unsafe {
src.py()
.from_owned_ptr_or_err(ffi::PyByteArray_FromObject(src.as_ptr()))
}
}

/// Creates a new Python `bytearray` object from another Python object that
/// implements the buffer protocol.
pub fn from_bound<'py>(src: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyByteArray>> {
unsafe {
ffi::PyByteArray_FromObject(src.as_ptr())
.assume_owned_or_err(src.py())
.downcast_into_unchecked()
}
}

/// Gets the length of the bytearray.
#[inline]
pub fn len(&self) -> usize {
Expand Down Expand Up @@ -211,7 +267,7 @@ impl PyByteArray {
/// # use pyo3::prelude::*;
/// # use pyo3::types::PyByteArray;
/// # Python::with_gil(|py| {
/// let bytearray = PyByteArray::new(py, b"Hello World.");
/// let bytearray = PyByteArray::new_bound(py, b"Hello World.");
/// let mut copied_message = bytearray.to_vec();
/// assert_eq!(b"Hello World.", copied_message.as_slice());
///
Expand Down Expand Up @@ -369,7 +425,7 @@ pub trait PyByteArrayMethods<'py> {
/// # use pyo3::prelude::*;
/// # use pyo3::types::PyByteArray;
/// # Python::with_gil(|py| {
/// let bytearray = PyByteArray::new(py, b"Hello World.");
/// let bytearray = PyByteArray::new_bound(py, b"Hello World.");
/// let mut copied_message = bytearray.to_vec();
/// assert_eq!(b"Hello World.", copied_message.as_slice());
///
Expand Down Expand Up @@ -450,21 +506,34 @@ impl<'py> TryFrom<&'py PyAny> for &'py PyByteArray {
/// Creates a new Python `bytearray` object from another Python object that
/// implements the buffer protocol.
fn try_from(value: &'py PyAny) -> Result<Self, Self::Error> {
PyByteArray::from(value)
PyByteArray::from_bound(&value.as_borrowed()).map(Bound::into_gil_ref)
}
}

impl<'py> TryFrom<&Bound<'py, PyAny>> for Bound<'py, PyByteArray> {
type Error = crate::PyErr;

/// Creates a new Python `bytearray` object from another Python object that
/// implements the buffer protocol.
fn try_from(value: &Bound<'py, PyAny>) -> Result<Self, Self::Error> {
PyByteArray::from_bound(value)
}
}

#[cfg(test)]
mod tests {
use crate::types::any::PyAnyMethods;
use crate::types::bytearray::PyByteArrayMethods;
use crate::types::string::PyStringMethods;
use crate::types::PyByteArray;
use crate::{exceptions, PyAny};
use crate::{exceptions, Bound, PyAny, PyNativeType};
use crate::{PyObject, Python};

#[test]
fn test_len() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);
assert_eq!(src.len(), bytearray.len());
});
}
Expand All @@ -473,7 +542,7 @@ mod tests {
fn test_as_bytes() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);

let slice = unsafe { bytearray.as_bytes() };
assert_eq!(src, slice);
Expand All @@ -485,7 +554,7 @@ mod tests {
fn test_as_bytes_mut() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);

let slice = unsafe { bytearray.as_bytes_mut() };
assert_eq!(src, slice);
Expand All @@ -494,7 +563,7 @@ mod tests {
slice[0..5].copy_from_slice(b"Hi...");

assert_eq!(
bytearray.str().unwrap().to_str().unwrap(),
bytearray.str().unwrap().to_cow().unwrap(),
"bytearray(b'Hi... Python')"
);
});
Expand All @@ -504,7 +573,7 @@ mod tests {
fn test_to_vec() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);

let vec = bytearray.to_vec();
assert_eq!(src, vec.as_slice());
Expand All @@ -515,10 +584,10 @@ mod tests {
fn test_from() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);

let ba: PyObject = bytearray.into();
let bytearray = PyByteArray::from(ba.as_ref(py)).unwrap();
let bytearray = PyByteArray::from_bound(ba.bind(py)).unwrap();

assert_eq!(src, unsafe { bytearray.as_bytes() });
});
Expand All @@ -527,7 +596,7 @@ mod tests {
#[test]
fn test_from_err() {
Python::with_gil(|py| {
if let Err(err) = PyByteArray::from(py.None()) {
if let Err(err) = PyByteArray::from_bound(&py.None().as_borrowed()) {
assert!(err.is_instance_of::<exceptions::PyTypeError>(py));
} else {
panic!("error");
Expand All @@ -539,8 +608,8 @@ mod tests {
fn test_try_from() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray: &PyAny = PyByteArray::new(py, src).into();
let bytearray: &PyByteArray = TryInto::try_into(bytearray).unwrap();
let bytearray: &Bound<'_, PyAny> = &PyByteArray::new_bound(py, src);
let bytearray: Bound<'_, PyByteArray> = TryInto::try_into(bytearray).unwrap();

assert_eq!(src, unsafe { bytearray.as_bytes() });
});
Expand All @@ -550,7 +619,7 @@ mod tests {
fn test_resize() {
Python::with_gil(|py| {
let src = b"Hello Python";
let bytearray = PyByteArray::new(py, src);
let bytearray = PyByteArray::new_bound(py, src);

bytearray.resize(20).unwrap();
assert_eq!(20, bytearray.len());
Expand All @@ -560,7 +629,7 @@ mod tests {
#[test]
fn test_byte_array_new_with() -> super::PyResult<()> {
Python::with_gil(|py| -> super::PyResult<()> {
let py_bytearray = PyByteArray::new_with(py, 10, |b: &mut [u8]| {
let py_bytearray = PyByteArray::new_bound_with(py, 10, |b: &mut [u8]| {
b.copy_from_slice(b"Hello Rust");
Ok(())
})?;
Expand All @@ -573,7 +642,7 @@ mod tests {
#[test]
fn test_byte_array_new_with_zero_initialised() -> super::PyResult<()> {
Python::with_gil(|py| -> super::PyResult<()> {
let py_bytearray = PyByteArray::new_with(py, 10, |_b: &mut [u8]| Ok(()))?;
let py_bytearray = PyByteArray::new_bound_with(py, 10, |_b: &mut [u8]| Ok(()))?;
let bytearray: &[u8] = unsafe { py_bytearray.as_bytes() };
assert_eq!(bytearray, &[0; 10]);
Ok(())
Expand All @@ -584,7 +653,7 @@ mod tests {
fn test_byte_array_new_with_error() {
use crate::exceptions::PyValueError;
Python::with_gil(|py| {
let py_bytearray_result = PyByteArray::new_with(py, 10, |_b: &mut [u8]| {
let py_bytearray_result = PyByteArray::new_bound_with(py, 10, |_b: &mut [u8]| {
Err(PyValueError::new_err("Hello Crustaceans!"))
});
assert!(py_bytearray_result.is_err());
Expand Down
36 changes: 32 additions & 4 deletions src/types/memoryview.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::err::PyResult;
use crate::{ffi, AsPyPointer, PyAny};
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::py_result_ext::PyResultExt;
use crate::{ffi, AsPyPointer, Bound, PyAny, PyNativeType};

/// Represents a Python `memoryview`.
#[repr(transparent)]
Expand All @@ -8,14 +10,30 @@ pub struct PyMemoryView(PyAny);
pyobject_native_type_core!(PyMemoryView, pyobject_native_static_type_object!(ffi::PyMemoryView_Type), #checkfunction=ffi::PyMemoryView_Check);

impl PyMemoryView {
/// Creates a new Python `memoryview` object from another Python object that
/// implements the buffer protocol.
/// Deprecated form of [`PyMemoryView::from_bound`]
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyMemoryView::from` will be replaced by `PyMemoryView::from_bound` in a future PyO3 version"
)
)]
pub fn from(src: &PyAny) -> PyResult<&PyMemoryView> {
unsafe {
src.py()
.from_owned_ptr_or_err(ffi::PyMemoryView_FromObject(src.as_ptr()))
}
}

/// Creates a new Python `memoryview` object from another Python object that
/// implements the buffer protocol.
pub fn from_bound<'py>(src: &Bound<'py, PyAny>) -> PyResult<Bound<'py, Self>> {
unsafe {
ffi::PyMemoryView_FromObject(src.as_ptr())
.assume_owned_or_err(src.py())
.downcast_into_unchecked()
}
}
}

impl<'py> TryFrom<&'py PyAny> for &'py PyMemoryView {
Expand All @@ -24,6 +42,16 @@ impl<'py> TryFrom<&'py PyAny> for &'py PyMemoryView {
/// Creates a new Python `memoryview` object from another Python object that
/// implements the buffer protocol.
fn try_from(value: &'py PyAny) -> Result<Self, Self::Error> {
PyMemoryView::from(value)
PyMemoryView::from_bound(&value.as_borrowed()).map(Bound::into_gil_ref)
}
}

impl<'py> TryFrom<&Bound<'py, PyAny>> for Bound<'py, PyMemoryView> {
type Error = crate::PyErr;

/// Creates a new Python `memoryview` object from another Python object that
/// implements the buffer protocol.
fn try_from(value: &Bound<'py, PyAny>) -> Result<Self, Self::Error> {
PyMemoryView::from_bound(value)
}
}
Loading