Skip to content

Commit

Permalink
add a kubernetes path for loading cataboost models #562
Browse files Browse the repository at this point in the history
  • Loading branch information
JeroenVerstraelen committed Apr 29, 2024
1 parent 8b38fe1 commit 8115c38
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions openeogeotrellis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,16 @@ def _set_permissions(job_dir: Path):
model: JavaObject = RandomForestModel._load_java(sc=gps.get_spark_context(), path="file:" + unpacked_model_path)
return model
elif architecture == "catboost":
if use_s3:
# TODO: Verify that local files work. If it does, we can remove the model_dir_path implementation.
# Download the model to the tmp directory and load it as a java object.
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir + "/catboost_model.cbm")
logger.info(f"Downloading ml_model from {model_url} to {tmp_path}")
with open(tmp_path, 'wb') as f:
f.write(requests.get(model_url).content)
model: JavaObject = CatBoostClassificationModel.load_native_model(tmp_path)
return model
filename = Path(model_dir_path + "/catboost_model.cbm")
with open(filename, 'wb') as f:
f.write(requests.get(model_url).content)
Expand Down

0 comments on commit 8115c38

Please sign in to comment.