diff --git a/stackstac/stac_types.py b/stackstac/stac_types.py index 2fc3086..738c4c5 100644 --- a/stackstac/stac_types.py +++ b/stackstac/stac_types.py @@ -16,7 +16,6 @@ Tuple, TypedDict, Union, - cast, ) possible_problems: list[str] = [] @@ -134,7 +133,11 @@ class ItemDict(TypedDict): ] -def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence: +def items_to_plain( + items: Union[ + ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem] + ] +) -> ItemSequence: """ Convert something like a collection/Catalog of STAC items into a list of plain dicts @@ -149,10 +152,24 @@ def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence: if isinstance(items, Sequence): # slicing a satstac `ItemCollection` produces a list, not another `ItemCollection`, # so having a `List[SatstacItem]` is quite possible - try: - return [item._data for item in cast(SatstacItemCollection, items)] - except AttributeError: - return items + results = [] + for item in items: + if isinstance(item, PystacItem): + results.append(item.to_dict()) + elif isinstance(item, SatstacItem): + results.append(item._data) + elif isinstance(item, dict): + results.append(item) + else: + raise TypeError( + f"Unrecognized STAC item type {type(item)}: {item!r}" + + ( + "\n".join(["\nPossible problems:"] + possible_problems) + if possible_problems + else "" + ) + ) + return results if isinstance(items, SatstacItem): return [items._data] diff --git a/stackstac/stack.py b/stackstac/stack.py index 6ab51e0..13f3368 100644 --- a/stackstac/stack.py +++ b/stackstac/stack.py @@ -13,12 +13,20 @@ from .reader_protocol import Reader from .rio_env import LayeredEnv from .rio_reader import AutoParallelRioReader -from .stac_types import ItemCollectionIsh, ItemIsh, items_to_plain +from .stac_types import ( + ItemCollectionIsh, + ItemIsh, + PystacItem, + SatstacItem, + items_to_plain, +) from .to_dask import items_to_dask, ChunksParam def stack( - items: Union[ItemCollectionIsh, ItemIsh], + items: Union[ + ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem] + ], assets: Optional[Union[List[str], AbstractSet[str]]] = frozenset( ["image/tiff", "image/x.geotiff", "image/vnd.stac.geotiff", "image/jp2"] ), @@ -309,5 +317,5 @@ def stack( band_coords=band_coords, ), attrs=to_attrs(spec), - name="stackstac-" + dask.base.tokenize(arr) + name="stackstac-" + dask.base.tokenize(arr), ) diff --git a/stackstac/tests/test_stac_types.py b/stackstac/tests/test_stac_types.py index cc635f8..f774b63 100644 --- a/stackstac/tests/test_stac_types.py +++ b/stackstac/tests/test_stac_types.py @@ -9,6 +9,87 @@ from stackstac import stac_types +satstac_itemcollection = satstac.itemcollection.ItemCollection( + [ + satstac.item.Item( + {"id": "foo"}, + ), + satstac.item.Item( + {"id": "bar"}, + ), + ] +) + + +pystac_catalog = pystac.Catalog("foo", "bar") +pystac_catalog.add_items( + [ + pystac.Item("foo", None, None, datetime(2000, 1, 1), {}), + pystac.Item("bar", None, None, datetime(2001, 1, 1), {}), + ] +) +pystac_foo_dict = { + k: v + for k, v in pystac_catalog.get_item("foo").to_dict().items() + if k + in ( + "type", + "stac_version", + "id", + "properties", + "geometry", + "href", + "assets", + "stac_extensions", + ) +} +pystac_bar_dict = { + k: v + for k, v in pystac_catalog.get_item("bar").to_dict().items() + if k + in ( + "type", + "stac_version", + "id", + "properties", + "geometry", + "assets", + "stac_extensions", + ) +} + + +@pytest.mark.parametrize( + "input, expected", + [ + ({"id": "foo"}, [{"id": "foo"}]), + ([{"id": "foo"}, {"id": "bar"}], [{"id": "foo"}, {"id": "bar"}]), + # satstac, + (satstac_itemcollection[0], [{"id": "foo"}]), + (satstac_itemcollection, [{"id": "foo"}, {"id": "bar"}]), + (satstac_itemcollection[:], [{"id": "foo"}, {"id": "bar"}]), + # pystac, + (pystac_catalog.get_item("foo"), [pystac_foo_dict]), + ( + pystac.ItemCollection(pystac_catalog.get_all_items()), + [pystac_foo_dict, pystac_bar_dict], + ), + ( + pystac_catalog, + [pystac_foo_dict, pystac_bar_dict], + ), + (list(pystac_catalog.get_all_items()), [pystac_foo_dict, pystac_bar_dict]), + ], +) +def test_basic(input, expected): + results = stac_types.items_to_plain(input) + assert isinstance(results, list) + assert len(results) == len(expected) + for result, exp in zip(results, expected): + # Only check fields stackstac actually cares about (we don't use the `link` field, for example) + subset = {k: v for k, v in result.items() if k in exp} + assert subset == exp + def test_normal_case(): assert stac_types.SatstacItem is satstac.item.Item @@ -31,6 +112,9 @@ def test_missing_satstac(monkeypatch: pytest.MonkeyPatch): assert "stackstac" in reloaded_stac_types.SatstacItemCollection.__module__ assert not reloaded_stac_types.possible_problems + # clean things up for other tests + monkeypatch.undo() + importlib.reload(stac_types) def test_missing_pystac(monkeypatch: pytest.MonkeyPatch): @@ -44,6 +128,9 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch): assert "stackstac" in reloaded_stac_types.PystacItemCollection.__module__ assert not reloaded_stac_types.possible_problems + # clean things up for other tests + monkeypatch.undo() + importlib.reload(stac_types) @pytest.mark.parametrize( @@ -52,26 +139,26 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch): ( satstac, "item.Item", - satstac.item.Item( - {"id": "foo"}, - ), + satstac_itemcollection[0], ), ( satstac, "itemcollection.ItemCollection", - satstac.itemcollection.ItemCollection( - [], - ), + satstac_itemcollection, + ), + ( + satstac, + "item.Item", + list(satstac_itemcollection), ), - (pystac, "Item", pystac.Item("foo", None, None, datetime(2000, 1, 1), {})), - (pystac, "Catalog", pystac.Catalog("foo", "bar")), + (pystac, "Item", pystac_catalog.get_item("foo")), + (pystac, "Catalog", pystac_catalog), ( pystac, "ItemCollection", - pystac.ItemCollection( - [], - ), + pystac.ItemCollection([]), ), + (pystac, "Item", list(pystac_catalog.get_all_items())), ], ) def test_unimportable_path( @@ -97,3 +184,7 @@ def test_unimportable_path( with pytest.raises(TypeError, match=f"Your version of `{modname}` is too old"): reloaded_stac_types.items_to_plain(inst) + + # clean things up for other tests + monkeypatch.undo() + importlib.reload(stac_types)