Skip to content

Commit

Permalink
allow multi step configuration for skypilot
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Dec 19, 2023
1 parent 1097444 commit 27fc323
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ To use the SkyPilot VM Orchestrator, you need:
* One of the SkyPilot integrations installed. You can install the SkyPilot integration for your cloud provider of choice using the following command:
```shell
# For AWS
pip install "zenml[connectors-gcp]"
zenml integration install aws vm_aws
pip install "zenml[connectors-aws]"
zenml integration install aws skypilot_aws

# for GCP
pip install "zenml[connectors-gcp]"
zenml integration install gcp vm_gcp # for GCP
zenml integration install gcp skypilot_gcp # for GCP

# for Azure
pip install "zenml[connectors-azure]"
zenml integration install azure vm_azure # for Azure
zenml integration install azure skypilot_azure # for Azure
```
* [Docker](https://www.docker.com) installed and running.
* A [remote artifact store](../artifact-stores/artifact-stores.md) as part of your stack.
Expand Down Expand Up @@ -371,6 +371,38 @@ skypilot_settings = SkypilotAzureOrchestratorSettings(
{% endtab %}
{% endtabs %}


One of the key features of the SkyPilot VM Orchestrator is the ability to run each step of a pipeline on a separate VM with its own specific settings. This allows for fine-grained control over the resources allocated to each step, ensuring that each part of your pipeline has the necessary compute power while optimizing for cost and efficiency.

## Configuring Step-Specific Resources

The SkyPilot VM Orchestrator allows you to configure resources for each step individually. This means you can specify different VM types, CPU and memory requirements, and even use spot instances for certain steps while using on-demand instances for others.

To configure step-specific resources, you can pass a `SkypilotBaseOrchestratorSettings` object to the `settings` parameter of the `@step` decorator. This object allows you to define various attributes such as `instance_type`, `cpus`, `memory`, `use_spot`, `region`, and more.

Here's an example of how to configure specific resources for a step for the AWS cloud:

```python
from zenml.integrations.skypilot.flavors.skypilot_orchestrator_aws_vm_flavor import SkypilotAWSOrchestratorSettings

# Settings for a specific step that requires more resources
high_resource_settings = SkypilotAWSOrchestratorSettings(
instance_type='t2.2xlarge',
cpus=8,
memory=32,
use_spot=False,
region='us-east-1',
# ... other settings
)

@step(settings={"orchestrator.vm_aws": high_resource_settings})
def my_resource_intensive_step():
# Step implementation
pass
```

By using the `settings` parameter, you can tailor the resources for each step according to its specific needs. This flexibility allows you to optimize your pipeline execution for both performance and cost.

Check out
the [SDK docs](https://sdkdocs.zenml.io/latest/integration\_code\_docs/integrations-skypilot/#zenml.integrations.skypilot.flavors.skypilot\_orchestrator\_base\_vm\_flavor.SkypilotBaseOrchestratorSettings)
for a full list of available attributes and [this docs page](/docs/book/user-guide/advanced-guide/pipelining-features/configure-steps-pipelines.md) for more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
"""Implementation of the Skypilot base VM orchestrator."""

import os
import time
import re
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from uuid import uuid4

import sky

from zenml.client import Client
from zenml.entrypoints import PipelineEntrypointConfiguration
from zenml.entrypoints import StepEntrypointConfiguration
from zenml.enums import StackComponentType
from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import (
SkypilotBaseOrchestratorSettings,
Expand All @@ -33,7 +33,6 @@
)
from zenml.orchestrators import utils as orchestrator_utils
from zenml.stack import StackValidator
from zenml.utils import string_utils

if TYPE_CHECKING:
from zenml.models import PipelineDeploymentResponse
Expand Down Expand Up @@ -138,7 +137,7 @@ def prepare_or_run_pipeline(
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Runs all pipeline steps in Skypilot containers.
"""Runs each pipeline step in a separate Skypilot container.
Args:
deployment: The pipeline deployment to prepare or run.
Expand All @@ -161,32 +160,18 @@ def prepare_or_run_pipeline(
environment[
ENV_ZENML_SKYPILOT_ORCHESTRATOR_RUN_ID
] = orchestrator_run_id

settings = cast(
SkypilotBaseOrchestratorSettings,
self.get_settings(deployment),
)

entrypoint = PipelineEntrypointConfiguration.get_entrypoint_command()
entrypoint_str = " ".join(entrypoint)
arguments = PipelineEntrypointConfiguration.get_entrypoint_arguments(
deployment_id=deployment.id
)
arguments_str = " ".join(arguments)

# Set up docker run command
image = self.get_image(deployment=deployment)
docker_environment_str = " ".join(
f"-e {k}={v}" for k, v in environment.items()
)

start_time = time.time()

instance_type = settings.instance_type or self.DEFAULT_INSTANCE_TYPE
# Set up some variables for configuration
orchestrator_run_id = str(uuid4())
environment[
ENV_ZENML_SKYPILOT_ORCHESTRATOR_RUN_ID
] = orchestrator_run_id

# Set up credentials
self.setup_credentials()

# Set the service connector AWS profile ENV variable
self.prepare_environment_variable(set=True)

# Guaranteed by stack validation
assert stack is not None and stack.container_registry is not None

Expand All @@ -205,71 +190,145 @@ def prepare_or_run_pipeline(
setup = None
task_envs = None

# Run the entire pipeline

# Set the service connector AWS profile ENV variable
self.prepare_environment_variable(set=True)

try:
task = sky.Task(
run=f"docker run --rm {docker_environment_str} {image} {entrypoint_str} {arguments_str}",
setup=setup,
envs=task_envs,
unique_resource_configs: Dict[str, List[str]] = {}
for step_name, step in deployment.step_configurations.items():
settings = cast(
SkypilotBaseOrchestratorSettings,
self.get_settings(step),
)
task = task.set_resources(
sky.Resources(
cloud=self.cloud,
instance_type=instance_type,
cpus=settings.cpus,
memory=settings.memory,
accelerators=settings.accelerators,
accelerator_args=settings.accelerator_args,
use_spot=settings.use_spot,
spot_recovery=settings.spot_recovery,
region=settings.region,
zone=settings.zone,
image_id=settings.image_id,
disk_size=settings.disk_size,
disk_tier=settings.disk_tier,
# Handle both str and Dict[str, int] types for accelerators
if isinstance(settings.accelerators, dict):
accelerators_hashable = frozenset(
settings.accelerators.items()
)
elif isinstance(settings.accelerators, str):
accelerators_hashable = frozenset({(settings.accelerators, 1)})
else:
accelerators_hashable = None
resource_config = (
settings.instance_type,
settings.cpus,
settings.memory,
settings.disk_size, # Assuming disk_size is part of the settings
settings.disk_tier, # Assuming disk_tier is part of the settings
settings.use_spot,
settings.spot_recovery,
settings.region,
settings.zone,
accelerators_hashable,
)
cluster_name_parts = [
self.sanitize_for_cluster_name(part)
for part in resource_config
if part is not None
]
cluster_name = "cluster-" + "-".join(cluster_name_parts)
unique_resource_configs.setdefault(cluster_name, []).append(
step_name
)

cluster_name = settings.cluster_name
if cluster_name is None:
# Find existing cluster
for i in sky.status(refresh=True):
if isinstance(
i["handle"].launched_resources.cloud, type(self.cloud)
):
cluster_name = i["handle"].cluster_name
logger.info(
f"Found existing cluster {cluster_name}. Reusing..."
)
# Process each unique resource configuration
for cluster_name, steps in unique_resource_configs.items():
# Process each step in the resource configuration
for step_name in steps:
step = deployment.step_configurations[step_name]
settings = cast(
SkypilotBaseOrchestratorSettings,
self.get_settings(step),
)
# Use the step-specific entrypoint configuration
entrypoint = (
StepEntrypointConfiguration.get_entrypoint_command()
)
entrypoint_str = " ".join(entrypoint)
arguments = (
StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
arguments_str = " ".join(arguments)

# Launch the cluster
sky.launch(
task,
cluster_name,
retry_until_up=settings.retry_until_up,
idle_minutes_to_autostop=settings.idle_minutes_to_autostop,
down=settings.down,
stream_logs=settings.stream_logs,
)
# Set up docker run command
image = self.get_image(
deployment=deployment, step_name=step_name
)
docker_environment_str = " ".join(
f"-e {k}={v}" for k, v in environment.items()
)

except Exception as e:
raise e
# Set up the task
task = sky.Task(
run=f"docker run --rm {docker_environment_str} {image} {entrypoint_str} {arguments_str}",
setup=setup,
envs=task_envs,
)
task = task.set_resources(
sky.Resources(
cloud=self.cloud,
instance_type=settings.instance_type
or self.DEFAULT_INSTANCE_TYPE,
cpus=settings.cpus,
memory=settings.memory,
disk_size=settings.disk_size,
disk_tier=settings.disk_tier,
accelerators=settings.accelerators,
accelerator_args=settings.accelerator_args,
use_spot=settings.use_spot,
spot_recovery=settings.spot_recovery,
region=settings.region,
zone=settings.zone,
image_id=settings.image_id,
)
)

finally:
# Unset the service connector AWS profile ENV variable
self.prepare_environment_variable(set=False)
# Launch the cluster
try:
sky.launch(
task,
cluster_name,
retry_until_up=settings.retry_until_up,
idle_minutes_to_autostop=settings.idle_minutes_to_autostop,
down=settings.down,
stream_logs=settings.stream_logs,
)
except Exception as e:
# If there's a resource mismatch, check if this is the last step using this configuration
if steps[-1] == step_name:
logger.warning(
f"Resource mismatch for cluster '{cluster_name}': {e}. "
"Attempting to down the existing cluster."
)
sky.down(cluster_name)
# Retry launching the task after bringing down the cluster
sky.launch(
task,
cluster_name,
retry_until_up=settings.retry_until_up,
idle_minutes_to_autostop=settings.idle_minutes_to_autostop,
down=settings.down,
stream_logs=settings.stream_logs,
)
else:
logger.warning(
f"Resource mismatch for cluster '{cluster_name}', but "
"the cluster will be used for subsequent steps. "
"Skipping the downing of the cluster: {e}."
)

run_duration = time.time() - start_time
# Unset the service connector AWS profile ENV variable
self.prepare_environment_variable(set=False)

# Log the completion of the pipeline run
run_id = orchestrator_utils.get_run_id_for_orchestrator_run_id(
orchestrator=self, orchestrator_run_id=orchestrator_run_id
)
run_model = Client().zen_store.get_run(run_id)
logger.info(
"Pipeline run `%s` has finished in `%s`.\n",
"Pipeline run `%s` has finished.\n",
run_model.name,
string_utils.get_human_readable_time(run_duration),
)

def sanitize_for_cluster_name(self, value: Any) -> str:
"""Sanitize the value to be used in a cluster name."""
# Convert to string, make lowercase, and replace invalid characters with dashes
return re.sub(r"[^a-z0-9]", "-", str(value).lower())

0 comments on commit 27fc323

Please sign in to comment.