-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat][python] Support Pytorch task in python api (#11975)
Co-authored-by: Jiajie Zhong <zhongjiajie955@gmail.com>
- Loading branch information
1 parent
864a908
commit 5202e5c
Showing
8 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,3 +42,4 @@ In this section | |
sub_process | ||
|
||
sagemaker | ||
pytorch |
42 changes: 42 additions & 0 deletions
42
dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/pytorch.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
.. Licensed to the Apache Software Foundation (ASF) under one | ||
or more contributor license agreements. See the NOTICE file | ||
distributed with this work for additional information | ||
regarding copyright ownership. The ASF licenses this file | ||
to you under the Apache License, Version 2.0 (the | ||
"License"); you may not use this file except in compliance | ||
with the License. You may obtain a copy of the License at | ||
.. http://www.apache.org/licenses/LICENSE-2.0 | ||
.. Unless required by applicable law or agreed to in writing, | ||
software distributed under the License is distributed on an | ||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations | ||
under the License. | ||
Pytorch | ||
======= | ||
|
||
|
||
A Pytorch task type's example and dive into information of **PyDolphinScheduler**. | ||
|
||
Example | ||
------- | ||
|
||
.. literalinclude:: ../../../src/pydolphinscheduler/examples/task_pytorch_example.py | ||
:start-after: [start workflow_declare] | ||
:end-before: [end workflow_declare] | ||
|
||
Dive Into | ||
--------- | ||
|
||
.. automodule:: pydolphinscheduler.tasks.pytorch | ||
|
||
|
||
YAML file example | ||
----------------- | ||
|
||
.. literalinclude:: ../../../examples/yaml_define/Pytorch.yaml | ||
:start-after: # under the License. | ||
:language: yaml |
53 changes: 53 additions & 0 deletions
53
dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/Pytorch.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Define the workflow | ||
workflow: | ||
name: "Pytorch" | ||
|
||
# Define the tasks under the workflow | ||
tasks: | ||
|
||
# run project with existing environment | ||
- name: task_existing_env | ||
task_type: pytorch | ||
script: main.py | ||
script_params: --dry-run --no-cuda | ||
project_path: https://github.com/pytorch/examples#mnist | ||
python_command: /home/anaconda3/envs/pytorch/bin/python3 | ||
|
||
|
||
# run project with creating conda environment | ||
- name: task_conda_env | ||
task_type: pytorch | ||
script: main.py | ||
script_params: --dry-run --no-cuda | ||
project_path: https://github.com/pytorch/examples#mnist | ||
is_create_environment: True | ||
python_env_tool: conda | ||
requirements: requirements.txt | ||
conda_python_version: 3.7 | ||
|
||
# run project with creating virtualenv environment | ||
- name: task_virtualenv_env | ||
task_type: pytorch | ||
script: main.py | ||
script_params: --dry-run --no-cuda | ||
project_path: https://github.com/pytorch/examples#mnist | ||
is_create_environment: True | ||
python_env_tool: virtualenv | ||
requirements: requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
...heduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_pytorch_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# [start workflow_declare] | ||
"""A example workflow for task pytorch.""" | ||
|
||
from pydolphinscheduler.core.process_definition import ProcessDefinition | ||
from pydolphinscheduler.tasks.pytorch import Pytorch | ||
|
||
with ProcessDefinition( | ||
name="task_pytorch_example", | ||
tenant="tenant_exists", | ||
) as pd: | ||
|
||
# run project with existing environment | ||
task_existing_env = Pytorch( | ||
name="task_existing_env", | ||
script="main.py", | ||
script_params="--dry-run --no-cuda", | ||
project_path="https://github.com/pytorch/examples#mnist", | ||
python_command="/home/anaconda3/envs/pytorch/bin/python3", | ||
) | ||
|
||
# run project with creating conda environment | ||
task_conda_env = Pytorch( | ||
name="task_conda_env", | ||
script="main.py", | ||
script_params="--dry-run --no-cuda", | ||
project_path="https://github.com/pytorch/examples#mnist", | ||
is_create_environment=True, | ||
python_env_tool="conda", | ||
requirements="requirements.txt", | ||
conda_python_version="3.7", | ||
) | ||
|
||
# run project with creating virtualenv environment | ||
task_virtualenv_env = Pytorch( | ||
name="task_virtualenv_env", | ||
script="main.py", | ||
script_params="--dry-run --no-cuda", | ||
project_path="https://github.com/pytorch/examples#mnist", | ||
is_create_environment=True, | ||
python_env_tool="virtualenv", | ||
requirements="requirements.txt", | ||
) | ||
|
||
pd.submit() | ||
# [end workflow_declare] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Task Pytorch.""" | ||
from typing import Optional | ||
|
||
from pydolphinscheduler.constants import TaskType | ||
from pydolphinscheduler.core.task import Task | ||
|
||
|
||
class DEFAULT: | ||
"""Default values for Pytorch.""" | ||
|
||
is_create_environment = False | ||
project_path = "." | ||
python_command = "${PYTHON_HOME}" | ||
|
||
|
||
class Pytorch(Task): | ||
"""Task Pytorch object, declare behavior for Pytorch task to dolphinscheduler. | ||
See also: `DolphinScheduler Pytorch Task Plugin | ||
<https://dolphinscheduler.apache.org/en-us/docs/dev/user_doc/guide/task/pytorch.html>`_ | ||
:param name: task name | ||
:param script: Entry to the Python script file that you want to run. | ||
:param script_params: Input parameters at run time. | ||
:param project_path: The path to the project. Default "." . | ||
:param is_create_environment: is create environment. Default False. | ||
:param python_command: The path to the python command. Default "${PYTHON_HOME}". | ||
:param python_env_tool: The python environment tool. Default "conda". | ||
:param requirements: The path to the requirements.txt file. Default "requirements.txt". | ||
:param conda_python_version: The python version of conda environment. Default "3.7". | ||
""" | ||
|
||
_task_custom_attr = { | ||
"script", | ||
"script_params", | ||
"other_params", | ||
"python_path", | ||
"is_create_environment", | ||
"python_command", | ||
"python_env_tool", | ||
"requirements", | ||
"conda_python_version", | ||
} | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
script: str, | ||
script_params: str = "", | ||
project_path: Optional[str] = DEFAULT.project_path, | ||
is_create_environment: Optional[bool] = DEFAULT.is_create_environment, | ||
python_command: Optional[str] = DEFAULT.python_command, | ||
python_env_tool: Optional[str] = "conda", | ||
requirements: Optional[str] = "requirements.txt", | ||
conda_python_version: Optional[str] = "3.7", | ||
*args, | ||
**kwargs, | ||
): | ||
"""Init Pytorch task.""" | ||
super().__init__(name, TaskType.PYTORCH, *args, **kwargs) | ||
self.script = script | ||
self.script_params = script_params | ||
self.is_create_environment = is_create_environment | ||
self.python_path = project_path | ||
self.python_command = python_command | ||
self.python_env_tool = python_env_tool | ||
self.requirements = requirements | ||
self.conda_python_version = conda_python_version | ||
|
||
@property | ||
def other_params(self): | ||
"""Return other params.""" | ||
conds = [ | ||
self.is_create_environment != DEFAULT.is_create_environment, | ||
self.python_path != DEFAULT.project_path, | ||
self.python_command != DEFAULT.python_command, | ||
] | ||
return any(conds) |
Oops, something went wrong.