Skip to content

Commit

Permalink
use s3a links for RandomForestModel._load_java() #562
Browse files Browse the repository at this point in the history
  • Loading branch information
JeroenVerstraelen committed Apr 29, 2024
1 parent e47cd18 commit 8b38fe1
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions openeogeotrellis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,26 +764,27 @@ def load_ml_model(self, model_id: str) -> 'JavaObject':
# Trick to make sure IDE infers right type of `self.batch_jobs` and can resolve `get_job_output_dir`
gps_batch_jobs: GpsBatchJobs = self.batch_jobs

def _create_model_dir():
def _create_model_dir(use_s3=False):
if use_s3:
# s3a://
return f"openeo-ml-models-dev/{generate_unique_id(prefix='model')}"

def _set_permissions(job_dir: Path):
if not ConfigParams().is_kube_deploy:
try:
shutil.chown(job_dir, user = None, group = 'eodata')
except LookupError as e:
logger.warning(f"Could not change group of {job_dir} to eodata.")
try:
shutil.chown(job_dir, user = None, group = 'eodata')
except LookupError as e:
logger.warning(f"Could not change group of {job_dir} to eodata.")
add_permissions(job_dir, stat.S_ISGID | stat.S_IWGRP) # make children inherit this group
ml_models_path = gps_batch_jobs.get_job_output_dir("ml_models")
if not os.path.exists(ml_models_path):
logger.info("Creating directory: {}".format(ml_models_path))
os.makedirs(ml_models_path)
_set_permissions(ml_models_path)
# Use a random id to avoid collisions.
model_dir_path = ml_models_path / generate_unique_id(prefix="model")
if not os.path.exists(model_dir_path):
logger.info("Creating directory: {}".format(model_dir_path))
os.makedirs(model_dir_path)
_set_permissions(model_dir_path)
return str(model_dir_path)

result_dir = gps_batch_jobs.get_job_output_dir("ml_models")
result_path = result_dir / generate_unique_id(prefix="model")
result_dir_exists = os.path.exists(result_dir)
logger.info("Creating directory: {}".format(result_path))
os.makedirs(result_path)
if not result_dir_exists:
_set_permissions(result_dir)
_set_permissions(result_path)
return str(result_path)

if model_id.startswith('http'):
# Load the model using its STAC metadata file.
Expand Down Expand Up @@ -813,15 +814,40 @@ def _set_permissions(job_dir: Path):
model_url = checkpoints[0]["href"]
architecture = metadata["properties"]["ml-model:architecture"]
# Download the model to the ml_models folder and load it as a java object.
model_dir_path = _create_model_dir()
use_s3 = ConfigParams().is_kube_deploy
model_dir_path = _create_model_dir(use_s3)
if architecture == "random-forest":
if use_s3:
with tempfile.TemporaryDirectory() as tmp_dir:
# Download to tmp_dir and unpack it there.
tmp_path = Path(tmp_dir + "/randomforest.model.tar.gz")
logger.info(f"Downloading ml_model from {model_url} to {tmp_path}")
with open(tmp_path, 'wb') as f:
f.write(requests.get(model_url).content)
shutil.unpack_archive(tmp_path, extract_dir = tmp_dir, format = 'gztar')
# Upload the unpacked model to s3.
unpacked_model_path = str(tmp_path).replace(".tar.gz", "")
logger.info(f"Uploading ml_model to {model_dir_path}")
path_split = model_dir_path.split("/")
bucket, key = path_split[0], path_split[1]
s3 = s3_client()
for root, dirs, files in os.walk(unpacked_model_path):
for file in files:
relative_filepath = os.path.relpath(os.path.join(root, file), tmp_dir)
s3.upload_file(os.path.join(root, file), bucket, key + "/" + relative_filepath)
# Load the spark model using the new s3 path.
s3_path = f"s3a://{model_dir_path}/randomforest.model/"
logger.info("Loading ml_model using filename: {}".format(s3_path))
model: JavaObject = RandomForestModel._load_java(sc = gps.get_spark_context(), path = s3_path)
return model
dest_path = Path(model_dir_path + "/randomforest.model.tar.gz")
with open(dest_path, 'wb') as f:
f.write(requests.get(model_url).content)
shutil.unpack_archive(dest_path, extract_dir=model_dir_path, format='gztar')
unpacked_model_path = str(dest_path).replace(".tar.gz", "")
logger.info("Loading ml_model using filename: {}".format(unpacked_model_path))
model: JavaObject = RandomForestModel._load_java(sc=gps.get_spark_context(), path="file:" + unpacked_model_path)
return model
elif architecture == "catboost":
filename = Path(model_dir_path + "/catboost_model.cbm")
with open(filename, 'wb') as f:
Expand Down

0 comments on commit 8b38fe1

Please sign in to comment.