Skip to content

Commit

Permalink
[feat][python] Support Pytorch task in python api (#11975)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiajie Zhong <zhongjiajie955@gmail.com>
  • Loading branch information
jieguangzhou and zhongjiajie authored Sep 16, 2022
1 parent 864a908 commit 5202e5c
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ In this section
sub_process

sagemaker
pytorch
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Pytorch
=======


A Pytorch task type's example and dive into information of **PyDolphinScheduler**.

Example
-------

.. literalinclude:: ../../../src/pydolphinscheduler/examples/task_pytorch_example.py
:start-after: [start workflow_declare]
:end-before: [end workflow_declare]

Dive Into
---------

.. automodule:: pydolphinscheduler.tasks.pytorch


YAML file example
-----------------

.. literalinclude:: ../../../examples/yaml_define/Pytorch.yaml
:start-after: # under the License.
:language: yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Define the workflow
workflow:
name: "Pytorch"

# Define the tasks under the workflow
tasks:

# run project with existing environment
- name: task_existing_env
task_type: pytorch
script: main.py
script_params: --dry-run --no-cuda
project_path: https://github.com/pytorch/examples#mnist
python_command: /home/anaconda3/envs/pytorch/bin/python3


# run project with creating conda environment
- name: task_conda_env
task_type: pytorch
script: main.py
script_params: --dry-run --no-cuda
project_path: https://github.com/pytorch/examples#mnist
is_create_environment: True
python_env_tool: conda
requirements: requirements.txt
conda_python_version: 3.7

# run project with creating virtualenv environment
- name: task_virtualenv_env
task_type: pytorch
script: main.py
script_params: --dry-run --no-cuda
project_path: https://github.com/pytorch/examples#mnist
is_create_environment: True
python_env_tool: virtualenv
requirements: requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TaskType(str):
SPARK = "SPARK"
MR = "MR"
SAGEMAKER = "SAGEMAKER"
PYTORCH = "PYTORCH"


class DefaultTaskCodeNum(str):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# [start workflow_declare]
"""A example workflow for task pytorch."""

from pydolphinscheduler.core.process_definition import ProcessDefinition
from pydolphinscheduler.tasks.pytorch import Pytorch

with ProcessDefinition(
name="task_pytorch_example",
tenant="tenant_exists",
) as pd:

# run project with existing environment
task_existing_env = Pytorch(
name="task_existing_env",
script="main.py",
script_params="--dry-run --no-cuda",
project_path="https://github.com/pytorch/examples#mnist",
python_command="/home/anaconda3/envs/pytorch/bin/python3",
)

# run project with creating conda environment
task_conda_env = Pytorch(
name="task_conda_env",
script="main.py",
script_params="--dry-run --no-cuda",
project_path="https://github.com/pytorch/examples#mnist",
is_create_environment=True,
python_env_tool="conda",
requirements="requirements.txt",
conda_python_version="3.7",
)

# run project with creating virtualenv environment
task_virtualenv_env = Pytorch(
name="task_virtualenv_env",
script="main.py",
script_params="--dry-run --no-cuda",
project_path="https://github.com/pytorch/examples#mnist",
is_create_environment=True,
python_env_tool="virtualenv",
requirements="requirements.txt",
)

pd.submit()
# [end workflow_declare]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pydolphinscheduler.tasks.map_reduce import MR
from pydolphinscheduler.tasks.procedure import Procedure
from pydolphinscheduler.tasks.python import Python
from pydolphinscheduler.tasks.pytorch import Pytorch
from pydolphinscheduler.tasks.sagemaker import SageMaker
from pydolphinscheduler.tasks.shell import Shell
from pydolphinscheduler.tasks.spark import Spark
Expand All @@ -42,6 +43,7 @@
"MR",
"Procedure",
"Python",
"Pytorch",
"Shell",
"Spark",
"Sql",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Task Pytorch."""
from typing import Optional

from pydolphinscheduler.constants import TaskType
from pydolphinscheduler.core.task import Task


class DEFAULT:
"""Default values for Pytorch."""

is_create_environment = False
project_path = "."
python_command = "${PYTHON_HOME}"


class Pytorch(Task):
"""Task Pytorch object, declare behavior for Pytorch task to dolphinscheduler.
See also: `DolphinScheduler Pytorch Task Plugin
<https://dolphinscheduler.apache.org/en-us/docs/dev/user_doc/guide/task/pytorch.html>`_
:param name: task name
:param script: Entry to the Python script file that you want to run.
:param script_params: Input parameters at run time.
:param project_path: The path to the project. Default "." .
:param is_create_environment: is create environment. Default False.
:param python_command: The path to the python command. Default "${PYTHON_HOME}".
:param python_env_tool: The python environment tool. Default "conda".
:param requirements: The path to the requirements.txt file. Default "requirements.txt".
:param conda_python_version: The python version of conda environment. Default "3.7".
"""

_task_custom_attr = {
"script",
"script_params",
"other_params",
"python_path",
"is_create_environment",
"python_command",
"python_env_tool",
"requirements",
"conda_python_version",
}

def __init__(
self,
name: str,
script: str,
script_params: str = "",
project_path: Optional[str] = DEFAULT.project_path,
is_create_environment: Optional[bool] = DEFAULT.is_create_environment,
python_command: Optional[str] = DEFAULT.python_command,
python_env_tool: Optional[str] = "conda",
requirements: Optional[str] = "requirements.txt",
conda_python_version: Optional[str] = "3.7",
*args,
**kwargs,
):
"""Init Pytorch task."""
super().__init__(name, TaskType.PYTORCH, *args, **kwargs)
self.script = script
self.script_params = script_params
self.is_create_environment = is_create_environment
self.python_path = project_path
self.python_command = python_command
self.python_env_tool = python_env_tool
self.requirements = requirements
self.conda_python_version = conda_python_version

@property
def other_params(self):
"""Return other params."""
conds = [
self.is_create_environment != DEFAULT.is_create_environment,
self.python_path != DEFAULT.project_path,
self.python_command != DEFAULT.python_command,
]
return any(conds)
Loading

0 comments on commit 5202e5c

Please sign in to comment.