Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release/0.0.121 #61

Merged
merged 2 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "zrb"
version = "0.0.120"
version = "0.0.121"
authors = [
{ name="Go Frendi Gunawan", email="gofrendiasgard@gmail.com" },
]
Expand Down
3 changes: 3 additions & 0 deletions src/zrb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from zrb.runner import runner
from zrb.task.decorator import python_task
from zrb.task.any_task import AnyTask
from zrb.task.parallel import AnyParallel, Parallel
from zrb.task.task import Task
from zrb.task.cmd_task import CmdTask
from zrb.task.docker_compose_task import DockerComposeTask, ServiceConfig
Expand Down Expand Up @@ -33,6 +34,8 @@

assert runner
assert AnyTask
assert AnyParallel
assert Parallel
assert python_task
assert Task
assert CmdTask
Expand Down
13 changes: 10 additions & 3 deletions src/zrb/task/base_task/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from zrb.task.base_task.component.renderer import Renderer
from zrb.task.base_task.component.base_task_model import BaseTaskModel
from zrb.task.parallel import AnyParallel
from zrb.advertisement import advertisements
from zrb.task_group.group import Group
from zrb.task_env.env import Env
Expand Down Expand Up @@ -102,9 +103,15 @@ def __init__(
self.__is_execution_triggered: bool = False
self.__is_execution_started: bool = False

def __rshift__(self, other_task: AnyTask):
other_task.add_upstream(self)
return other_task
def __rshift__(self, operand: Union[AnyParallel, AnyTask]):
if isinstance(operand, AnyTask):
operand.add_upstream(self)
return operand
if isinstance(operand, AnyParallel):
other_tasks: List[AnyTask] = operand.get_tasks()
for other_task in other_tasks:
other_task.add_upstream(self)
return operand

def copy(self) -> AnyTask:
return copy.deepcopy(self)
Expand Down
36 changes: 36 additions & 0 deletions src/zrb/task/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from zrb.helper.typing import TypeVar, List, Union
from abc import ABC, abstractmethod
from zrb.helper.typecheck import typechecked
from zrb.task.any_task import AnyTask


TParallel = TypeVar('TParallel', bound='Parallel')


class AnyParallel(ABC):
@abstractmethod
def get_tasks(self) -> List[AnyTask]:
pass


@typechecked
class Parallel(AnyParallel):
def __init__(self, *tasks: AnyTask):
self.__tasks = list(tasks)

def get_tasks(self) -> List[AnyTask]:
return self.__tasks

def __rshift__(
self, operand: Union[AnyTask, AnyParallel]
) -> Union[AnyTask, AnyParallel]:
if isinstance(operand, AnyTask):
for task in self.__tasks:
operand.add_upstream(task)
return operand
if isinstance(operand, AnyParallel):
other_tasks: List[AnyTask] = operand.get_tasks()
for task in self.__tasks:
for other_task in other_tasks:
other_task.add_upstream(task)
return operand
83 changes: 59 additions & 24 deletions test/task/test_task_shift_right_operator.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
from zrb.task.task import Task
from zrb.task.parallel import Parallel


def test_consistent_task_shift_right_operator():
result = []
task_1 = Task(
name='task-1',
run=lambda *args, **kwargs: result.append(1)
)
task_2 = Task(
name='task-2',
run=lambda *args, **kwargs: result.append(2)
)
task_3 = Task(
name='task-3',
run=lambda *args, **kwargs: result.append(3)
)
task_4 = Task(
name='task-4',
run=lambda *args, **kwargs: result.append(4)
)
def test_task_shift_right_operator():
task_1 = Task(name='task-1')
task_2 = Task(name='task-2')
task_3 = Task(name='task-3')
task_4 = Task(name='task-4')
# define DAG
task_1 >> task_2 >> task_3 >> task_4
function = task_4.to_function()
function()
assert result[0] == 1
assert result[1] == 2
assert result[2] == 3
assert result[3] == 4
# test 1
upstream_1 = task_1._get_upstreams()
assert len(upstream_1) == 0
# test 2
upstream_2 = task_2._get_upstreams()
assert len(upstream_2) == 1
assert upstream_2[0] == task_1
# test 3
upstream_3 = task_3._get_upstreams()
assert len(upstream_3) == 1
assert upstream_3[0] == task_2
# test 4
upstream_4 = task_4._get_upstreams()
assert len(upstream_4) == 1
assert upstream_4[0] == task_3


def test_task_shift_right_operator_with_parallel():
task_1 = Task(name='task-1')
task_2 = Task(name='task-2')
task_3 = Task(name='task-3')
task_4 = Task(name='task-4')
task_5 = Task(name='task-5')
task_6 = Task(name='task-6')
# define DAG
task_1 >> Parallel(task_2, task_3) >> Parallel(task_4, task_5) >> task_6
# test 1
upstream_1 = task_1._get_upstreams()
assert len(upstream_1) == 0
# test 2
upstream_2 = task_2._get_upstreams()
assert len(upstream_2) == 1
assert upstream_2[0] == task_1
# test 3
upstream_3 = task_3._get_upstreams()
assert len(upstream_3) == 1
assert upstream_3[0] == task_1
# test 4
upstream_4 = task_4._get_upstreams()
assert len(upstream_4) == 2
assert upstream_4[0] == task_2
assert upstream_4[1] == task_3
# test 5
upstream_5 = task_5._get_upstreams()
assert len(upstream_5) == 2
assert upstream_5[0] == task_2
assert upstream_5[1] == task_3
# test 6
upstream_6 = task_6._get_upstreams()
assert len(upstream_6) == 2
assert upstream_6[0] == task_4
assert upstream_6[1] == task_5