Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sequence of Items #164

Merged
merged 1 commit into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions stackstac/stac_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Tuple,
TypedDict,
Union,
cast,
)

possible_problems: list[str] = []
Expand Down Expand Up @@ -134,7 +133,11 @@ class ItemDict(TypedDict):
]


def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence:
def items_to_plain(
items: Union[
ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem]
]
) -> ItemSequence:
"""
Convert something like a collection/Catalog of STAC items into a list of plain dicts

Expand All @@ -149,10 +152,24 @@ def items_to_plain(items: Union[ItemCollectionIsh, ItemIsh]) -> ItemSequence:
if isinstance(items, Sequence):
# slicing a satstac `ItemCollection` produces a list, not another `ItemCollection`,
# so having a `List[SatstacItem]` is quite possible
try:
return [item._data for item in cast(SatstacItemCollection, items)]
except AttributeError:
return items
results = []
for item in items:
if isinstance(item, PystacItem):
results.append(item.to_dict())
elif isinstance(item, SatstacItem):
results.append(item._data)
elif isinstance(item, dict):
results.append(item)
else:
raise TypeError(
f"Unrecognized STAC item type {type(item)}: {item!r}"
+ (
"\n".join(["\nPossible problems:"] + possible_problems)
if possible_problems
else ""
)
)
return results

if isinstance(items, SatstacItem):
return [items._data]
Expand Down
14 changes: 11 additions & 3 deletions stackstac/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
from .reader_protocol import Reader
from .rio_env import LayeredEnv
from .rio_reader import AutoParallelRioReader
from .stac_types import ItemCollectionIsh, ItemIsh, items_to_plain
from .stac_types import (
ItemCollectionIsh,
ItemIsh,
PystacItem,
SatstacItem,
items_to_plain,
)
from .to_dask import items_to_dask, ChunksParam


def stack(
items: Union[ItemCollectionIsh, ItemIsh],
items: Union[
ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem]
],
assets: Optional[Union[List[str], AbstractSet[str]]] = frozenset(
["image/tiff", "image/x.geotiff", "image/vnd.stac.geotiff", "image/jp2"]
),
Expand Down Expand Up @@ -309,5 +317,5 @@ def stack(
band_coords=band_coords,
),
attrs=to_attrs(spec),
name="stackstac-" + dask.base.tokenize(arr)
name="stackstac-" + dask.base.tokenize(arr),
)
113 changes: 102 additions & 11 deletions stackstac/tests/test_stac_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,87 @@

from stackstac import stac_types

satstac_itemcollection = satstac.itemcollection.ItemCollection(
[
satstac.item.Item(
{"id": "foo"},
),
satstac.item.Item(
{"id": "bar"},
),
]
)


pystac_catalog = pystac.Catalog("foo", "bar")
pystac_catalog.add_items(
[
pystac.Item("foo", None, None, datetime(2000, 1, 1), {}),
pystac.Item("bar", None, None, datetime(2001, 1, 1), {}),
]
)
pystac_foo_dict = {
k: v
for k, v in pystac_catalog.get_item("foo").to_dict().items()
if k
in (
"type",
"stac_version",
"id",
"properties",
"geometry",
"href",
"assets",
"stac_extensions",
)
}
pystac_bar_dict = {
k: v
for k, v in pystac_catalog.get_item("bar").to_dict().items()
if k
in (
"type",
"stac_version",
"id",
"properties",
"geometry",
"assets",
"stac_extensions",
)
}


@pytest.mark.parametrize(
"input, expected",
[
({"id": "foo"}, [{"id": "foo"}]),
([{"id": "foo"}, {"id": "bar"}], [{"id": "foo"}, {"id": "bar"}]),
# satstac,
(satstac_itemcollection[0], [{"id": "foo"}]),
(satstac_itemcollection, [{"id": "foo"}, {"id": "bar"}]),
(satstac_itemcollection[:], [{"id": "foo"}, {"id": "bar"}]),
# pystac,
(pystac_catalog.get_item("foo"), [pystac_foo_dict]),
(
pystac.ItemCollection(pystac_catalog.get_all_items()),
[pystac_foo_dict, pystac_bar_dict],
),
(
pystac_catalog,
[pystac_foo_dict, pystac_bar_dict],
),
(list(pystac_catalog.get_all_items()), [pystac_foo_dict, pystac_bar_dict]),
],
)
def test_basic(input, expected):
results = stac_types.items_to_plain(input)
assert isinstance(results, list)
assert len(results) == len(expected)
for result, exp in zip(results, expected):
# Only check fields stackstac actually cares about (we don't use the `link` field, for example)
subset = {k: v for k, v in result.items() if k in exp}
assert subset == exp


def test_normal_case():
assert stac_types.SatstacItem is satstac.item.Item
Expand All @@ -31,6 +112,9 @@ def test_missing_satstac(monkeypatch: pytest.MonkeyPatch):
assert "stackstac" in reloaded_stac_types.SatstacItemCollection.__module__

assert not reloaded_stac_types.possible_problems
# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)


def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
Expand All @@ -44,6 +128,9 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
assert "stackstac" in reloaded_stac_types.PystacItemCollection.__module__

assert not reloaded_stac_types.possible_problems
# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)


@pytest.mark.parametrize(
Expand All @@ -52,26 +139,26 @@ def test_missing_pystac(monkeypatch: pytest.MonkeyPatch):
(
satstac,
"item.Item",
satstac.item.Item(
{"id": "foo"},
),
satstac_itemcollection[0],
),
(
satstac,
"itemcollection.ItemCollection",
satstac.itemcollection.ItemCollection(
[],
),
satstac_itemcollection,
),
(
satstac,
"item.Item",
list(satstac_itemcollection),
),
(pystac, "Item", pystac.Item("foo", None, None, datetime(2000, 1, 1), {})),
(pystac, "Catalog", pystac.Catalog("foo", "bar")),
(pystac, "Item", pystac_catalog.get_item("foo")),
(pystac, "Catalog", pystac_catalog),
(
pystac,
"ItemCollection",
pystac.ItemCollection(
[],
),
pystac.ItemCollection([]),
),
(pystac, "Item", list(pystac_catalog.get_all_items())),
],
)
def test_unimportable_path(
Expand All @@ -97,3 +184,7 @@ def test_unimportable_path(

with pytest.raises(TypeError, match=f"Your version of `{modname}` is too old"):
reloaded_stac_types.items_to_plain(inst)

# clean things up for other tests
monkeypatch.undo()
importlib.reload(stac_types)