Skip to content

Commit

Permalink
fix prev tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Dec 15, 2023
1 parent 13bd97f commit 13e8bd7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
8 changes: 5 additions & 3 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,11 @@ def run(self):
except Exception as e:
if self.in_teardown:
return
logger.info(f"Error executing worker task: {e}")
for output_channel in outer.dag_output_channels:
output_channel.set_error(e)
if isinstance(outer.dag_output_channels, list):
for output_channel in outer.dag_output_channels:
output_channel.set_error(e)
else:
outer.dag_output_channels.set_error(e)
self.teardown()

monitor = Monitor()
Expand Down
38 changes: 15 additions & 23 deletions python/ray/dag/tests/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import ray.cluster_utils
from ray.dag import InputNode, MultiOutputNode
from ray.tests.conftest import * # noqa
from ray._private.test_utils import wait_for_condition


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,6 +57,10 @@ def test_basic(ray_start_regular):
assert result == i + 1
output_channel.end_read()

# Note: must teardown before starting a new Ray session, otherwise you'll get
# a segfault from the dangling monitor thread upon the new Ray init.
compiled_dag.teardown()


def test_regular_args(ray_start_regular):
# Test passing regular args to .bind in addition to DAGNode args.
Expand All @@ -74,6 +77,8 @@ def test_regular_args(ray_start_regular):
assert result == (i + 1) * 3
output_channel.end_read()

compiled_dag.teardown()


@pytest.mark.parametrize("num_actors", [1, 4])
def test_scatter_gather_dag(ray_start_regular, num_actors):
Expand All @@ -92,6 +97,8 @@ def test_scatter_gather_dag(ray_start_regular, num_actors):
for chan in output_channels:
chan.end_read()

compiled_dag.teardown()


@pytest.mark.parametrize("num_actors", [1, 4])
def test_chain_dag(ray_start_regular, num_actors):
Expand All @@ -110,17 +117,20 @@ def test_chain_dag(ray_start_regular, num_actors):
assert result == list(range(num_actors))
output_channel.end_read()

compiled_dag.teardown()


def test_dag_exception(ray_start_regular, capsys):
a = Actor.remote(0)
with InputNode() as inp:
dag = a.inc.bind(inp)

compiled_dag = dag.experimental_compile()
compiled_dag.execute("hello")
wait_for_condition(
lambda: "Compiled DAG task aborted with exception" in capsys.readouterr().err
)
output_channel = compiled_dag.execute("hello")
with pytest.raises(TypeError):
output_channel.begin_read()

compiled_dag.teardown()


def test_dag_errors(ray_start_regular):
Expand Down Expand Up @@ -208,24 +218,6 @@ def test_dag_fault_tolerance(ray_start_regular, num_actors):
chan.end_read()


def test_dag_teardown(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as i:
dag = actor.inc.bind(i)

# Test we can go through multiple rounds of setup/teardown without issues.
for _ in range(10):
compiled_dag = dag.experimental_compile()

for _ in range(3):
output_channel = compiled_dag.execute(1)
# TODO(swang): Replace with fake ObjectRef.
output_channel.begin_read()
output_channel.end_read()

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
5 changes: 3 additions & 2 deletions python/ray/experimental/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

import ray
from ray.exceptions import RaySystemError
from ray.util.annotations import PublicAPI

# Logger for this module. It should be configured at the entry point
Expand Down Expand Up @@ -186,5 +187,5 @@ def set_error(self, e: Exception) -> None:


def _is_write_acquire_failed_error(e: Exception) -> bool:
# XXX detect the exception type better
return "write acquire failed" in str(e)
# TODO(ekl): detect the exception type better
return isinstance(e, RaySystemError)

0 comments on commit 13e8bd7

Please sign in to comment.