Skip to content

Commit

Permalink
Make run_with_accelerate a pythonic decorator (#2943)
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov authored Aug 27, 2024
1 parent 9b11e5d commit 35813b1
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 106 deletions.
192 changes: 97 additions & 95 deletions src/zenml/integrations/huggingface/steps/accelerate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"""Step function to run any ZenML step using Accelerate."""

import functools
import inspect
from typing import Any, Callable, Dict, TypeVar, cast
from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast

import cloudpickle as pickle
from accelerate.commands.launch import ( # type: ignore[import-untyped]
launch_command,
launch_command_parser,
)

from zenml import get_pipeline_context
from zenml.logger import get_logger
from zenml.steps import BaseStep
from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script
Expand All @@ -35,28 +35,31 @@


def run_with_accelerate(
step_function: BaseStep,
step_function_top_level: Optional[BaseStep] = None,
**accelerate_launch_kwargs: Any,
) -> BaseStep:
) -> Union[Callable[[BaseStep], BaseStep], BaseStep]:
"""Run a function with accelerate.
Accelerate package: https://huggingface.co/docs/accelerate/en/index
Example:
```python
from zenml import step, pipeline
from zenml.integrations.hugginface.steps import run_with_accelerate
@run_with_accelerate(num_processes=4, multi_gpu=True)
@step
def training_step(some_param: int, ...):
# your training code is below
...
@pipeline
def training_pipeline(some_param: int, ...):
run_with_accelerate(training_step, num_processes=4)(some_param, ...)
training_step(some_param, ...)
```
Args:
step_function: The step function to run.
step_function_top_level: The step function to run with accelerate [optional].
Used in functional calls like `run_with_accelerate(some_func,foo=bar)()`.
accelerate_launch_kwargs: A dictionary of arguments to pass along to the
`accelerate launch` command, including hardware selection, resource
allocation, and training paradigm options. Visit
Expand All @@ -65,100 +68,99 @@ def training_pipeline(some_param: int, ...):
Returns:
The accelerate-enabled version of the step.
Raises:
RuntimeError: If the decorator is misused.
"""

def _decorator(
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
) -> F:
@functools.wraps(entrypoint)
def inner(*args: Any, **kwargs: Any) -> Any:
if args:
raise ValueError(
"Accelerated steps do not support positional arguments."
)

with create_cli_wrapped_script(
entrypoint, flavor="accelerate"
) as (
script_path,
output_path,
):
commands = [str(script_path.absolute())]
for k, v in kwargs.items():
k = _cli_arg_name(k)
if isinstance(v, bool):
if v:
commands.append(f"--{k}")
elif type(v) in (list, tuple, set):
for each in v:
commands += [f"--{k}", f"{each}"]
else:
commands += [f"--{k}", f"{v}"]
logger.debug(commands)

parser = launch_command_parser()
args = parser.parse_args(commands)
for k, v in accelerate_launch_kwargs.items():
if k in args:
setattr(args, k, v)
else:
logger.warning(
f"You passed in `{k}` as an `accelerate launch` argument, but it was not accepted. "
"Please check https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch "
"to find out more about supported arguments and retry."
)
try:
launch_command(args)
except Exception as e:
logger.error(
"Accelerate training job failed... See error message for details."
def _decorator(step_function: BaseStep) -> BaseStep:
def _wrapper(
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
) -> F:
@functools.wraps(entrypoint)
def inner(*args: Any, **kwargs: Any) -> Any:
if args:
raise ValueError(
"Accelerated steps do not support positional arguments."
)
raise RuntimeError(
"Accelerate training job failed."
) from e
else:
logger.info(
"Accelerate training job finished successfully."
)
return pickle.load(open(output_path, "rb"))

return cast(F, inner)

import __main__

if __main__.__file__ == inspect.getsourcefile(step_function.entrypoint):
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"with steps defined inside the entrypoint script, please move "
f"your step `{step_function.name}` code to another file and retry."
with create_cli_wrapped_script(
entrypoint, flavor="accelerate"
) as (
script_path,
output_path,
):
commands = [str(script_path.absolute())]
for k, v in kwargs.items():
k = _cli_arg_name(k)
if isinstance(v, bool):
if v:
commands.append(f"--{k}")
elif type(v) in (list, tuple, set):
for each in v:
commands += [f"--{k}", f"{each}"]
else:
commands += [f"--{k}", f"{v}"]
logger.debug(commands)

parser = launch_command_parser()
args = parser.parse_args(commands)
for k, v in accelerate_launch_kwargs.items():
if k in args:
setattr(args, k, v)
else:
logger.warning(
f"You passed in `{k}` as an `accelerate launch` argument, but it was not accepted. "
"Please check https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch "
"to find out more about supported arguments and retry."
)
try:
launch_command(args)
except Exception as e:
logger.error(
"Accelerate training job failed... See error message for details."
)
raise RuntimeError(
"Accelerate training job failed."
) from e
else:
logger.info(
"Accelerate training job finished successfully."
)
return pickle.load(open(output_path, "rb"))

return cast(F, inner)

try:
get_pipeline_context()
except RuntimeError:
pass
else:
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"in a functional way with steps, please apply decoration "
"directly to a step instead. This behavior will be also "
"allowed in future, but now it faces technical limitations.\n"
"Example (allowed):\n"
f"@{run_with_accelerate.__name__}(...)\n"
f"def {step_function.name}(...):\n"
" ...\n"
"Example (not allowed):\n"
"def my_pipeline(...):\n"
f" run_with_accelerate({step_function.name},...)(...)\n"
)

