From 5317aebb0a01614b584e8baa29dede9483936bcf Mon Sep 17 00:00:00 2001 From: Abhi Agarwal Date: Fri, 14 Jun 2024 09:28:11 -0400 Subject: [PATCH] chore: migrate to pyo3 Bounds API (#2596) # Description This migrates the Python package to use the new pyo3 bounds-based API, which allows more control over memory management on the library side and theoretical performance improvements (I benchmarked, and didn't notice anything substantial). The old API will be removed in 0.22. # Related Issue(s) # Documentation --- python/.gitignore | 1 + python/Cargo.toml | 2 +- python/src/filesystem.rs | 48 +++++++----- python/src/lib.rs | 158 +++++++++++++++++++++------------------ python/src/schema.rs | 152 +++++++++++++++++++------------------ 5 files changed, 194 insertions(+), 167 deletions(-) diff --git a/python/.gitignore b/python/.gitignore index e1e978f0a6..56df04b804 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -7,6 +7,7 @@ __pycache__/ # Unit test / coverage reports .coverage .pytest_cache/ +.benchmarks/ # mypy .mypy_cache/ diff --git a/python/Cargo.toml b/python/Cargo.toml index 672ba4ee50..f6c3b62a55 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -44,7 +44,7 @@ deltalake-mount = { path = "../crates/mount" } [dependencies.pyo3] version = "0.21.1" -features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] +features = ["extension-module", "abi3", "abi3-py38"] [dependencies.deltalake] path = "../crates/deltalake" diff --git a/python/src/filesystem.rs b/python/src/filesystem.rs index af8410af72..6a26b3e6f4 100644 --- a/python/src/filesystem.rs +++ b/python/src/filesystem.rs @@ -66,7 +66,7 @@ impl DeltaFileSystemHandler { #[classmethod] #[pyo3(signature = (table, options = None, known_sizes = None))] fn from_table( - _cls: &PyType, + _cls: &Bound<'_, PyType>, table: &RawDeltaTable, options: Option>, known_sizes: Option>, @@ -123,12 +123,20 @@ impl DeltaFileSystemHandler { Ok(format!("{self:?}") == format!("{other:?}")) } - fn get_file_info<'py>(&self, paths: Vec, py: Python<'py>) -> PyResult> { - let fs = PyModule::import(py, "pyarrow.fs")?; + fn get_file_info<'py>( + &self, + paths: Vec, + py: Python<'py>, + ) -> PyResult>> { + let fs = PyModule::import_bound(py, "pyarrow.fs")?; let file_types = fs.getattr("FileType")?; - let to_file_info = |loc: &str, type_: &PyAny, kwargs: &HashMap<&str, i64>| { - fs.call_method("FileInfo", (loc, type_), Some(kwargs.into_py_dict(py))) + let to_file_info = |loc: &str, type_: &Bound<'py, PyAny>, kwargs: &HashMap<&str, i64>| { + fs.call_method( + "FileInfo", + (loc, type_), + Some(&kwargs.into_py_dict_bound(py)), + ) }; let mut infos = Vec::new(); @@ -155,14 +163,14 @@ impl DeltaFileSystemHandler { ]); infos.push(to_file_info( meta.location.as_ref(), - file_types.getattr("File")?, + &file_types.getattr("File")?, &kwargs, )?); } Err(ObjectStoreError::NotFound { .. }) => { infos.push(to_file_info( path.as_ref(), - file_types.getattr("NotFound")?, + &file_types.getattr("NotFound")?, &HashMap::new(), )?); } @@ -173,7 +181,7 @@ impl DeltaFileSystemHandler { } else { infos.push(to_file_info( path.as_ref(), - file_types.getattr("Directory")?, + &file_types.getattr("Directory")?, &HashMap::new(), )?); } @@ -189,12 +197,16 @@ impl DeltaFileSystemHandler { allow_not_found: bool, recursive: bool, py: Python<'py>, - ) -> PyResult> { - let fs = PyModule::import(py, "pyarrow.fs")?; + ) -> PyResult>> { + let fs = PyModule::import_bound(py, "pyarrow.fs")?; let file_types = fs.getattr("FileType")?; - let to_file_info = |loc: String, type_: &PyAny, kwargs: HashMap<&str, i64>| { - fs.call_method("FileInfo", (loc, type_), Some(kwargs.into_py_dict(py))) + let to_file_info = |loc: String, type_: &Bound<'py, PyAny>, kwargs: HashMap<&str, i64>| { + fs.call_method( + "FileInfo", + (loc, type_), + Some(&kwargs.into_py_dict_bound(py)), + ) }; let path = Self::parse_path(&base_dir); @@ -222,7 +234,7 @@ impl DeltaFileSystemHandler { .map(|p| { to_file_info( p.to_string(), - file_types.getattr("Directory")?, + &file_types.getattr("Directory")?, HashMap::new(), ) }) @@ -244,7 +256,7 @@ impl DeltaFileSystemHandler { ]); to_file_info( meta.location.to_string(), - file_types.getattr("File")?, + &file_types.getattr("File")?, kwargs, ) }) @@ -438,7 +450,7 @@ impl ObjectInputFile { } #[pyo3(signature = (nbytes = None))] - fn read(&mut self, nbytes: Option, py: Python<'_>) -> PyResult> { + fn read<'py>(&mut self, nbytes: Option, py: Python<'py>) -> PyResult> { self.check_closed()?; let range = match nbytes { Some(len) => { @@ -466,7 +478,7 @@ impl ObjectInputFile { // TODO: PyBytes copies the buffer. If we move away from the limited CPython // API (the stable C API), we could implement the buffer protocol for // bytes::Bytes and return this zero-copy. - Ok(PyBytes::new(py, data.as_ref()).into_py(py)) + Ok(PyBytes::new_bound(py, data.as_ref())) } fn fileno(&self) -> PyResult<()> { @@ -580,7 +592,7 @@ impl ObjectOutputStream { Err(PyNotImplementedError::new_err("'read' not implemented")) } - fn write(&mut self, data: &PyBytes) -> PyResult { + fn write(&mut self, data: &Bound<'_, PyBytes>) -> PyResult { self.check_closed()?; let py = data.py(); let bytes = data.as_bytes(); @@ -598,7 +610,7 @@ impl ObjectOutputStream { Ok(_) => Ok(()), Err(err) => { rt().block_on(self.upload.abort()) - .map_err(|err| PythonError::from(err))?; + .map_err(PythonError::from)?; Err(PyIOError::new_err(err.to_string())) } }) diff --git a/python/src/lib.rs b/python/src/lib.rs index 9d766a8dfb..3ce837f31b 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -54,6 +54,7 @@ use deltalake::{DeltaOps, DeltaResult}; use futures::future::join_all; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyDict, PyFrozenSet}; use serde_json::{Map, Value}; @@ -64,9 +65,9 @@ use crate::schema::schema_to_pyobject; use crate::utils::rt; #[derive(FromPyObject)] -enum PartitionFilterValue<'a> { - Single(&'a str), - Multiple(Vec<&'a str>), +enum PartitionFilterValue { + Single(PyBackedStr), + Multiple(Vec), } #[pyclass(module = "deltalake._internal")] @@ -248,21 +249,10 @@ impl RawDeltaTable { pub fn files_by_partitions( &self, py: Python, - partitions_filters: Vec<(&str, &str, PartitionFilterValue)>, + partitions_filters: Vec<(PyBackedStr, PyBackedStr, PartitionFilterValue)>, ) -> PyResult> { py.allow_threads(|| { - let partition_filters: Result, DeltaTableError> = - partitions_filters - .into_iter() - .map(|filter| match filter { - (key, op, PartitionFilterValue::Single(v)) => { - PartitionFilter::try_from((key, op, v)) - } - (key, op, PartitionFilterValue::Multiple(v)) => { - PartitionFilter::try_from((key, op, v.as_slice())) - } - }) - .collect(); + let partition_filters = convert_partition_filters(partitions_filters); match partition_filters { Ok(filters) => Ok(self ._table @@ -279,7 +269,7 @@ impl RawDeltaTable { pub fn files( &self, py: Python, - partition_filters: Option>, + partition_filters: Option>, ) -> PyResult> { py.allow_threads(|| { if let Some(filters) = partition_filters { @@ -304,7 +294,7 @@ impl RawDeltaTable { pub fn file_uris( &self, - partition_filters: Option>, + partition_filters: Option>, ) -> PyResult> { if let Some(filters) = partition_filters { let filters = convert_partition_filters(filters).map_err(PythonError::from)?; @@ -322,9 +312,9 @@ impl RawDeltaTable { } #[getter] - pub fn schema(&self, py: Python) -> PyResult { + pub fn schema<'py>(&self, py: Python<'py>) -> PyResult> { let schema: &StructType = self._table.get_schema().map_err(PythonError::from)?; - schema_to_pyobject(schema, py) + schema_to_pyobject(schema.to_owned(), py) } /// Run the Vacuum command on the Delta Table: list and delete files no longer referenced @@ -422,7 +412,7 @@ impl RawDeltaTable { pub fn compact_optimize( &mut self, py: Python, - partition_filters: Option>, + partition_filters: Option>, target_size: Option, max_concurrent_tasks: Option, min_commit_interval: Option, @@ -481,7 +471,7 @@ impl RawDeltaTable { &mut self, py: Python, z_order_columns: Vec, - partition_filters: Option>, + partition_filters: Option>, target_size: Option, max_concurrent_tasks: Option, max_spill_size: usize, @@ -852,7 +842,7 @@ impl RawDeltaTable { #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false, custom_metadata=None))] pub fn restore( &mut self, - target: Option<&PyAny>, + target: Option<&Bound<'_, PyAny>>, ignore_missing_files: bool, protocol_downgrade_allowed: bool, custom_metadata: Option>, @@ -865,9 +855,9 @@ impl RawDeltaTable { if let Ok(version) = val.extract::() { cmd = cmd.with_version_to_restore(version) } - if let Ok(ds) = val.extract::<&str>() { + if let Ok(ds) = val.extract::() { let datetime = DateTime::::from( - DateTime::::parse_from_rfc3339(ds).map_err(|err| { + DateTime::::parse_from_rfc3339(ds.as_ref()).map_err(|err| { PyValueError::new_err(format!("Failed to parse datetime string: {err}")) })?, ); @@ -913,8 +903,8 @@ impl RawDeltaTable { &mut self, py: Python<'py>, schema: PyArrowType, - partition_filters: Option>, - ) -> PyResult)>> { + partition_filters: Option>, + ) -> PyResult>)>> { let path_set = match partition_filters { Some(filters) => Some(HashSet::<_>::from_iter( self.files_by_partitions(py, filters)?.iter().cloned(), @@ -942,9 +932,9 @@ impl RawDeltaTable { fn get_active_partitions<'py>( &self, - partitions_filters: Option>, + partitions_filters: Option>, py: Python<'py>, - ) -> PyResult<&'py PyFrozenSet> { + ) -> PyResult> { let column_names: HashSet<&str> = self ._table .get_schema() @@ -962,10 +952,13 @@ impl RawDeltaTable { .collect(); if let Some(filters) = &partitions_filters { - let unknown_columns: Vec<&str> = filters + let unknown_columns: Vec<&PyBackedStr> = filters .iter() - .map(|(column_name, _, _)| *column_name) - .filter(|column_name| !column_names.contains(column_name)) + .map(|(column_name, _, _)| column_name) + .filter(|column_name| { + let column_name: &'_ str = column_name.as_ref(); + !column_names.contains(column_name) + }) .collect(); if !unknown_columns.is_empty() { return Err(PyValueError::new_err(format!( @@ -973,10 +966,13 @@ impl RawDeltaTable { ))); } - let non_partition_columns: Vec<&str> = filters + let non_partition_columns: Vec<&PyBackedStr> = filters .iter() - .map(|(column_name, _, _)| *column_name) - .filter(|column_name| !partition_columns.contains(column_name)) + .map(|(column_name, _, _)| column_name) + .filter(|column_name| { + let column_name: &'_ str = column_name.as_ref(); + !partition_columns.contains(column_name) + }) .collect(); if !non_partition_columns.is_empty() { @@ -1019,11 +1015,11 @@ impl RawDeltaTable { }) .collect(); - let active_partitions: Vec<&'py PyFrozenSet> = active_partitions + let active_partitions = active_partitions .into_iter() - .map(|part| PyFrozenSet::new(py, part.iter())) - .collect::>()?; - PyFrozenSet::new(py, active_partitions) + .map(|part| PyFrozenSet::new_bound(py, part.iter())) + .collect::>, PyErr>>()?; + PyFrozenSet::new_bound(py, &active_partitions) } #[allow(clippy::too_many_arguments)] @@ -1034,7 +1030,7 @@ impl RawDeltaTable { mode: &str, partition_by: Vec, schema: PyArrowType, - partitions_filters: Option>, + partitions_filters: Option>, custom_metadata: Option>, ) -> PyResult<()> { py.allow_threads(|| { @@ -1253,7 +1249,6 @@ impl RawDeltaTable { #[pyo3(signature = (dry_run = true, custom_metadata = None))] pub fn repair( &mut self, - _py: Python, dry_run: bool, custom_metadata: Option>, ) -> PyResult { @@ -1317,23 +1312,32 @@ fn set_writer_properties( Ok(properties.build()) } -fn convert_partition_filters<'a>( - partitions_filters: Vec<(&'a str, &'a str, PartitionFilterValue)>, +fn convert_partition_filters( + partitions_filters: Vec<(PyBackedStr, PyBackedStr, PartitionFilterValue)>, ) -> Result, DeltaTableError> { partitions_filters .into_iter() .map(|filter| match filter { - (key, op, PartitionFilterValue::Single(v)) => PartitionFilter::try_from((key, op, v)), + (key, op, PartitionFilterValue::Single(v)) => { + let key: &'_ str = key.as_ref(); + let op: &'_ str = op.as_ref(); + let v: &'_ str = v.as_ref(); + PartitionFilter::try_from((key, op, v)) + } (key, op, PartitionFilterValue::Multiple(v)) => { + let key: &'_ str = key.as_ref(); + let op: &'_ str = op.as_ref(); + let v: Vec<&'_ str> = v.iter().map(|v| v.as_ref()).collect(); PartitionFilter::try_from((key, op, v.as_slice())) } }) .collect() } -fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult { +fn scalar_to_py<'py>(value: &Scalar, py_date: &Bound<'py, PyAny>) -> PyResult> { use Scalar::*; + let py = py_date.py(); let val = match value { Null(_) => py.None(), Boolean(val) => val.to_object(py), @@ -1363,15 +1367,15 @@ fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult value.serialize().to_object(py), Struct(data) => { - let py_struct = PyDict::new(py); + let py_struct = PyDict::new_bound(py); for (field, value) in data.fields().iter().zip(data.values().iter()) { - py_struct.set_item(field.name(), scalar_to_py(value, py_date, py)?)?; + py_struct.set_item(field.name(), scalar_to_py(value, py_date)?)?; } py_struct.to_object(py) } }; - Ok(val) + Ok(val.into_bound(py)) } /// Create expression that file statistics guarantee to be true. @@ -1390,14 +1394,14 @@ fn filestats_to_expression_next<'py>( py: Python<'py>, schema: &PyArrowType, file_info: LogicalFile<'_>, -) -> PyResult> { - let ds = PyModule::import(py, "pyarrow.dataset")?; +) -> PyResult>> { + let ds = PyModule::import_bound(py, "pyarrow.dataset")?; let py_field = ds.getattr("field")?; - let pa = PyModule::import(py, "pyarrow")?; - let py_date = Python::import(py, "datetime")?.getattr("date")?; - let mut expressions: Vec> = Vec::new(); + let pa = PyModule::import_bound(py, "pyarrow")?; + let py_date = Python::import_bound(py, "datetime")?.getattr("date")?; + let mut expressions = Vec::new(); - let cast_to_type = |column_name: &String, value: PyObject, schema: &ArrowSchema| { + let cast_to_type = |column_name: &String, value: &Bound<'py, PyAny>, schema: &ArrowSchema| { let column_type = schema .field_with_name(column_name) .map_err(|_| { @@ -1416,7 +1420,7 @@ fn filestats_to_expression_next<'py>( if !value.is_null() { // value is a string, but needs to be parsed into appropriate type let converted_value = - cast_to_type(&column, scalar_to_py(value, py_date, py)?, &schema.0)?; + cast_to_type(&column, &scalar_to_py(value, &py_date)?, &schema.0)?; expressions.push( py_field .call1((&column,))? @@ -1453,7 +1457,7 @@ fn filestats_to_expression_next<'py>( Scalar::Struct(_) => {} _ => { let maybe_minimum = - cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0); + cast_to_type(field.name(), &scalar_to_py(value, &py_date)?, &schema.0); if let Ok(minimum) = maybe_minimum { let field_expr = py_field.call1((field.name(),))?; let expr = field_expr.call_method1("__ge__", (minimum,)); @@ -1480,7 +1484,7 @@ fn filestats_to_expression_next<'py>( Scalar::Struct(_) => {} _ => { let maybe_maximum = - cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0); + cast_to_type(field.name(), &scalar_to_py(value, &py_date)?, &schema.0); if let Ok(maximum) = maybe_maximum { let field_expr = py_field.call1((field.name(),))?; let expr = field_expr.call_method1("__le__", (maximum,)); @@ -1859,29 +1863,41 @@ impl PyDeltaDataChecker { #[pymodule] // module name need to match project name -fn _internal(py: Python, m: &PyModule) -> PyResult<()> { +fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { use crate::error::{CommitFailedError, DeltaError, SchemaMismatchError, TableNotFoundError}; - deltalake::aws::register_handlers(None); deltalake::azure::register_handlers(None); deltalake::gcp::register_handlers(None); deltalake_mount::register_handlers(None); - m.add("DeltaError", py.get_type::())?; - m.add("CommitFailedError", py.get_type::())?; - m.add("DeltaProtocolError", py.get_type::())?; - m.add("TableNotFoundError", py.get_type::())?; - m.add("SchemaMismatchError", py.get_type::())?; + let py = m.py(); + m.add("DeltaError", py.get_type_bound::())?; + m.add( + "CommitFailedError", + py.get_type_bound::(), + )?; + m.add( + "DeltaProtocolError", + py.get_type_bound::(), + )?; + m.add( + "TableNotFoundError", + py.get_type_bound::(), + )?; + m.add( + "SchemaMismatchError", + py.get_type_bound::(), + )?; env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add("__version__", env!("CARGO_PKG_VERSION"))?; - m.add_function(pyo3::wrap_pyfunction!(rust_core_version, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(create_deltalake, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(write_to_deltalake, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(convert_to_deltalake, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(batch_distinct, m)?)?; - m.add_function(pyo3::wrap_pyfunction!( + m.add_function(pyo3::wrap_pyfunction_bound!(rust_core_version, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!(create_deltalake, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!(write_new_deltalake, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!(write_to_deltalake, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!(convert_to_deltalake, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!(batch_distinct, m)?)?; + m.add_function(pyo3::wrap_pyfunction_bound!( get_num_idx_cols_and_stats_columns, m )?)?; diff --git a/python/src/schema.rs b/python/src/schema.rs index 36f301ab98..6f1380709a 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -22,40 +22,40 @@ use std::collections::HashMap; // Decimal is separate special case, since it has parameters -fn schema_type_to_python(schema_type: DataType, py: Python) -> PyResult { +fn schema_type_to_python(schema_type: DataType, py: Python<'_>) -> PyResult> { match schema_type { - DataType::Primitive(data_type) => { - Ok((PrimitiveType::new(data_type.to_string())?).into_py(py)) - } + DataType::Primitive(data_type) => Ok((PrimitiveType::new(data_type.to_string())?) + .into_py(py) + .into_bound(py)), DataType::Array(array_type) => { let array_type: ArrayType = (*array_type).into(); - Ok(array_type.into_py(py)) + Ok(array_type.into_py(py).into_bound(py)) } DataType::Map(map_type) => { let map_type: MapType = (*map_type).into(); - Ok(map_type.into_py(py)) + Ok(map_type.into_py(py).into_bound(py)) } DataType::Struct(struct_type) => { let struct_type: StructType = (*struct_type).into(); - Ok(struct_type.into_py(py)) + Ok(struct_type.into_py(py).into_bound(py)) } } } -fn python_type_to_schema(ob: PyObject, py: Python) -> PyResult { - if let Ok(data_type) = ob.extract::(py) { +fn python_type_to_schema(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(data_type) = ob.extract::() { return Ok(DataType::Primitive(data_type.inner_type)); } - if let Ok(array_type) = ob.extract::(py) { + if let Ok(array_type) = ob.extract::() { return Ok(array_type.into()); } - if let Ok(map_type) = ob.extract::(py) { + if let Ok(map_type) = ob.extract::() { return Ok(map_type.into()); } - if let Ok(struct_type) = ob.extract::(py) { + if let Ok(struct_type) = ob.extract::() { return Ok(struct_type.into()); } - if let Ok(raw_primitive) = ob.extract::(py) { + if let Ok(raw_primitive) = ob.extract::() { // Pass through PrimitiveType::new() to do validation return PrimitiveType::new(raw_primitive) .map(|data_type| DataType::Primitive(data_type.inner_type)); @@ -187,16 +187,15 @@ impl TryFrom for ArrayType { impl ArrayType { #[new] #[pyo3(signature = (element_type, contains_null = true))] - fn new(element_type: PyObject, contains_null: bool, py: Python) -> PyResult { - let inner_type = - DeltaArrayType::new(python_type_to_schema(element_type, py)?, contains_null); + fn new(element_type: &Bound<'_, PyAny>, contains_null: bool) -> PyResult { + let inner_type = DeltaArrayType::new(python_type_to_schema(element_type)?, contains_null); Ok(Self { inner_type }) } fn __repr__(&self, py: Python) -> PyResult { let type_repr: String = schema_type_to_python(self.inner_type.element_type().clone(), py)? - .call_method0(py, "__repr__")? - .extract(py)?; + .call_method0("__repr__")? + .extract()?; Ok(format!( "ArrayType({}, contains_null={})", type_repr, @@ -224,13 +223,13 @@ impl ArrayType { } #[getter] - fn element_type(&self, py: Python) -> PyResult { + fn element_type<'py>(&self, py: Python<'py>) -> PyResult> { schema_type_to_python(self.inner_type.element_type().to_owned(), py) } #[getter] - fn contains_null(&self, py: Python) -> PyResult { - Ok(self.inner_type.contains_null().into_py(py)) + fn contains_null<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(self.inner_type.contains_null().into_py(py).into_bound(py)) } #[pyo3(text_signature = "($self)")] @@ -301,15 +300,14 @@ impl TryFrom for MapType { impl MapType { #[new] #[pyo3(signature = (key_type, value_type, value_contains_null = true))] - fn new( - key_type: PyObject, - value_type: PyObject, + fn new<'py>( + key_type: &Bound<'py, PyAny>, + value_type: &Bound<'py, PyAny>, value_contains_null: bool, - py: Python, ) -> PyResult { let inner_type = DeltaMapType::new( - python_type_to_schema(key_type, py)?, - python_type_to_schema(value_type, py)?, + python_type_to_schema(key_type)?, + python_type_to_schema(value_type)?, value_contains_null, ); Ok(Self { inner_type }) @@ -317,11 +315,11 @@ impl MapType { fn __repr__(&self, py: Python) -> PyResult { let key_repr: String = schema_type_to_python(self.inner_type.key_type().clone(), py)? - .call_method0(py, "__repr__")? - .extract(py)?; + .call_method0("__repr__")? + .extract()?; let value_repr: String = schema_type_to_python(self.inner_type.value_type().clone(), py)? - .call_method0(py, "__repr__")? - .extract(py)?; + .call_method0("__repr__")? + .extract()?; Ok(format!( "MapType({}, {}, value_contains_null={})", key_repr, @@ -350,18 +348,22 @@ impl MapType { } #[getter] - fn key_type(&self, py: Python) -> PyResult { + fn key_type<'py>(&self, py: Python<'py>) -> PyResult> { schema_type_to_python(self.inner_type.key_type().to_owned(), py) } #[getter] - fn value_type(&self, py: Python) -> PyResult { + fn value_type<'py>(&self, py: Python<'py>) -> PyResult> { schema_type_to_python(self.inner_type.value_type().to_owned(), py) } #[getter] - fn value_contains_null(&self, py: Python) -> PyResult { - Ok(self.inner_type.value_contains_null().into_py(py)) + fn value_contains_null<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(self + .inner_type + .value_contains_null() + .into_py(py) + .into_bound(py)) } #[pyo3(text_signature = "($self)")] @@ -408,18 +410,18 @@ pub struct Field { impl Field { #[new] #[pyo3(signature = (name, r#type, nullable = true, metadata = None))] - fn new( + fn new<'py>( name: String, - r#type: PyObject, + r#type: &Bound<'py, PyAny>, nullable: bool, - metadata: Option, - py: Python, + metadata: Option<&Bound<'py, PyAny>>, ) -> PyResult { - let ty = python_type_to_schema(r#type, py)?; + let py = r#type.py(); + let ty = python_type_to_schema(r#type)?; // Serialize and de-serialize JSON (it needs to be valid JSON anyways) - let metadata: HashMap = if let Some(ref json) = metadata { - let json_dumps = PyModule::import(py, "json")?.getattr("dumps")?; + let metadata: HashMap = if let Some(json) = metadata { + let json_dumps = PyModule::import_bound(py, "json")?.getattr("dumps")?; let metadata_json: String = json_dumps.call1((json,))?.extract()?; let metadata_json = Some(metadata_json) .filter(|x| x != "null") @@ -459,7 +461,7 @@ impl Field { } #[getter] - fn get_type(&self, py: Python) -> PyResult { + fn get_type<'py>(&self, py: Python<'py>) -> PyResult> { schema_type_to_python(self.inner.data_type().clone(), py) } @@ -469,26 +471,27 @@ impl Field { } #[getter] - fn metadata(&self, py: Python) -> PyResult { - let json_loads = PyModule::import(py, "json")?.getattr("loads")?; + fn metadata<'py>(&self, py: Python<'py>) -> PyResult> { + let json_loads = PyModule::import_bound(py, "json")?.getattr("loads")?; let metadata_json: String = serde_json::to_string(self.inner.metadata()) .map_err(|err| PyValueError::new_err(err.to_string()))?; - Ok(json_loads.call1((metadata_json,))?.to_object(py)) + Ok(json_loads + .call1((metadata_json,))? + .to_object(py) + .bind(py) + .to_owned()) } fn __repr__(&self, py: Python) -> PyResult { let type_repr: String = schema_type_to_python(self.inner.data_type().clone(), py)? - .call_method0(py, "__repr__")? - .extract(py)?; + .call_method0("__repr__")? + .extract()?; let metadata = self.inner.metadata(); let maybe_metadata = if metadata.is_empty() { "".to_string() } else { - let metadata_repr: String = self - .metadata(py)? - .call_method0(py, "__repr__")? - .extract(py)?; + let metadata_repr: String = self.metadata(py)?.call_method0("__repr__")?.extract()?; format!(", metadata={metadata_repr}") }; Ok(format!( @@ -660,19 +663,14 @@ impl StructType { } } -pub fn schema_to_pyobject(schema: &DeltaStructType, py: Python) -> PyResult { - let fields: Vec = schema - .fields() - .map(|field| Field { - inner: field.clone(), - }) - .collect(); +pub fn schema_to_pyobject(schema: DeltaStructType, py: Python<'_>) -> PyResult> { + let fields = schema.fields().map(|field| Field { + inner: field.clone(), + }); - let py_schema = PyModule::import(py, "deltalake.schema")?.getattr("Schema")?; + let py_schema = PyModule::import_bound(py, "deltalake.schema")?.getattr("Schema")?; - py_schema - .call1((fields,)) - .map(|schema| schema.to_object(py)) + py_schema.call1((fields.collect::>(),)) } /// A Delta Lake schema @@ -718,26 +716,23 @@ impl PySchema { Ok(format!("Schema([{}])", inner_data.join(", "))) } - fn json(self_: PyRef<'_, Self>, py: Python) -> PyResult { - let warnings_warn = PyModule::import(py, "warnings")?.getattr("warn")?; - let deprecation_warning = PyModule::import(py, "builtins")? - .getattr("DeprecationWarning")? - .to_object(py); - let kwargs: [(&str, PyObject); 2] = [ + fn json<'py>(self_: PyRef<'_, Self>, py: Python<'py>) -> PyResult> { + let warnings_warn = PyModule::import_bound(py, "warnings")?.getattr("warn")?; + let deprecation_warning = + PyModule::import_bound(py, "builtins")?.getattr("DeprecationWarning")?; + let kwargs: [(&str, Bound<'py, PyAny>); 2] = [ ("category", deprecation_warning), - ("stacklevel", 2.to_object(py)), + ("stacklevel", 2.to_object(py).into_bound(py)), ]; warnings_warn.call( ("Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.",), - Some(kwargs.into_py_dict(py)), + Some(&kwargs.into_py_dict_bound(py)), )?; let super_ = self_.as_ref(); let json = super_.to_json()?; - let json_loads = PyModule::import(py, "json")?.getattr("loads")?; - json_loads - .call1((json.into_py(py),)) - .map(|obj| obj.to_object(py)) + let json_loads = PyModule::import_bound(py, "json")?.getattr("loads")?; + json_loads.call1((json.into_py(py),)) } #[pyo3(signature = (as_large_types = false))] @@ -815,12 +810,15 @@ impl PySchema { #[staticmethod] #[pyo3(text_signature = "(data_type)")] - fn from_pyarrow(data_type: PyArrowType, py: Python) -> PyResult { + fn from_pyarrow( + data_type: PyArrowType, + py: Python<'_>, + ) -> PyResult> { let inner_type: DeltaStructType = (&data_type.0) .try_into() .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; - schema_to_pyobject(&inner_type, py) + schema_to_pyobject(inner_type, py) } #[pyo3(text_signature = "($self)")]