Skip to content

Commit

Permalink
fix: Ensure type object of inputs for cached any-value conversion fun…
Browse files Browse the repository at this point in the history
…ctions are kept alive (pola-rs#19866)
  • Loading branch information
nameexhaustion authored Nov 20, 2024
1 parent cf6f375 commit 5f61d70
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 13 deletions.
74 changes: 62 additions & 12 deletions crates/polars-python/src/conversion/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::borrow::{Borrow, Cow};

#[cfg(feature = "object")]
use polars::chunked_array::object::PolarsObjectSafe;
Expand All @@ -12,7 +12,9 @@ use polars_core::utils::arrow::temporal_conversions::date32_to_date;
use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PySequence, PyString, PyTuple};
use pyo3::types::{
PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PySequence, PyString, PyTuple, PyType,
};

use super::datetime::{
elapsed_offset_to_timedelta, nanos_since_midnight_to_naivetime, timestamp_to_naive_datetime,
Expand Down Expand Up @@ -152,9 +154,50 @@ fn datetime_to_py_object(
}
}

type TypeObjectPtr = usize;
/// Holds a Python type object and implements hashing / equality based on the pointer address of the
/// type object. This is used as a hashtable key instead of only the `usize` pointer value, as we
/// need to hold a ref to the Python type object to keep it alive.
#[derive(Debug)]
pub struct TypeObjectKey {
#[allow(unused)]
type_object: Py<PyType>,
/// We need to store this in a field for `Borrow<usize>`
address: usize,
}

impl TypeObjectKey {
fn new(type_object: Py<PyType>) -> Self {
let address = type_object.as_ptr() as usize;
Self {
type_object,
address,
}
}
}

impl PartialEq for TypeObjectKey {
fn eq(&self, other: &Self) -> bool {
self.address == other.address
}
}

impl Eq for TypeObjectKey {}

impl std::borrow::Borrow<usize> for TypeObjectKey {
fn borrow(&self) -> &usize {
&self.address
}
}

impl std::hash::Hash for TypeObjectKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let v: &usize = self.borrow();
v.hash(state)
}
}

type InitFn = for<'py> fn(&Bound<'py, PyAny>, bool) -> PyResult<AnyValue<'py>>;
pub(crate) static LUT: crate::gil_once_cell::GILOnceCell<PlHashMap<TypeObjectPtr, InitFn>> =
pub(crate) static LUT: crate::gil_once_cell::GILOnceCell<PlHashMap<TypeObjectKey, InitFn>> =
crate::gil_once_cell::GILOnceCell::new();

/// Convert a Python object to an [`AnyValue`].
Expand Down Expand Up @@ -473,16 +516,23 @@ pub(crate) fn py_object_to_any_value<'py>(
}
}

let type_object_ptr = ob.get_type().as_type_ptr() as usize;
let py_type = ob.get_type();
let py_type_address = py_type.as_ptr() as usize;

Python::with_gil(move |py| {
LUT.with_gil(py, move |lut| {
if !lut.contains_key(&py_type_address) {
let k = TypeObjectKey::new(py_type.clone().unbind());

assert_eq!(k.address, py_type_address);

Python::with_gil(|py| {
let conversion_function = get_conversion_function(ob, py, allow_object)?;
unsafe {
lut.insert_unique_unchecked(k, get_conversion_function(ob, py, allow_object)?);
}
}

LUT.with_gil(py, |lut| {
let convert_fn = lut
.entry(type_object_ptr)
.or_insert_with(|| conversion_function);
convert_fn(ob, strict)
let conversion_func = lut.get(&py_type_address).unwrap();
conversion_func(ob, strict)
})
})
}
1 change: 0 additions & 1 deletion py-polars/tests/unit/functions/range/test_date_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def test_date_ranges_datetime_input() -> None:
assert_series_equal(result, expected)


@pytest.mark.may_fail_auto_streaming
def test_date_range_with_subclass_18470_18447() -> None:
class MyAmazingDate(date):
pass
Expand Down

0 comments on commit 5f61d70

Please sign in to comment.