diff --git a/CHANGES.md b/CHANGES.md index 45312d09..992ab023 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,13 @@ - `read_arrow` and `open_arrow` now provide [GeoArrow-compliant extension metadata](https://geoarrow.org/extension-types.html), including the CRS, when using GDAL 3.8 or higher (#366). +- The `open_arrow` function can now be used without a `pyarrow` dependency. By + default, it will now return a stream object implementing the + [Arrow PyCapsule Protocol](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html) + (i.e. having an `__arrow_c_stream__`method). This object can then be consumed + by your Arrow implementation of choice that supports this protocol. To keep + the previous behaviour of returning a `pyarrow.RecordBatchReader`, specify + `use_pyarrow=True` (#349). - Warn when reading from a multilayer file without specifying a layer (#362). ### Bug fixes diff --git a/pyogrio/_compat.py b/pyogrio/_compat.py index 2e070397..d3fbf214 100644 --- a/pyogrio/_compat.py +++ b/pyogrio/_compat.py @@ -24,7 +24,8 @@ pandas = None -HAS_ARROW_API = __gdal_version__ >= (3, 6, 0) and pyarrow is not None +HAS_ARROW_API = __gdal_version__ >= (3, 6, 0) +HAS_PYARROW = pyarrow is not None HAS_GEOPANDAS = geopandas is not None diff --git a/pyogrio/_io.pyx b/pyogrio/_io.pyx index c2be6d00..004abc84 100644 --- a/pyogrio/_io.pyx +++ b/pyogrio/_io.pyx @@ -18,6 +18,8 @@ from libc.string cimport strlen from libc.math cimport isnan cimport cython +from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer + import numpy as np from pyogrio._ogr cimport * @@ -1256,6 +1258,35 @@ def ogr_read( field_data ) + +cdef void pycapsule_array_stream_deleter(object stream_capsule) noexcept: + cdef ArrowArrayStream* stream = PyCapsule_GetPointer( + stream_capsule, 'arrow_array_stream' + ) + # Do not invoke the deleter on a used/moved capsule + if stream.release != NULL: + stream.release(stream) + + free(stream) + + +cdef object alloc_c_stream(ArrowArrayStream** c_stream): + c_stream[0] = malloc(sizeof(ArrowArrayStream)) + # Ensure the capsule destructor doesn't call a random release pointer + c_stream[0].release = NULL + return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_array_stream_deleter) + + +class _ArrowStream: + def __init__(self, capsule): + self._capsule = capsule + + def __arrow_c_stream__(self, requested_schema=None): + if requested_schema is not None: + raise NotImplementedError("requested_schema is not supported") + return self._capsule + + @contextlib.contextmanager def ogr_open_arrow( str path, @@ -1274,7 +1305,9 @@ def ogr_open_arrow( str sql=None, str sql_dialect=None, int return_fids=False, - int batch_size=0): + int batch_size=0, + use_pyarrow=False, +): cdef int err = 0 cdef const char *path_c = NULL @@ -1286,7 +1319,7 @@ def ogr_open_arrow( cdef char **fields_c = NULL cdef const char *field_c = NULL cdef char **options = NULL - cdef ArrowArrayStream stream + cdef ArrowArrayStream* stream cdef ArrowSchema schema IF CTE_GDAL_VERSION < (3, 6, 0): @@ -1419,19 +1452,23 @@ def ogr_open_arrow( # make sure layer is read from beginning OGR_L_ResetReading(ogr_layer) - if not OGR_L_GetArrowStream(ogr_layer, &stream, options): - raise RuntimeError("Failed to open ArrowArrayStream from Layer") + # allocate the stream struct and wrap in capsule to ensure clean-up on error + capsule = alloc_c_stream(&stream) - stream_ptr = &stream + if not OGR_L_GetArrowStream(ogr_layer, stream, options): + raise RuntimeError("Failed to open ArrowArrayStream from Layer") if skip_features: # only supported for GDAL >= 3.8.0; have to do this after getting # the Arrow stream OGR_L_SetNextByIndex(ogr_layer, skip_features) - # stream has to be consumed before the Dataset is closed - import pyarrow as pa - reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr) + if use_pyarrow: + import pyarrow as pa + + reader = pa.RecordBatchStreamReader._import_from_c( stream) + else: + reader = _ArrowStream(capsule) meta = { 'crs': crs, @@ -1442,13 +1479,16 @@ def ogr_open_arrow( 'fid_column': fid_column, } + # stream has to be consumed before the Dataset is closed yield meta, reader finally: - if reader is not None: + if use_pyarrow and reader is not None: # Mark reader as closed to prevent reading batches reader.close() + # `stream` will be freed through `capsule` destructor + CSLDestroy(options) if fields_c != NULL: CSLDestroy(fields_c) @@ -1465,6 +1505,7 @@ def ogr_open_arrow( GDALClose(ogr_dataset) ogr_dataset = NULL + def ogr_read_bounds( str path, object layer=None, diff --git a/pyogrio/_ogr.pxd b/pyogrio/_ogr.pxd index 35fbd29a..fa75dd89 100644 --- a/pyogrio/_ogr.pxd +++ b/pyogrio/_ogr.pxd @@ -190,12 +190,13 @@ cdef extern from "ogr_srs_api.h": void OSRRelease(OGRSpatialReferenceH srs) -cdef extern from "arrow_bridge.h": +cdef extern from "arrow_bridge.h" nogil: struct ArrowSchema: int64_t n_children struct ArrowArrayStream: - int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out) + int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out) noexcept + void (*release)(ArrowArrayStream*) noexcept cdef extern from "ogr_api.h": diff --git a/pyogrio/raw.py b/pyogrio/raw.py index a32f84d8..6466f4d7 100644 --- a/pyogrio/raw.py +++ b/pyogrio/raw.py @@ -1,7 +1,7 @@ import warnings from pyogrio._env import GDALEnv -from pyogrio._compat import HAS_ARROW_API +from pyogrio._compat import HAS_ARROW_API, HAS_PYARROW from pyogrio.core import detect_write_driver from pyogrio.errors import DataSourceError from pyogrio.util import ( @@ -256,6 +256,12 @@ def read_arrow( "geometry_name": "", } """ + if not HAS_PYARROW: + raise RuntimeError( + "pyarrow required to read using 'read_arrow'. You can use 'open_arrow' " + "to read data with an alternative Arrow implementation" + ) + from pyarrow import Table gdal_version = get_gdal_version() @@ -297,6 +303,7 @@ def read_arrow( return_fids=return_fids, skip_features=gdal_skip_features, batch_size=batch_size, + use_pyarrow=True, **kwargs, ) as source: meta, reader = source @@ -351,17 +358,37 @@ def open_arrow( sql_dialect=None, return_fids=False, batch_size=65_536, + use_pyarrow=False, **kwargs, ): """ - Open OGR data source as a stream of pyarrow record batches. + Open OGR data source as a stream of Arrow record batches. See docstring of `read` for parameters. - The RecordBatchStreamReader is reading from a stream provided by OGR and must not be + The returned object is reading from a stream provided by OGR and must not be accessed after the OGR dataset has been closed, i.e. after the context manager has been closed. + By default this functions returns a generic stream object implementing + the `Arrow PyCapsule Protocol`_ (i.e. having an ``__arrow_c_stream__`` + method). This object can then be consumed by your Arrow implementation + of choice that supports this protocol. + Optionally, you can specify ``use_pyarrow=True`` to directly get the + stream as a `pyarrow.RecordBatchReader`. + + .. _Arrow PyCapsule Protocol: https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + + Other Parameters + ---------------- + batch_size : int (default: 65_536) + Maximum number of features to retrieve in a batch. + use_pyarrow : bool (default: False) + If True, return a pyarrow RecordBatchReader instead of a generic + ArrowStream object. In the default case, this stream object needs + to be passed to another library supporting the Arrow PyCapsule + Protocol to consume the stream of data. + Examples -------- @@ -370,16 +397,29 @@ def open_arrow( >>> import shapely >>> >>> with open_arrow(path) as source: + >>> meta, stream = source + >>> # wrap the arrow stream object in a pyarrow RecordBatchReader + >>> reader = pa.RecordBatchReader.from_stream(stream) + >>> for batch in reader: + >>> geometries = shapely.from_wkb(batch[meta["geometry_name"] or "wkb_geometry"]) + + The returned `stream` object needs to be consumed by a library implementing + the Arrow PyCapsule Protocol. In the above example, pyarrow is used through + its RecordBatchReader. For this case, you can also specify ``use_pyarrow=True`` + to directly get this result as a short-cut: + + >>> with open_arrow(path, use_pyarrow=True) as source: >>> meta, reader = source - >>> for table in reader: - >>> geometries = shapely.from_wkb(table[meta["geometry_name"]]) + >>> for batch in reader: + >>> geometries = shapely.from_wkb(batch[meta["geometry_name"] or "wkb_geometry"]) Returns ------- - (dict, pyarrow.RecordBatchStreamReader) + (dict, pyarrow.RecordBatchReader or ArrowStream) Returns a tuple of meta information about the data source in a dict, - and a pyarrow RecordBatchStreamReader with data. + and a data stream object (a generic ArrowStream object, or a pyarrow + RecordBatchReader if `use_pyarrow` is set to True). Meta is: { "crs": "", @@ -390,7 +430,7 @@ def open_arrow( } """ if not HAS_ARROW_API: - raise RuntimeError("pyarrow and GDAL>= 3.6 required to read using arrow") + raise RuntimeError("GDAL>= 3.6 required to read using arrow") path, buffer = get_vsi_path(path_or_buffer) @@ -415,6 +455,7 @@ def open_arrow( return_fids=return_fids, dataset_kwargs=dataset_kwargs, batch_size=batch_size, + use_pyarrow=use_pyarrow, ) finally: if buffer is not None: diff --git a/pyogrio/tests/conftest.py b/pyogrio/tests/conftest.py index 76327b4f..20b84415 100644 --- a/pyogrio/tests/conftest.py +++ b/pyogrio/tests/conftest.py @@ -8,7 +8,7 @@ __version__, list_drivers, ) -from pyogrio._compat import HAS_ARROW_API, HAS_GDAL_GEOS, HAS_SHAPELY +from pyogrio._compat import HAS_ARROW_API, HAS_GDAL_GEOS, HAS_PYARROW, HAS_SHAPELY from pyogrio.raw import read, write @@ -43,8 +43,9 @@ def pytest_report_header(config): # marks to skip tests if optional dependecies are not present -requires_arrow_api = pytest.mark.skipif( - not HAS_ARROW_API, reason="GDAL>=3.6 and pyarrow required" +requires_arrow_api = pytest.mark.skipif(not HAS_ARROW_API, reason="GDAL>=3.6 required") +requires_pyarrow_api = pytest.mark.skipif( + not HAS_ARROW_API or not HAS_PYARROW, reason="GDAL>=3.6 and pyarrow required" ) requires_gdal_geos = pytest.mark.skipif( diff --git a/pyogrio/tests/test_arrow.py b/pyogrio/tests/test_arrow.py index 5300ea83..481e6b8c 100644 --- a/pyogrio/tests/test_arrow.py +++ b/pyogrio/tests/test_arrow.py @@ -2,13 +2,15 @@ import json import math import os +import sys import pytest - import numpy as np + +import pyogrio from pyogrio import __gdal_version__, read_dataframe from pyogrio.raw import open_arrow, read_arrow, write -from pyogrio.tests.conftest import ALL_EXTS, requires_arrow_api +from pyogrio.tests.conftest import ALL_EXTS, requires_pyarrow_api try: import pandas as pd @@ -20,7 +22,7 @@ pass # skip all tests in this file if Arrow API or GeoPandas are unavailable -pytestmark = requires_arrow_api +pytestmark = requires_pyarrow_api pytest.importorskip("geopandas") @@ -137,8 +139,8 @@ def test_read_arrow_raw(naturalearth_lowres): assert isinstance(table, pyarrow.Table) -def test_open_arrow(naturalearth_lowres): - with open_arrow(naturalearth_lowres) as (meta, reader): +def test_open_arrow_pyarrow(naturalearth_lowres): + with open_arrow(naturalearth_lowres, use_pyarrow=True) as (meta, reader): assert isinstance(meta, dict) assert isinstance(reader, pyarrow.RecordBatchReader) assert isinstance(reader.read_all(), pyarrow.Table) @@ -148,7 +150,10 @@ def test_open_arrow_batch_size(naturalearth_lowres): meta, table = read_arrow(naturalearth_lowres) batch_size = math.ceil(len(table) / 2) - with open_arrow(naturalearth_lowres, batch_size=batch_size) as (meta, reader): + with open_arrow(naturalearth_lowres, batch_size=batch_size, use_pyarrow=True) as ( + meta, + reader, + ): assert isinstance(meta, dict) assert isinstance(reader, pyarrow.RecordBatchReader) count = 0 @@ -207,6 +212,36 @@ def test_read_arrow_geoarrow_metadata(naturalearth_lowres): assert parsed_meta["crs"]["id"]["code"] == 4326 +def test_open_arrow_capsule_protocol(naturalearth_lowres): + pytest.importorskip("pyarrow", minversion="14") + + with open_arrow(naturalearth_lowres) as (meta, reader): + assert isinstance(meta, dict) + assert isinstance(reader, pyogrio._io._ArrowStream) + + result = pyarrow.table(reader) + + _, expected = read_arrow(naturalearth_lowres) + assert result.equals(expected) + + +def test_open_arrow_capsule_protocol_without_pyarrow(naturalearth_lowres): + pyarrow = pytest.importorskip("pyarrow", minversion="14") + + # Make PyArrow temporarily unavailable (importing will fail) + sys.modules["pyarrow"] = None + try: + with open_arrow(naturalearth_lowres) as (meta, reader): + assert isinstance(meta, dict) + assert isinstance(reader, pyogrio._io._ArrowStream) + result = pyarrow.table(reader) + finally: + sys.modules["pyarrow"] = pyarrow + + _, expected = read_arrow(naturalearth_lowres) + assert result.equals(expected) + + @contextlib.contextmanager def use_arrow_context(): original = os.environ.get("PYOGRIO_USE_ARROW", None) diff --git a/pyogrio/tests/test_geopandas_io.py b/pyogrio/tests/test_geopandas_io.py index 3d254ee5..cec77d9b 100644 --- a/pyogrio/tests/test_geopandas_io.py +++ b/pyogrio/tests/test_geopandas_io.py @@ -14,7 +14,7 @@ from pyogrio.tests.conftest import ( ALL_EXTS, DRIVERS, - requires_arrow_api, + requires_pyarrow_api, requires_gdal_geos, ) from pyogrio._compat import PANDAS_GE_15 @@ -45,7 +45,7 @@ scope="session", params=[ False, - pytest.param(True, marks=requires_arrow_api), + pytest.param(True, marks=requires_pyarrow_api), ], ) def use_arrow(request): @@ -1582,7 +1582,7 @@ def test_read_dataframe_arrow_dtypes(tmp_path): assert_geodataframe_equal(result, df) -@requires_arrow_api +@requires_pyarrow_api @pytest.mark.skipif( __gdal_version__ < (3, 8, 3), reason="Arrow bool value bug fixed in GDAL >= 3.8.3" ) @@ -1607,7 +1607,7 @@ def test_arrow_bool_roundtrip(tmpdir, ext): assert_geodataframe_equal(result, df, check_dtype=ext != ".shp") -@requires_arrow_api +@requires_pyarrow_api @pytest.mark.skipif( __gdal_version__ >= (3, 8, 3), reason="Arrow bool value bug fixed in GDAL >= 3.8.3" ) diff --git a/pyogrio/tests/test_raw_io.py b/pyogrio/tests/test_raw_io.py index 362c5773..8330bbfe 100644 --- a/pyogrio/tests/test_raw_io.py +++ b/pyogrio/tests/test_raw_io.py @@ -1,4 +1,5 @@ import contextlib +import ctypes import json import os import sys @@ -7,6 +8,7 @@ from numpy import array_equal import pytest +import pyogrio from pyogrio import ( list_layers, list_drivers, @@ -14,13 +16,14 @@ set_gdal_config_options, __gdal_version__, ) -from pyogrio._compat import HAS_SHAPELY -from pyogrio.raw import read, write +from pyogrio._compat import HAS_SHAPELY, HAS_PYARROW +from pyogrio.raw import read, write, open_arrow from pyogrio.errors import DataSourceError, DataLayerError, FeatureError from pyogrio.tests.conftest import ( DRIVERS, DRIVER_EXT, prepare_testfile, + requires_pyarrow_api, requires_arrow_api, ) @@ -1025,7 +1028,7 @@ def test_write_float_nan_null(tmp_path, dtype): assert '{ "col": NaN }' in content -@requires_arrow_api +@requires_pyarrow_api @pytest.mark.skipif( "Arrow" not in list_drivers(), reason="Arrow driver is not available" ) @@ -1194,3 +1197,31 @@ def test_write_with_mask(tmp_path): field_mask = [np.array([False, True, False])] * 2 with pytest.raises(ValueError): write(filename, geometry, field_data, fields, field_mask, **meta) + + +@requires_arrow_api +def test_open_arrow_capsule_protocol_without_pyarrow(naturalearth_lowres): + # this test is included here instead of test_arrow.py to ensure we also run + # it when pyarrow is not installed + + with open_arrow(naturalearth_lowres) as (meta, reader): + assert isinstance(meta, dict) + assert isinstance(reader, pyogrio._io._ArrowStream) + capsule = reader.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + +@pytest.mark.skipif(HAS_PYARROW, reason="pyarrow is installed") +@requires_arrow_api +def test_open_arrow_error_no_pyarrow(naturalearth_lowres): + # this test is included here instead of test_arrow.py to ensure we run + # it when pyarrow is not installed + + with pytest.raises(ImportError): + with open_arrow(naturalearth_lowres, use_pyarrow=True) as _: + pass