Skip to content

Commit

Permalink
fixup! Issue #683/#681 align additional/job_options argument in creat…
Browse files Browse the repository at this point in the history
…e_job, download, ...
  • Loading branch information
soxofaan committed Dec 10, 2024
1 parent 0076152 commit 5a83763
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
27 changes: 26 additions & 1 deletion openeo/rest/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ class DummyBackend:
"""

# TODO: move to openeo.testing
# TODO: unify "batch_jobs", "batch_jobs_full" and "extra_job_metadata_fields"?
# TODO: unify "sync_requests" and "sync_requests_full"?

__slots__ = (
"_requests_mock",
"connection",
"file_formats",
"sync_requests",
"sync_requests_full",
"batch_jobs",
"batch_jobs_full",
"validation_requests",
"next_result",
"next_validation_errors",
Expand All @@ -60,7 +64,9 @@ def __init__(
self.connection = connection
self.file_formats = {"input": {}, "output": {}}
self.sync_requests = []
self.sync_requests_full = []
self.batch_jobs = {}
self.batch_jobs_full = {}
self.validation_requests = []
self.next_result = self.DEFAULT_RESULT
self.next_validation_errors = []
Expand Down Expand Up @@ -163,7 +169,9 @@ def setup_file_format(self, name: str, type: str = "output", gis_data_types: Ite

def _handle_post_result(self, request, context):
"""handler of `POST /result` (synchronous execute)"""
pg = request.json()["process"]["process_graph"]
post_data = request.json()
pg = post_data["process"]["process_graph"]
self.sync_requests_full.append(post_data)
self.sync_requests.append(pg)
result = self.next_result
if isinstance(result, (dict, list)):
Expand All @@ -185,6 +193,10 @@ def _handle_post_jobs(self, request, context):
job_id = f"job-{len(self.batch_jobs):03d}"
assert job_id not in self.batch_jobs

# Full post data dump
self.batch_jobs_full[job_id] = post_data

# Batch job essentials
job_data = {"job_id": job_id, "pg": pg, "status": "created"}
for field in ["title", "description"]:
if field in post_data:
Expand Down Expand Up @@ -272,6 +284,11 @@ def get_sync_pg(self) -> dict:
assert len(self.sync_requests) == 1
return self.sync_requests[0]

def get_sync_post_data(self) -> dict:
"""Get post data of the one and only synchronous job."""
assert len(self.sync_requests_full) == 1
return self.sync_requests_full[0]

def get_batch_pg(self) -> dict:
"""
Get process graph of the one and only batch job.
Expand All @@ -280,6 +297,14 @@ def get_batch_pg(self) -> dict:
assert len(self.batch_jobs) == 1
return self.batch_jobs[max(self.batch_jobs.keys())]["pg"]

def get_batch_post_data(self) -> dict:
"""
Get post data of the one and only batch job.
Fails when there is none or more than one.
"""
assert len(self.batch_jobs_full) == 1
return self.batch_jobs_full[max(self.batch_jobs_full.keys())]

def get_validation_pg(self) -> dict:
"""
Get process graph of the one and only validation request.
Expand Down
2 changes: 2 additions & 0 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,8 @@ def _build_request_with_process_graph(
if additional:
result.update(additional)
if job_options is not None:
# Note: this "job_options" top-level property is not in official openEO API spec,
# but a commonly used convention, e.g. in openeo-python-driver based deployments.
assert "job_options" not in result
result["job_options"] = job_options

Expand Down
112 changes: 112 additions & 0 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shapely.geometry

import openeo
from openeo import BatchJob
from openeo.capabilities import ApiVersionException
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension
Expand Down Expand Up @@ -2902,6 +2903,58 @@ def flat_graph(self) -> typing.Dict[str, dict]:
return {"foo1": {"process_id": "foo"}}


@pytest.mark.parametrize(
"pg",
[
{"foo1": {"process_id": "foo"}},
{"process_graph": {"foo1": {"process_id": "foo"}}},
DummyFlatGraphable(),
],
)
def test_create_job(dummy_backend, pg):
job = dummy_backend.connection.create_job(pg)
assert isinstance(job, BatchJob)
assert dummy_backend.get_batch_pg() == {"foo1": {"process_id": "foo"}}


def test_create_job_with_additional(dummy_backend):
job = dummy_backend.connection.create_job(
{"foo1": {"process_id": "foo"}},
additional={"color": "green"},
)
assert isinstance(job, BatchJob)
assert dummy_backend.get_batch_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"color": "green",
}


def test_create_job_with_job_options(dummy_backend):
job = dummy_backend.connection.create_job(
{"foo1": {"process_id": "foo"}},
job_options={"color": "green"},
)
assert isinstance(job, BatchJob)
assert dummy_backend.get_batch_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"job_options": {"color": "green"},
}


def test_create_job_with_additional_and_job_options(dummy_backend):
job = dummy_backend.connection.create_job(
{"foo1": {"process_id": "foo"}},
additional={"color": "blue"},
job_options={"color": "green"},
)
assert isinstance(job, BatchJob)
assert dummy_backend.get_batch_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"color": "blue",
"job_options": {"color": "green"},
}


@pytest.mark.parametrize(
"pg",
[
Expand All @@ -2928,6 +2981,65 @@ def test_download_100(requests_mock, pg):
]


@pytest.mark.parametrize(
"pg",
[
{"foo1": {"process_id": "foo"}},
{"process_graph": {"foo1": {"process_id": "foo"}}},
DummyFlatGraphable(),
],
)
def test_download(dummy_backend, pg, tmp_path):
output_path = tmp_path / "result.data"
dummy_backend.connection.download(pg, output_path)
assert output_path.read_bytes() == b'{"what?": "Result data"}'
assert dummy_backend.get_sync_pg() == {"foo1": {"process_id": "foo"}}


def test_download_with_additional(dummy_backend, tmp_path):
output_path = tmp_path / "result.data"
dummy_backend.connection.download(
{"foo1": {"process_id": "foo"}},
output_path,
additional={"color": "green"},
)
assert output_path.read_bytes() == b'{"what?": "Result data"}'
assert dummy_backend.get_sync_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"color": "green",
}


def test_download_with_job_options(dummy_backend, tmp_path):
output_path = tmp_path / "result.data"
dummy_backend.connection.download(
{"foo1": {"process_id": "foo"}},
output_path,
job_options={"color": "green"},
)
assert output_path.read_bytes() == b'{"what?": "Result data"}'
assert dummy_backend.get_sync_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"job_options": {"color": "green"},
}


def test_download_with_additional_and_job_options(dummy_backend, tmp_path):
output_path = tmp_path / "result.data"
dummy_backend.connection.download(
{"foo1": {"process_id": "foo"}},
output_path,
additional={"color": "blue"},
job_options={"color": "green"},
)
assert output_path.read_bytes() == b'{"what?": "Result data"}'
assert dummy_backend.get_sync_post_data() == {
"process": {"process_graph": {"foo1": {"process_id": "foo"}}},
"color": "blue",
"job_options": {"color": "green"},
}


@pytest.mark.parametrize(
"pg",
[
Expand Down

0 comments on commit 5a83763

Please sign in to comment.