diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index 7a3d07c..08840b3 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -146,6 +146,10 @@ def stac_table_to_ndjson( ) -> None: """Write STAC Arrow to a newline-delimited JSON file. + !!! note + This function _appends_ to the JSON file at `dest`; it does not overwrite any + existing data. + Args: table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow RecordBatchReader, or any other Arrow stream object exposed through the diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 78584b0..e9f4151 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -1,3 +1,4 @@ +import itertools import json from io import BytesIO from pathlib import Path @@ -7,6 +8,7 @@ import pytest from stac_geoparquet.arrow import ( + DEFAULT_JSON_CHUNK_SIZE, parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow, stac_table_to_items, @@ -33,30 +35,46 @@ "us-census", ] +CHUNK_SIZES = [2, DEFAULT_JSON_CHUNK_SIZE] -@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) -def test_round_trip_read_write(collection_id: str): + +@pytest.mark.parametrize( + "collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES) +) +def test_round_trip_read_write(collection_id: str, chunk_size: int): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) - table = pa.Table.from_batches(parse_stac_items_to_arrow(items)) + table = parse_stac_items_to_arrow(items, chunk_size=chunk_size).read_all() items_result = list(stac_table_to_items(table)) for result, expected in zip(items_result, items): assert_json_value_equal(result, expected, precision=0) -@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) -def test_round_trip_write_read_ndjson(collection_id: str, tmp_path: Path): +@pytest.mark.parametrize( + "collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES) +) +def test_round_trip_write_read_ndjson( + collection_id: str, chunk_size: int, tmp_path: Path +): # First load into a STAC-GeoParquet table path = HERE / "data" / f"{collection_id}-pc.json" - table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) + table = parse_stac_ndjson_to_arrow(path, chunk_size=chunk_size).read_all() # Then write to disk stac_table_to_ndjson(table, tmp_path / "tmp.ndjson") - # Then read back and assert tables match - table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(tmp_path / "tmp.ndjson")) + with open(path) as f: + orig_json = json.load(f) + + rt_json = [] + with open(tmp_path / "tmp.ndjson") as f: + for line in f: + rt_json.append(json.loads(line)) + + # Then read back and assert JSON data matches + assert_json_value_equal(orig_json, rt_json, precision=0) def test_table_contains_geoarrow_metadata(): @@ -64,7 +82,7 @@ def test_table_contains_geoarrow_metadata(): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) - table = pa.Table.from_batches(parse_stac_items_to_arrow(items)) + table = parse_stac_items_to_arrow(items).read_all() field_meta = table.schema.field("geometry").metadata assert field_meta[b"ARROW:extension:name"] == b"geoarrow.wkb" assert json.loads(field_meta[b"ARROW:extension:metadata"])["crs"]["id"] == { @@ -107,7 +125,7 @@ def test_to_parquet_two_geometry_columns(): with open(HERE / "data" / "3dep-lidar-copc-pc.json") as f: items = json.load(f) - table = pa.Table.from_batches(parse_stac_items_to_arrow(items)) + table = parse_stac_items_to_arrow(items).read_all() with BytesIO() as bio: to_parquet(table, bio) bio.seek(0)