Skip to content

Commit

Permalink
Support sequence of Items (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 authored Jul 28, 2022
1 parent 2027a88 commit c67944b
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 20 deletions.
29 changes: 23 additions & 6 deletions stackstac/stac_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Tuple,
TypedDict,
Union,
cast,
)

possible_problems: list[str] = []
Expand Down Expand Up @@ -134,7 +133,11 @@ class ItemDict(TypedDict):
]


def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence:
def items_to_plain(
items: Union[
ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem]
]
) -> ItemSequence:
"""
Convert something like a collection/Catalog of STAC items into a list of plain dicts
Expand All @@ -149,10 +152,24 @@ def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence:
if isinstance(items, Sequence):
# slicing a satstac `ItemCollection` produces a list, not another `ItemCollection`,
# so having a `List[SatstacItem]` is quite possible
try:
return [item._data for item in cast(SatstacItemCollection, items)]
except AttributeError:
return items
results = []
for item in items:
if isinstance(item, PystacItem):
results.append(item.to_dict())
elif isinstance(item, SatstacItem):
results.append(item._data)
elif isinstance(item, dict):
results.append(item)
else:
raise TypeError(
f"Unrecognized STAC item type {type(item)}: {item!r}"
+ (
"\n".join(["\nPossible problems:"] + possible_problems)
if possible_problems
else ""
)
)
return results

if isinstance(items, SatstacItem):
return [items._data]
Expand Down
14 changes: 11 additions & 3 deletions stackstac/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
from .reader_protocol import Reader
from .rio_env import LayeredEnv
from .rio_reader import AutoParallelRioReader
from .stac_types import ItemCollectionIsh, ItemIsh, items_to_plain
from .stac_types import (
ItemCollectionIsh,
ItemIsh,
PystacItem,
SatstacItem,
items_to_plain,
)
from .to_dask import items_to_dask, ChunksParam


def stack(
items: Union[ItemCollectionIsh, ItemIsh],
items: Union[
ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem]
],
assets: Optional[Union[List[str], AbstractSet[str]]] = frozenset(
["image/tiff", "image/x.geotiff", "image/vnd.stac.geotiff", "image/jp2"]
),
Expand Down Expand Up @@ -309,5 +317,5 @@ def stack(
band_coords=band_coords,
),
attrs=to_attrs(spec),
name="stackstac-" + dask.base.tokenize(arr)
name="stackstac-" + dask.base.tokenize(arr),
)
113 changes: 102 additions & 11 deletions stackstac/tests/test_stac_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,87 @@

from stackstac import stac_types

satstac_itemcollection = satstac.itemcollection.ItemCollection(
[
satstac.item.Item(
{"id": "foo"},
),
satstac.item.Item(
{"id": "bar"},
),
]
)


pystac_catalog = pystac.Catalog("foo", "bar")
pystac_catalog.add_items(
[
pystac.Item("foo", None, None, datetime(2000, 1, 1), {}),
pystac.Item("bar", None, None, datetime(2001, 1, 1), {}),
]
)
pystac_foo_dict = {
k: v
for k, v in pystac_catalog.get_item("foo").to_dict().items()
if k
in (
"type",
"stac_version",
"id",
"properties",
"geometry",
"href",
"assets",
"stac_extensions",
)
}
pystac_bar_dict = {
k: v
for k, v in pystac_catalog.get_item("bar").to_dict().items()
if k
in (
"type",
"stac_version",
"id",
"properties",
"geometry",
"assets",
"stac_extensions",
)
}


@pytest.mark.parametrize(
"input, expected",
[
({"id": "foo"}, [{"id": "foo"}]),
([{"id": "foo"}, {"id": "bar"}], [{"id": "foo"}, {"id": "bar"}]),
# satstac,
(satstac_itemcollection[0], [{"id": "foo"}]),
(satstac_itemcollection, [{"id": "foo"}, {"id": "bar"}]),
(satstac_itemcollection[:], [{"id": "foo"}, {"id": "bar"}]),
# pystac,
(pystac_catalog.get_item("foo"), [pystac_foo_dict]),
(
pystac.ItemCollection(pystac_catalog.get_all_items()),
[pystac_foo_dict, pystac_bar_dict],
),
(
pystac_catalog,
[pystac_foo_dict, pystac_bar_dict],
),
(list(pystac_catalog.get_all_items()), [pystac_foo_dict, pystac_bar_dict]),
],
)
def test_basic(input, expected):
results = stac_types.items_to_plain(input)
assert isinstance(results, list)
assert len(results) == len(expected)
for result, exp in zip(results, expected):
# Only check fields stackstac actually cares about (we don't use the `link` field, for example)
subset = {k: v for k, v in result.items() if k in exp}
assert subset == exp


def test_normal_case():
assert stac_types.SatstacItem is satstac.item.Item
Expand All @@ -31,6 +112,9 @@ def test_missing_satstac(monkeypatch: pytest.MonkeyPatch):
assert "stackstac" in reloaded_stac_types.SatstacItemCollection.__module__

assert not reloaded_stac_types.possible_problems
# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)


def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
Expand All @@ -44,6 +128,9 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
assert "stackstac" in reloaded_stac_types.PystacItemCollection.__module__

assert not reloaded_stac_types.possible_problems
# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)


@pytest.mark.parametrize(
Expand All @@ -52,26 +139,26 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
(
satstac,
"item.Item",
satstac.item.Item(
{"id": "foo"},
),
satstac_itemcollection[0],
),
(
satstac,
"itemcollection.ItemCollection",
satstac.itemcollection.ItemCollection(
[],
),
satstac_itemcollection,
),
(
satstac,
"item.Item",
list(satstac_itemcollection),
),
(pystac, "Item", pystac.Item("foo", None, None, datetime(2000, 1, 1), {})),
(pystac, "Catalog", pystac.Catalog("foo", "bar")),
(pystac, "Item", pystac_catalog.get_item("foo")),
(pystac, "Catalog", pystac_catalog),
(
pystac,
"ItemCollection",
pystac.ItemCollection(
[],
),
pystac.ItemCollection([]),
),
(pystac, "Item", list(pystac_catalog.get_all_items())),
],
)
def test_unimportable_path(
Expand All @@ -97,3 +184,7 @@ def test_unimportable_path(

with pytest.raises(TypeError, match=f"Your version of `{modname}` is too old"):
reloaded_stac_types.items_to_plain(inst)

# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)

0 comments on commit c67944b

Please sign in to comment.