Skip to content

Commit

Permalink
Feature/sagemaker compiler (#662)
Browse files Browse the repository at this point in the history
resolves: #661
  • Loading branch information
GeorgesLorre authored Nov 23, 2023
1 parent eb27d99 commit c01dd02
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ adlfs = { version = ">= 2023.4.0", optional = true }
docker = {version = ">= 6.1.3", optional = true }
kfp = { version = "2.3.0", optional = true, extras =["kubernetes"] }
google-cloud-aiplatform = { version = "1.34.0", optional = true}
sagemaker = {version = ">= 2.197.0", optional = true}

[tool.poetry.extras]
component = ["dask"]
Expand All @@ -68,6 +69,7 @@ gcp = ["gcsfs"]

kfp = ["docker", "kfp"]
vertex = ["docker", "kfp", "google-cloud-aiplatform"]
sagemaker = ["sagemaker"]
docker = ["docker"]

[tool.poetry.group.test.dependencies]
Expand Down
169 changes: 169 additions & 0 deletions src/fondant/pipeline/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import os
import tempfile
import typing as t
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -515,3 +517,170 @@ def _set_configuration(self, task, fondant_component_operation):
task.set_accelerator_type(accelerator_name)

return task


class SagemakerCompiler(Compiler):
def __init__(self):
self._resolve_imports()

def _resolve_imports(self):
try:
import sagemaker
import sagemaker.processing
import sagemaker.workflow.pipeline
import sagemaker.workflow.steps

self.sagemaker = sagemaker

except ImportError:
msg = """You need to install the sagemaker extras to use the sagemaker compiler,\n
you can install it with `pip install fondant[sagemaker]`"""
raise ImportError(
msg,
)

def _get_build_command(
self,
metadata: Metadata,
arguments: t.Dict[str, t.Any],
dependencies: t.List[str] = [],
) -> t.List[str]:
# add metadata argument to command
command = ["--metadata", f"'{metadata.to_json()}'"]

# add in and out manifest paths to command
command.extend(
[
"--output_manifest_path",
f"{metadata.base_path}/{metadata.pipeline_name}/{metadata.run_id}/"
f"{metadata.component_id}/manifest.json",
],
)

# add arguments if any to command
for key, value in arguments.items():
if isinstance(value, (dict, list)):
command.extend([f"--{key}", f"'{json.dumps(value)}'"])
else:
command.extend([f"--{key}", f"'{value}'"])

# resolve dependencies
if dependencies:
for dependency in dependencies:
# there is only an input manifest if the component has dependencies
command.extend(
[
"--input_manifest_path",
f"{metadata.base_path}/{metadata.pipeline_name}/{metadata.run_id}/"
f"{dependency}/manifest.json",
],
)

return command

def compile(
self,
pipeline: Pipeline,
output_path: str,
*,
instance_type: str = "ml.t3.medium",
role_arn: t.Optional[str] = None,
) -> None:
"""Compile a fondant pipeline to sagemaker pipeline spec and save it
to a specified output path.
Args:
pipeline: the pipeline to compile
output_path: the path where to save the Kubeflow pipeline spec.
instance_type: the instance type to use for the processing steps
(see: https://aws.amazon.com/ec2/instance-types/ for options).
role_arn: the Amazon Resource Name role to use for the processing steps,
if none provided the `sagemaker.get_execution_role()` role will be used.
"""
run_id = pipeline.get_run_id()
path = pipeline.base_path
pipeline.validate(run_id=run_id)

component_cache_key = None

steps: t.List[t.Any] = []

with tempfile.TemporaryDirectory(dir=os.getcwd()) as tmpdirname:
for component_name, component in pipeline._graph.items():
component_op = component["fondant_component_op"]
component_cache_key = component_op.get_component_cache_key(
previous_component_cache=component_cache_key,
)

metadata = Metadata(
pipeline_name=pipeline.name,
run_id=run_id,
base_path=path,
component_id=component_name,
cache_key=component_cache_key,
)

logger.info(f"Compiling service for {component_name}")

command = self._get_build_command(
metadata,
component_op.arguments,
component["dependencies"],
)
depends_on = [steps[-1]] if component["dependencies"] else []

script_path = self.generate_component_script(
component_name,
command,
tmpdirname,
)

if not role_arn:
# if no role is provided use the default sagemaker execution role
# https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-role.html
role_arn = self.sagemaker.get_execution_role()

processor = self.sagemaker.processing.ScriptProcessor(
image_uri=component_op.component_spec.image,
command=["bash"],
instance_type=instance_type,
instance_count=1,
base_job_name=component_name,
role=role_arn,
)
step = self.sagemaker.workflow.steps.ProcessingStep(
name=component_name,
processor=processor,
depends_on=depends_on,
code=script_path,
)
steps.append(step)

sagemaker_pipeline = self.sagemaker.workflow.pipeline.Pipeline(
name=pipeline.name,
steps=steps,
)
with open(output_path, "w") as outfile:
json.dump(
json.loads(sagemaker_pipeline.definition()),
outfile,
indent=4,
)

def _set_configuration(self, *args, **kwargs) -> None:
raise NotImplementedError

def generate_component_script(
self,
component_name: str,
command: t.List[str],
directory: str,
) -> str:
"""Generate a bash script for a component to be used as input in a
sagemaker pipeline step. Returns the path to the script.
"""
content = " ".join(["fondant", "execute", "main", *command])

with open(f"{directory}/{component_name}.sh", "w") as f:
f.write(content)
return f"{directory}/{component_name}.sh"
51 changes: 51 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import pytest
from fondant.core.exceptions import InvalidPipelineDefinition
from fondant.core.manifest import Metadata
from fondant.pipeline import ComponentOp, Pipeline, Resources
from fondant.pipeline.compiler import (
DockerCompiler,
KubeFlowCompiler,
SagemakerCompiler,
VertexCompiler,
)
from fondant.testing import (
Expand Down Expand Up @@ -605,3 +607,52 @@ def test_caching_dependency_kfp(tmp_path_factory):
second_component_cache_key_dict[arg_list[0]]
!= second_component_cache_key_dict[arg_list[1]]
)


def test_sagemaker_build_command():
compiler = SagemakerCompiler()
metadata = Metadata(
pipeline_name="example_pipeline",
base_path="/foo/bar",
component_id="component_2",
run_id="example_pipeline_2024",
cache_key="42",
)
args = {"foo": "bar", "baz": "qux"}
command = compiler._get_build_command(metadata, args)

assert command == [
"--metadata",
'\'{"base_path": "/foo/bar", "pipeline_name": "example_pipeline", '
'"run_id": "example_pipeline_2024", "component_id": "component_2", '
'"cache_key": "42"}\'',
"--output_manifest_path",
"/foo/bar/example_pipeline/example_pipeline_2024/component_2/manifest.json",
"--foo",
"'bar'",
"--baz",
"'qux'",
]

# with dependencies
dependencies = ["component_1"]

command2 = compiler._get_build_command(metadata, args, dependencies=dependencies)

assert command2 == [
*command,
"--input_manifest_path",
"/foo/bar/example_pipeline/example_pipeline_2024/component_1/manifest.json",
]


def test_sagemaker_generate_script(tmp_path_factory):
compiler = SagemakerCompiler()
command = ["echo", "hello world"]
with tmp_path_factory.mktemp("temp") as fn:
script_path = compiler.generate_component_script("component_1", command, fn)

assert script_path == f"{fn}/component_1.sh"

with open(script_path) as f:
assert f.read() == "fondant execute main echo hello world"

0 comments on commit c01dd02

Please sign in to comment.