diff --git a/CHANGELOG.md b/CHANGELOG.md index 68b533a79..58c0b0c8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - Self links no longer included in Items for "relative published" catalogs ([#725](https://github.com/stac-utils/pystac/pull/725)) - Adding New and Custom Extensions tutorial now up-to-date with new extensions API ([#724](https://github.com/stac-utils/pystac/pull/724)) +- Type errors when initializing `TemporalExtent` using a list of `datetime` objects ([#744](https://github.com/stac-utils/pystac/pull/744)) ### Deprecated diff --git a/pystac/collection.py b/pystac/collection.py index b40a8f3ae..86d4ccd54 100644 --- a/pystac/collection.py +++ b/pystac/collection.py @@ -37,6 +37,10 @@ from pystac.provider import Provider as Provider_Type T = TypeVar("T") +TemporalIntervals = Union[List[List[datetime]], List[List[Optional[datetime]]]] +TemporalIntervalsLike = Union[ + TemporalIntervals, List[datetime], List[Optional[datetime]] +] class SpatialExtent: @@ -176,7 +180,7 @@ class TemporalExtent: Datetimes are required to be in UTC. """ - intervals: List[List[Optional[datetime]]] + intervals: TemporalIntervals """A list of two datetimes wrapped in a list, representing the temporal extent of a Collection. Open date ranges are represented by either the start (the first element of the interval) or the @@ -188,16 +192,16 @@ class TemporalExtent: def __init__( self, - intervals: Union[List[List[Optional[datetime]]], List[Optional[datetime]]], + intervals: TemporalIntervals, extra_fields: Optional[Dict[str, Any]] = None, ): # A common mistake is to pass in a single interval instead of a # list of intervals. Account for this by transforming the input # in that case. if isinstance(intervals, list) and isinstance(intervals[0], datetime): - self.intervals = [cast(List[Optional[datetime]], intervals)] + self.intervals = intervals else: - self.intervals = cast(List[List[Optional[datetime]]], intervals) + self.intervals = intervals self.extra_fields = extra_fields or {} diff --git a/tests/test_collection.py b/tests/test_collection.py index 868550294..2a307c38f 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -18,7 +18,7 @@ CatalogType, Provider, ) -from pystac.utils import datetime_to_str, get_required +from pystac.utils import datetime_to_str, get_required, str_to_datetime from tests.utils import TestCases, ARBITRARY_GEOM, ARBITRARY_BBOX TEST_DATETIME = datetime(2020, 3, 14, 16, 32) @@ -263,6 +263,14 @@ class ExtentTest(unittest.TestCase): def setUp(self) -> None: self.maxDiff = None + def test_temporal_extent_init_typing(self) -> None: + # This test exists purely to test the typing of the intervals argument to + # TemporalExtent + start_datetime = str_to_datetime("2022-01-01T00:00:00Z") + end_datetime = str_to_datetime("2022-01-31T23:59:59Z") + + _ = TemporalExtent([[start_datetime, end_datetime]]) + def test_spatial_allows_single_bbox(self) -> None: temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]])