Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Add Raw AWS Batch Task #782

Merged
merged 16 commits into from
Feb 17, 2022
38 changes: 35 additions & 3 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def _compute_array_job_index():
offset = 0
if _os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"):
offset = int(_os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"))
return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"):
return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
return offset


def _map_job_index_to_child_index(local_input_dir, datadir, index):
Expand Down Expand Up @@ -380,11 +382,34 @@ def _execute_map_task(
raw_output_data_prefix,
max_concurrency,
test,
is_aws_batch_single_job: bool,
dynamic_addl_distro: str,
dynamic_dest_dir: str,
resolver: str,
resolver_args: List[str],
):
"""
This function should be called by map task and aws-batch task
resolver should be something like:
flytekit.core.python_auto_container.default_task_resolver
resolver args should be something like
task_module app.workflows task_name task_1
have dashes seems to mess up click, like --task_module seems to interfere

:param inputs: Where to read inputs
:param output_prefix: Where to write primitive outputs
:param raw_output_data_prefix: Where to write offloaded data (files, directories, dataframes).
:param test: Dry run
:param is_aws_batch_single_job: True if the aws batch job type is Single job
:param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be
nested).
:param resolver_args: Args that will be passed to the aforementioned resolver's load_task function
:param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the
compressed code archive has been uploaded.
:param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed
code archives should be installed in the flyte task container.
:return:
"""
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

Expand All @@ -394,8 +419,12 @@ def _execute_map_task(
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
if not isinstance(_task_def, PythonFunctionTask):
raise Exception("Map tasks cannot be run with instance tasks.")
map_task = MapPythonTask(_task_def, max_concurrency)
raise Exception("Map tasks cannot be run with instance tasks.", _task_def)

if is_aws_batch_single_job:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
map_task = _task_def
else:
map_task = MapPythonTask(_task_def, max_concurrency)

task_index = _compute_array_job_index()
output_prefix = _os.path.join(output_prefix, str(task_index))
Expand Down Expand Up @@ -508,6 +537,7 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--max-concurrency", type=int, required=False)
@_click.option("--test", is_flag=True)
@_click.option("--is-aws-batch-single-job", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
Expand All @@ -522,6 +552,7 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
is_aws_batch_single_job,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
Expand All @@ -535,6 +566,7 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
is_aws_batch_single_job,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(

collection_interface = transform_interface_to_list_interface(python_function_task.python_interface)
instance = next(self._ids)
name = f"{python_function_task._task_function.__module__}.mapper_{python_function_task._task_function.__name__}_{instance}"
name = f"{python_function_task.task_function.__module__}.mapper_{python_function_task.task_function.__name__}_{instance}"

self._run_task = python_function_task
self._max_concurrency = concurrency
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-aws-batch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit AWS Batch Plugin

Flyte backend can be connected with AWS batch. Once enabled, it allows you to run flyte task on AWS batch service

To install the plugin, run the following command:

```bash
pip install flytekitplugins-awsbatch
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import AWSBatch
71 changes: 71 additions & 0 deletions plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

from dataclasses_json import dataclass_json
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins


@dataclass_json
@dataclass
class AWSBatch(object):
"""
Use this to configure a job definition for a AWS batch job. Task's marked with this will automatically execute
natively onto AWS batch service.
Refer to AWS job definition template for more detail: https://docs.aws.amazon.com/batch/latest/userguide/job-definition-template.html
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
"""

parameters: Optional[Dict[str, str]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are all five of these in the SubmitJobInput documentation? Or are some of these Flyte concepts?

Copy link
Member Author

@pingsutw pingsutw Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all are in the SubmitJobInput documentation

schedulingPriority: Optional[int] = None
PlatformCapabilities: Optional[List[str]] = None
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
PropagateTags: Optional[bool] = None
RetryStrategy: Optional[Dict[str, Union[str, int, dict]]] = None
Tags: Optional[Dict[str, str]] = None
Timeout: Optional[Dict[str, int]] = None


class AWSBatchFunctionTask(PythonFunctionTask):
"""
Actual Plugin that transforms the local python code for execution within AWS batch job
"""

_AWS_BATCH_TASK_TYPE = "aws-batch"

def __init__(self, task_config: AWSBatch, task_function: Callable, **kwargs):
if task_config is None:
task_config = AWSBatch()
super(AWSBatchFunctionTask, self).__init__(
task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs
)
self._run_task = PythonFunctionTask(task_config=None, task_function=task_function)
self._task_config = task_config

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
s = Struct()
s.update(self._task_config.to_dict())
return json_format.MessageToDict(s)

def get_command(self, settings: SerializationSettings) -> List[str]:
container_args = [
"pyflyte-map-execute",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this to me? If I have one Python task, running one container, once - it still uses pyflyte-map-execute? Why is that? Seems kinda messy...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, can we not make a AWS batch task and then have a map-task override ride the command?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or if needed have a separate task

"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--is-aws-batch-single-job",
"--resolver",
self._run_task.task_resolver.location,
"--",
*self._run_task.task_resolver.loader_args(settings, self._run_task),
]

return container_args


# Inject the AWS batch plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(AWSBatch, AWSBatchFunctionTask)
2 changes: 2 additions & 0 deletions plugins/flytekit-aws-batch/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-aws-batch
148 changes: 148 additions & 0 deletions plugins/flytekit-aws-batch/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# pip-compile requirements.in
#
-e file:.#egg=flytekitplugins-aws-batch
# via -r requirements.in
arrow==1.2.1
# via jinja2-time
binaryornot==0.4.4
# via cookiecutter
certifi==2021.10.8
# via requests
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.7
# via requests
checksumdir==1.2.0
# via flytekit
click==7.1.2
# via
# cookiecutter
# flytekit
cloudpickle==2.0.0
# via flytekit
cookiecutter==1.7.3
# via flytekit
croniter==1.0.15
# via flytekit
dataclasses-json==0.5.6
# via flytekit
decorator==5.1.0
# via retry
deprecated==1.2.13
# via flytekit
diskcache==5.2.1
# via flytekit
docker-image-py==0.1.12
# via flytekit
docstring-parser==0.12
# via flytekit
flyteidl==0.21.8
# via flytekit
flytekit==0.24.0
# via flytekitplugins-aws-batch
grpcio==1.41.1
# via flytekit
idna==3.3
# via requests
importlib-metadata==4.8.2
# via keyring
jinja2==3.0.3
# via
# cookiecutter
# jinja2-time
jinja2-time==0.2.0
# via cookiecutter
keyring==23.2.1
# via flytekit
markupsafe==2.0.1
# via jinja2
marshmallow==3.14.0
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
mypy-extensions==0.4.3
# via typing-inspect
natsort==8.0.0
# via flytekit
numpy==1.21.4
# via
# pandas
# pyarrow
pandas==1.3.4
# via flytekit
poyo==0.5.0
# via cookiecutter
protobuf==3.19.1
# via
# flyteidl
# flytekit
py==1.11.0
# via retry
pyarrow==6.0.0
# via flytekit
python-dateutil==2.8.1
# via
# arrow
# croniter
# flytekit
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify==5.0.2
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
pytz==2018.4
# via
# flytekit
# pandas
regex==2021.11.10
# via docker-image-py
requests==2.26.0
# via
# cookiecutter
# flytekit
# responses
responses==0.15.0
# via flytekit
retry==0.9.2
# via flytekit
six==1.16.0
# via
# cookiecutter
# flytekit
# grpcio
# python-dateutil
# responses
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
# via flytekit
text-unidecode==1.3
# via python-slugify
typing-extensions==3.10.0.2
# via typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.7
# via
# flytekit
# requests
# responses
wheel==0.37.0
# via flytekit
wrapt==1.13.3
# via
# deprecated
# flytekit
zipp==3.6.0
# via importlib-metadata
34 changes: 34 additions & 0 deletions plugins/flytekit-aws-batch/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup

PLUGIN_NAME = "awsbatch"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.19.0,<1.0.0"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="admin@flyte.org",
description="This package holds the AWS Batch plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
Empty file.
Loading