diff --git a/docs/installation.rst b/docs/installation.rst index 68df3f4af..6991c7f87 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -92,6 +92,7 @@ For example: - ``rioxarray`` for GeoTIFF support in the assert helpers from ``openeo.testing.results`` - ``geopandas`` for working with dataframes with geospatial support, (e.g. with :py:class:`~openeo.extra.job_management.MultiBackendJobManager`) +- ``pystac_client`` for creating a STAC API Job Database (e.g. with :py:class:`~openeo.extra.job_management.stac_job_db.STACAPIJobDatabase`) Enabling additional features diff --git a/openeo/extra/job_management/__init__.py b/openeo/extra/job_management/__init__.py index 2a37e2b73..8f7f4b7df 100644 --- a/openeo/extra/job_management/__init__.py +++ b/openeo/extra/job_management/__init__.py @@ -104,21 +104,6 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame: """ ... - @abc.abstractmethod - def initialize_from_df(self, df: pd.DataFrame, on_exists: str = "error") -> "JobDatabaseInterface": - """ - Initialize the job database from a given dataframe, - - :param df: dataframe with some columns your ``start_job`` callable expects - :param on_exists: what to do when the job database already exists: - - "error": (default) raise an exception - - "skip": work with existing database, ignore given dataframe and skip any initialization - - :return: initialized job database. - """ - ... - - def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs): raise NotImplementedError("No 'start_job' callable provided") diff --git a/openeo/extra/job_management/stac_job_db.py b/openeo/extra/job_management/stac_job_db.py index 16755b7a2..e794757f3 100644 --- a/openeo/extra/job_management/stac_job_db.py +++ b/openeo/extra/job_management/stac_job_db.py @@ -1,16 +1,14 @@ -import concurrent +import concurrent.futures import datetime import logging -from typing import Iterable, List, Union +from typing import Iterable, List import geopandas as gpd import numpy as np import pandas as pd import pystac +import pystac_client import requests -from pystac import Collection, Item -from pystac_client import Client -from requests.auth import HTTPBasicAuth from shapely.geometry import mapping, shape from openeo.extra.job_management import JobDatabaseInterface, MultiBackendJobManager @@ -45,7 +43,7 @@ def __init__( :param geometry_column: The name of the geometry column in the job metadata that implements __geo_interface__. """ self.collection_id = collection_id - self.client = Client.open(stac_root_url) + self.client = pystac_client.Client.open(stac_root_url) self._auth = auth self.has_geometry = has_geometry @@ -54,7 +52,7 @@ def __init__( self.bulk_size = 500 def exists(self) -> bool: - return len([c.id for c in self.client.get_collections() if c.id == self.collection_id]) > 0 + return any(c.id == self.collection_id for c in self.client.get_collections()) def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"): """ @@ -108,7 +106,6 @@ def series_from(self, item: pystac.Item) -> pd.Series: item_dict = item.to_dict() item_id = item_dict["id"] dt = item_dict["properties"]["datetime"] - item_dict["datetime"] = pystac.utils.str_to_datetime(dt) return pd.Series(item_dict["properties"], name=item_id) @@ -151,6 +148,9 @@ def item_from(self, series: pd.Series) -> pystac.Item: return item def count_by_status(self, statuses: Iterable[str] = ()) -> dict: + if isinstance(statuses, str): + statuses = {statuses} + statuses = set(statuses) items = self.get_by_status(statuses, max=200) if items is None: return {k: 0 for k in statuses} @@ -199,13 +199,13 @@ def handle_row(series): self._upload_items_bulk(self.collection_id, all_items) - def _prepare_item(self, item: Item, collection_id: str): + def _prepare_item(self, item: pystac.Item, collection_id: str): item.collection_id = collection_id if not item.get_links(pystac.RelType.COLLECTION): item.add_link(pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id)) - def _ingest_bulk(self, items: Iterable[Item]) -> dict: + def _ingest_bulk(self, items: List[pystac.Item]) -> dict: collection_id = items[0].collection_id if not all(i.collection_id == collection_id for i in items): raise Exception("All collection IDs should be identical for bulk ingests") @@ -219,7 +219,7 @@ def _ingest_bulk(self, items: Iterable[Item]) -> dict: _check_response_status(response, _EXPECTED_STATUS_POST) return response.json() - def _upload_items_bulk(self, collection_id: str, items: Iterable[Item]) -> None: + def _upload_items_bulk(self, collection_id: str, items: List[pystac.Item]) -> None: chunk = [] futures = [] @@ -246,7 +246,7 @@ def join_url(self, url_path: str) -> str: """ return str(self.base_url + "/" + url_path) - def _create_collection(self, collection: Collection) -> dict: + def _create_collection(self, collection: pystac.Collection) -> dict: """Create a new collection. :param collection: pystac.Collection object to create in the STAC API backend (or upload if you will) @@ -254,7 +254,7 @@ def _create_collection(self, collection: Collection) -> dict: :return: dict that contains the JSON body of the HTTP response. """ - if not isinstance(collection, Collection): + if not isinstance(collection, pystac.Collection): raise TypeError( f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}' ) diff --git a/tests/extra/job_management/test_stac_job_db.py b/tests/extra/job_management/test_stac_job_db.py index 7eeb16c68..69c9d866c 100644 --- a/tests/extra/job_management/test_stac_job_db.py +++ b/tests/extra/job_management/test_stac_job_db.py @@ -21,7 +21,7 @@ def mock_auth(): @pytest.fixture def mock_stac_api_job_database(mock_auth) -> STACAPIJobDatabase: - return STACAPIJobDatabase(collection_id="test_id", stac_root_url="http://fake-stac-api", auth=mock_auth) + return STACAPIJobDatabase(collection_id="test_id", stac_root_url="http://fake-stac-api.test", auth=mock_auth) @pytest.fixture