Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable running a single step on the active stack #2942

Merged
merged 12 commits into from
Aug 27, 2024
3 changes: 3 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
f"{ENV_ZENML_SERVER_PREFIX}USE_LEGACY_DASHBOARD"
)
ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE"
ENV_ZENML_DISABLE_RUNNING_SINGLE_STEPS_ON_STACK = (
"ZENML_DISABLE_RUNNING_SINGLE_STEPS_ON_STACK"
)

# Logging variables
IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False)
Expand Down
83 changes: 79 additions & 4 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
from zenml.client_lazy_loader import ClientLazyLoader
from zenml.config.retry_config import StepRetryConfig
from zenml.config.source import Source
from zenml.constants import STEP_SOURCE_PARAMETER_NAME
from zenml.constants import (
ENV_ZENML_DISABLE_RUNNING_SINGLE_STEPS_ON_STACK,
STEP_SOURCE_PARAMETER_NAME,
handle_bool_env_var,
)
from zenml.exceptions import MissingStepParameterError, StepInterfaceError
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
Expand Down Expand Up @@ -586,9 +590,16 @@ def __call__(
from zenml.new.pipelines.pipeline import Pipeline

if not Pipeline.ACTIVE_PIPELINE:
# The step is being called outside the context of a pipeline,
# we simply call the entrypoint
return self.call_entrypoint(*args, **kwargs)
# The step is being called outside the context of a pipeline, either
# run the step function or run it as a single step pipeline on the
# active stack
run_as_single_step_pipeline = not handle_bool_env_var(
ENV_ZENML_DISABLE_RUNNING_SINGLE_STEPS_ON_STACK, default=False
)
if run_as_single_step_pipeline:
return self.run_as_single_step_pipeline(*args, **kwargs)
else:
return self.call_entrypoint(*args, **kwargs)

(
input_artifacts,
Expand Down Expand Up @@ -660,6 +671,70 @@ def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:

return self.entrypoint(**validated_args)

def run_as_single_step_pipeline(self, *args: Any, **kwargs: Any) -> Any:
logger.info(
"Running single step pipeline to execute step `%s`", self.name
)
from zenml import ExternalArtifact

bound_args = inspect.signature(self.entrypoint).bind(*args, **kwargs)
# bound_args.apply_defaults()

inputs = {}
for key, value in bound_args.arguments.items():
try:
self.entrypoint_definition.validate_input(key=key, value=value)
inputs[key] = value
except Exception:
inputs[key] = ExternalArtifact(value=value)

# 2. Create single-step pipeline
from zenml import pipeline
from zenml.client import Client

orchestrator = Client().active_stack.orchestrator

pipeline_settings = {}
if "synchronous" in orchestrator.config.model_fields:
# Make sure the orchestrator runs sync so we stream the logs
key = settings_utils.get_stack_component_setting_key(orchestrator)
pipeline_settings[key] = {"synchronous": True}

@pipeline(enable_cache=False, settings=pipeline_settings)
def single_step_pipeline():
self(**inputs)

# 3. Run pipeline
from zenml.enums import ExecutionStatus
from zenml.new.pipelines.run_utils import (
wait_for_pipeline_run_to_finish,
)

single_step_pipeline = single_step_pipeline.with_options(unlisted=True)
try:
run = single_step_pipeline()
except Exception as e:
raise RuntimeError("Failed to execute step %s.", self.name) from e

run = wait_for_pipeline_run_to_finish(run.id)

if run.status != ExecutionStatus.COMPLETED:
raise RuntimeError("Failed to execute step %s.", self.name)

# 4. Load output artifacts
step_run = next(iter(run.steps.values()))
outputs = [
step_run.outputs[output_name].load()
for output_name in step_run.config.outputs.keys()
]

if len(outputs) == 0:
return None
elif len(outputs) == 1:
return outputs[0]
else:
return tuple(outputs)

@property
def name(self) -> str:
"""The name of the step.
Expand Down
Loading