setattr(
step_function, "unwrapped_entrypoint", step_function.entrypoint
)
if f"@{run_with_accelerate.__name__}" in inspect.getsource(
step_function.entrypoint
):
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"directly on steps using '@' syntax, please use a functional "
"decoration in your pipeline script instead.\n"
"Example (not allowed):\n"
f"@{run_with_accelerate.__name__}\n"
f"def {step_function.name}(...):\n"
" ...\n"
"Example (allowed):\n"
"def my_pipeline(...):\n"
f" run_with_accelerate({step_function.name})(...)\n"
setattr(
step_function,
"entrypoint",
_wrapper(
step_function.entrypoint,
accelerate_launch_kwargs=accelerate_launch_kwargs,
),
)

setattr(
step_function,
"entrypoint",
_decorator(
step_function.entrypoint,
accelerate_launch_kwargs=accelerate_launch_kwargs,
),
)
return step_function

return step_function
if step_function_top_level:
return _decorator(step_function_top_level)
return _decorator
8 changes: 4 additions & 4 deletions src/zenml/utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
import sys
sys.path.append(r"{func_path}")
from {func_module} import {func_name} as func_to_wrap
from {func_module} import {func_name} as step_function
if entrypoint:=getattr(func_to_wrap, "entrypoint", None):
func = _cli_wrapped_function(entrypoint)
if unwrapped_entrypoint:=getattr(step_function, "unwrapped_entrypoint", None):
func = _cli_wrapped_function(unwrapped_entrypoint)
else:
func = _cli_wrapped_function(func_to_wrap)
func = _cli_wrapped_function(step_function.entrypoint)
"""
_CLI_WRAPPED_MAINS = {
"accelerate": """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shutil
from pathlib import Path

import pytest
import transformers
from accelerate import Accelerator
from datasets import load_from_disk
Expand Down Expand Up @@ -74,17 +75,20 @@ def get_full_path(folder: str):
return str(ft_model_dir)


@pipeline(enable_cache=False)
def train_pipe():
model_dir = run_with_accelerate(train, num_processes=2, use_cpu=True)()
# if it is StepArtifact, we are still composing the pipeline
if not isinstance(model_dir, StepArtifact):
assert isinstance(model_dir, str)
assert model_dir == "model_dir"
train_accelerated = run_with_accelerate(train, num_processes=2, use_cpu=True)


def test_accelerate_runner_on_cpu_with_toy_model(clean_client):
"""Tests whether the run_with_accelerate wrapper works as expected."""

@pipeline(enable_cache=False)
def train_pipe():
model_dir = train_accelerated()
# if it is StepArtifact, we are still composing the pipeline
if not isinstance(model_dir, StepArtifact):
assert isinstance(model_dir, str)
assert model_dir == "model_dir"

try:
prev_files = os.listdir()
response = train_pipe()
Expand All @@ -93,3 +97,14 @@ def test_accelerate_runner_on_cpu_with_toy_model(clean_client):
cur_files = os.listdir()
for each in set(cur_files) - set(prev_files):
shutil.rmtree(each)


def test_accelerate_runner_fails_on_functional_use(clean_client):
"""Tests whether the run_with_accelerate wrapper works as expected."""

@pipeline(enable_cache=False)
def train_pipe():
_ = run_with_accelerate(train, num_processes=2, use_cpu=True)

with pytest.raises(RuntimeError):
train_pipe()

0 comments on commit 35813b1

Please sign in to comment.