Skip to content

Commit

Permalink
Add support for fids filter with use_arrow=True (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
theroggy authored Apr 17, 2024
1 parent 87c0e99 commit d2dfbd1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Improvements

- Add support for `fids` filter to `read_arrow` and `open_arrow`, and to
`read_dataframe` with `use_arrow=True` (#304).
- Add some missing properties to `read_info`, including layer name, geometry name
and FID column name (#365).
- `read_arrow` and `open_arrow` now provide
Expand Down
41 changes: 38 additions & 3 deletions pyogrio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def ogr_open_arrow(
raise ValueError("forcing 2D is not supported for Arrow")

if fids is not None:
raise ValueError("reading by FID is not supported for Arrow")
if where is not None or bbox is not None or mask is not None or sql is not None or skip_features or max_features:
raise ValueError(
"cannot set both 'fids' and any of 'where', 'bbox', 'mask', "
"'sql', 'skip_features', or 'max_features'"
)

IF CTE_GDAL_VERSION < (3, 8, 0):
if skip_features:
Expand Down Expand Up @@ -1407,14 +1411,45 @@ def ogr_open_arrow(
geometry_name = get_string(OGR_L_GetGeometryColumn(ogr_layer))

fid_column = get_string(OGR_L_GetFIDColumn(ogr_layer))
fid_column_where = fid_column
# OGR_L_GetFIDColumn returns the column name if it is a custom column,
# or "" if not. For arrow, the default column name is "OGC_FID".
# or "" if not. For arrow, the default column name used to return the FID data
# read is "OGC_FID". When accessing the underlying datasource like when using a
# where clause, the default column name is "FID".
if fid_column == "":
fid_column = "OGC_FID"
fid_column_where = "FID"

# Use fids list to create a where clause, as arrow doesn't support direct fid
# filtering.
if fids is not None:
IF CTE_GDAL_VERSION < (3, 8, 0):
driver = get_driver(ogr_dataset)
if driver not in {"GPKG", "GeoJSON"}:
warnings.warn(
"Using 'fids' and 'use_arrow=True' with GDAL < 3.8 can be slow "
"for some drivers. Upgrading GDAL or using 'use_arrow=False' "
"can avoid this.",
stacklevel=2,
)

fids_str = ",".join([str(fid) for fid in fids])
where = f"{fid_column_where} IN ({fids_str})"

# Apply the attribute filter
if where is not None and where != "":
apply_where_filter(ogr_layer, where)
try:
apply_where_filter(ogr_layer, where)
except ValueError as ex:
if fids is not None and str(ex).startswith("Invalid SQL query"):
# If fids is not None, the where being applied is the one formatted
# above.
raise ValueError(
f"error applying filter for {len(fids)} fids; max. number for "
f"drivers with default SQL dialect 'OGRSQL' is 4997"
) from ex

raise

# Apply the spatial filter
if bbox is not None:
Expand Down
11 changes: 8 additions & 3 deletions pyogrio/geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ def read_dataframe(
the starting index is driver and file specific (e.g. typically 0 for
Shapefile and 1 for GeoPackage, but can still depend on the specific
file). The performance of reading a large number of features usings FIDs
is also driver specific.
is also driver specific and depends on the value of ``use_arrow``. The order
of the rows returned is undefined. If you would like to sort based on FID, use
``fid_as_index=True`` to have the index of the GeoDataFrame returned set to the
FIDs of the features read. If ``use_arrow=True``, the number of FIDs is limited
to 4997 for drivers with 'OGRSQL' as default SQL dialect. To read a larger
number of FIDs, set ``user_arrow=False``.
sql : str, optional (default: None)
The SQL statement to execute. Look at the sql_dialect parameter for more
information on the syntax to use for the query. When combined with other
Expand Down Expand Up @@ -345,7 +350,7 @@ def write_dataframe(
in the output file.
path : str
path to file
layer :str, optional (default: None)
layer : str, optional (default: None)
layer name
driver : string, optional (default: None)
The OGR format driver used to write the vector file. By default write_dataframe
Expand Down Expand Up @@ -545,7 +550,7 @@ def write_dataframe(
# if possible use EPSG codes instead
epsg = geometry.crs.to_epsg()
if epsg:
crs = f"EPSG:{epsg}"
crs = f"EPSG:{epsg}" # noqa: E231
else:
crs = geometry.crs.to_wkt(WktVersion.WKT1_GDAL)

Expand Down
37 changes: 34 additions & 3 deletions pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,45 @@ def test_read_mask_where(naturalearth_lowres_all_ext, use_arrow):
assert np.array_equal(df.iso_a3, ["CAN"])


def test_read_fids(naturalearth_lowres_all_ext):
@pytest.mark.parametrize("fids", [[1, 5, 10], np.array([1, 5, 10], dtype=np.int64)])
def test_read_fids(naturalearth_lowres_all_ext, fids, use_arrow):
# ensure keyword is properly passed through
fids = np.array([1, 10, 5], dtype=np.int64)
df = read_dataframe(naturalearth_lowres_all_ext, fids=fids, fid_as_index=True)
df = read_dataframe(
naturalearth_lowres_all_ext, fids=fids, fid_as_index=True, use_arrow=use_arrow
)
assert len(df) == 3
assert np.array_equal(fids, df.index.values)


@requires_pyarrow_api
def test_read_fids_arrow_max_exception(naturalearth_lowres):
# Maximum number at time of writing is 4997 for "OGRSQL". For e.g. for SQLite based
# formats like Geopackage, there is no limit.
nb_fids = 4998
fids = range(nb_fids)
with pytest.raises(ValueError, match=f"error applying filter for {nb_fids} fids"):
_ = read_dataframe(naturalearth_lowres, fids=fids, use_arrow=True)


@requires_pyarrow_api
@pytest.mark.skipif(
__gdal_version__ >= (3, 8, 0), reason="GDAL >= 3.8.0 does not need to warn"
)
def test_read_fids_arrow_warning_old_gdal(naturalearth_lowres_all_ext):
# A warning should be given for old GDAL versions, except for some file formats.
if naturalearth_lowres_all_ext.suffix not in [".gpkg", ".geojson"]:
handler = pytest.warns(
UserWarning,
match="Using 'fids' and 'use_arrow=True' with GDAL < 3.8 can be slow",
)
else:
handler = contextlib.nullcontext()

with handler:
df = read_dataframe(naturalearth_lowres_all_ext, fids=[22], use_arrow=True)
assert len(df) == 1


def test_read_fids_force_2d(test_fgdb_vsi):
with pytest.warns(
UserWarning, match=r"Measured \(M\) geometry types are not supported"
Expand Down

0 comments on commit d2dfbd1

Please sign in to comment.