Skip to content

Commit

Permalink
Add additional steps to download merlin artifacts from s3 storage loc…
Browse files Browse the repository at this point in the history
…ally
  • Loading branch information
deadlycoconuts committed Nov 7, 2024
1 parent 9e78611 commit c15b98d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
4 changes: 2 additions & 2 deletions python/sdk/merlin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from merlin.transformer import Transformer
from merlin.util import (
autostr,
download_files_from_gcs,
download_files_from_blob_storage,
extract_optional_value_with_default,
guess_mlp_ui_url,
valid_name_check,
Expand Down Expand Up @@ -956,7 +956,7 @@ def download_artifact(self, destination_path):
if artifact_uri is None or artifact_uri == "":
raise Exception("There is no artifact uri for this model version")

download_files_from_gcs(artifact_uri, destination_path)
download_files_from_blob_storage(artifact_uri, destination_path)

def log_artifacts(self, local_dir, artifact_path=None):
"""
Expand Down
56 changes: 40 additions & 16 deletions python/sdk/merlin/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import re
import os
import boto3
from urllib.parse import urlparse
from google.cloud import storage
from os.path import dirname
Expand Down Expand Up @@ -66,6 +67,11 @@ def valid_name_check(input_name: str) -> bool:
return matching_group == input_name


def get_blob_storage_schema(artifact_uri: str) -> str:
parsed_result = urlparse(artifact_uri)
return parsed_result.scheme


def get_bucket_name(gcs_uri: str) -> str:
parsed_result = urlparse(gcs_uri)
return parsed_result.netloc
Expand All @@ -76,24 +82,42 @@ def get_gcs_path(gcs_uri: str) -> str:
return parsed_result.path.strip("/")


def download_files_from_gcs(gcs_uri: str, destination_path: str):
def download_files_from_blob_storage(artifact_uri: str, destination_path: str):
makedirs(destination_path, exist_ok=True)

client = storage.Client()
bucket_name = get_bucket_name(gcs_uri)
path = get_gcs_path(gcs_uri)

bucket = client.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=path)
for blob in blobs:
# Get only the path after .../artifacts/model
# E.g.
# Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb
# we only want to extract 1/saved_model.pb
artifact_path = os.path.join(*blob.name.split("/")[5:])
dir = os.path.join(destination_path, dirname(artifact_path))
makedirs(dir, exist_ok=True)
blob.download_to_filename(os.path.join(destination_path, artifact_path))
storage_schema = get_blob_storage_schema(artifact_uri)
bucket_name = get_bucket_name(artifact_uri)
path = get_gcs_path(artifact_uri)

if storage_schema == "gs":
client = storage.Client()
bucket = client.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=path)
for blob in blobs:
# Get only the path after .../artifacts/model
# E.g.
# Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb
# we only want to extract 1/saved_model.pb
artifact_path = os.path.join(*blob.name.split("/")[5:])
dir = os.path.join(destination_path, dirname(artifact_path))
makedirs(dir, exist_ok=True)
blob.download_to_filename(os.path.join(destination_path, artifact_path))
elif storage_schema == "s3":
client = boto3.client("s3")
bucket = client.list_objects_v2(Prefix=path, Bucket=bucket_name)["Contents"]
for s3_object in bucket:
# we do this because the list_objects_v2 method lists all subdirectories in addition to files
if not s3_object['Key'].endswith('/'):
# Get only the path after .../artifacts/model
# E.g.
# Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb
# we only want to extract 1/saved_model.pb
object_paths = s3_object['Key'].split("/")[5:]
if len(object_paths) != 0:
artifact_path = os.path.join(*object_paths)
os.makedirs(os.path.join(destination_path, dirname(artifact_path)), exist_ok=True)
client.download_file(bucket_name, s3_object['Key'], os.path.join(destination_path, artifact_path))


def extract_optional_value_with_default(opt: Optional[Any], default: Any) -> Any:
if opt is not None:
Expand Down

0 comments on commit c15b98d

Please sign in to comment.