Skip to content

Commit

Permalink
Allow to load signed rsult URL in load_ml_model. #562
Browse files Browse the repository at this point in the history
  • Loading branch information
EmileSonneveld committed Mar 19, 2024
1 parent c95ba2a commit 6ddf018
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions openeogeotrellis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,15 +1190,27 @@ def _set_permissions(job_dir: Path):
with requests.get(model_id) as resp:
resp.raise_for_status()
metadata = resp.json()
if deep_get(metadata, "properties", "ml-model:architecture", default=None) is None:
raise OpenEOApiException(
message=f"{model_id} does not specify a model architecture under properties.ml-model:architecture.",
status_code=400)
checkpoints = []
assets = metadata.get('assets', {})
for asset in assets:
if "ml-model:checkpoint" in assets[asset].get('roles', []):
checkpoints.append(assets[asset])
if "ml_model_metadata.json" in model_id:
architecture = deep_get(metadata, "properties", "ml-model:architecture", default=None)
if architecture is None:
raise OpenEOApiException(
message=f"{model_id} does not specify a model architecture under properties.ml-model:architecture.",
status_code=400)
checkpoints = []
assets = metadata.get('assets', {})
for asset in assets:
if "ml-model:checkpoint" in assets[asset].get('roles', []):
checkpoints.append(assets[asset])
else:
architecture = deep_get(metadata, "summaries", "ml-model:architecture", 0, None)
if architecture is None:
raise OpenEOApiException(
message=f"{model_id} does not specify a model architecture under summaries.ml-model:architecture[0].",
status_code=400)
checkpoints = [
metadata.get("assets").get("randomforest.model.tar.gz")
]

if len(checkpoints) == 0 or checkpoints[0].get("href", None) is None:
raise OpenEOApiException(
message=f"{model_id} does not contain a link to the ml model in its assets section.",
Expand All @@ -1208,10 +1220,8 @@ def _set_permissions(job_dir: Path):
raise OpenEOApiException(
message=f"{model_id} contains multiple checkpoints.",
status_code=400)

# Get the url for the actual model from the STAC metadata.
model_url = checkpoints[0]["href"]
architecture = metadata["properties"]["ml-model:architecture"]
# Download the model to the ml_models folder and load it as a java object.
model_dir_path = _create_model_dir()
if architecture == "random-forest":
Expand Down

0 comments on commit 6ddf018

Please sign in to comment.