Skip to content

Commit

Permalink
refactor: Simplify op lt
Browse files Browse the repository at this point in the history
Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
  • Loading branch information
dengwxn committed Sep 25, 2024
1 parent b6a2c9e commit bbce790
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 80 deletions.
28 changes: 9 additions & 19 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,19 @@ def compare(lhs: "_DAGOperationGraphNode", rhs: "_DAGOperationGraphNode"):
return lhs.operation.exec_task_idx < rhs.operation.exec_task_idx
return lhs.task_idx < rhs.task_idx

# When both nodes belong to the same actor, use the default comparison.
if self.actor_handle == other.actor_handle:
# When both nodes belong to the same actor, use the default comparison.
return compare(self, other)

if not (self.is_nccl_op or other.is_nccl_op):
# When both nodes are not NCCL operations, use the default comparison.
return compare(self, other)
elif self.is_nccl_op != other.is_nccl_op:
# When one node is an NCCL operation and the other is not, prioritize
# the non-NCCL operation.
return not self.is_nccl_op
else:
# Both nodes are NCCL operations.
if self.is_nccl_write and other.is_nccl_write:
# When both nodes are NCCL writes, use the default comparison.
return compare(self, other)
elif self.is_nccl_compute and other.is_nccl_compute:
# When both nodes are NCCL computes, use the default comparison.
return compare(self, other)
# Both nodes belong to different actors.
if self.is_nccl_op != other.is_nccl_op:
# When one node is an NCCL operation and the other is not, prioritize
# the non-NCCL operation.
return not self.is_nccl_op
else:
# When one node is an NCCL write and the other is an NCCL compute,
# prioritize the NCCL write.
return self.is_nccl_write
# When either both nodes are NCCL operations or both nodes are not
# NCCL operations, use the default comparison.
return compare(self, other)

def __eq__(self, other: "_DAGOperationGraphNode"):
"""
Expand Down
122 changes: 61 additions & 61 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,67 +997,6 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
compiled_dag.teardown()


# @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
# def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular):
# """
# Test a dag containing all-reduce errors when given tensors of wrong shapes.
# """
# if not USE_GPU:
# pytest.skip("NCCL tests require GPUs")

# assert (
# sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
# ), "This test requires at least 2 GPUs"

# actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

# num_workers = 2
# workers = [actor_cls.remote() for _ in range(num_workers)]

# dtype = torch.float16

# with InputNode() as inp:
# computes = [
# worker.compute_with_tuple_args.bind(inp, i)
# for i, worker in enumerate(workers)
# ]
# collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
# recvs = [
# worker.recv.bind(collective)
# for worker, collective in zip(workers, collectives)
# ]
# dag = MultiOutputNode(recvs)

# compiled_dag = dag.experimental_compile()

# ref = compiled_dag.execute(
# [((20,), dtype, idx + 1) for idx in range(num_workers)]
# )
# reduced_val = (1 + num_workers) * num_workers / 2
# assert ray.get(ref) == [(reduced_val, (20,), dtype) for _ in range(num_workers)]

# ref = compiled_dag.execute(
# [((10 + idx,), dtype, idx + 1) for idx in range(num_workers)]
# )
# # The shapes mismatch but no errors are thrown.
# # [TODO] Throw error when shapes mismatch. Make sure it does not hang.
# with pytest.raises(RayChannelError):
# ray.get(ref)

# # The DAG will be torn down after any task throws an application-level
# # exception, such as when the task returns torch.Tensors of the wrong
# # shape or dtype. Check that we can no longer submit to the DAG.
# ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers])
# with pytest.raises(RayChannelError):
# ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers])

# compiled_dag.teardown()

# # [TODO:andy] Check if this requires time.sleep to avoid some issue with
# # following tests.
# # time.sleep(3)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_all_reduce_custom_comm(ray_start_regular):
"""
Expand Down Expand Up @@ -1853,6 +1792,67 @@ def test_torch_tensor_nccl_all_reduce_scheduling_one_ready_group(ray_start_regul
compiled_dag.teardown()


# @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
# def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular):
# """
# Test a dag containing all-reduce errors when given tensors of wrong shapes.
# """
# if not USE_GPU:
# pytest.skip("NCCL tests require GPUs")

# assert (
# sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
# ), "This test requires at least 2 GPUs"

# actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

# num_workers = 2
# workers = [actor_cls.remote() for _ in range(num_workers)]

# dtype = torch.float16

# with InputNode() as inp:
# computes = [
# worker.compute_with_tuple_args.bind(inp, i)
# for i, worker in enumerate(workers)
# ]
# collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
# recvs = [
# worker.recv.bind(collective)
# for worker, collective in zip(workers, collectives)
# ]
# dag = MultiOutputNode(recvs)

# compiled_dag = dag.experimental_compile()

# ref = compiled_dag.execute(
# [((20,), dtype, idx + 1) for idx in range(num_workers)]
# )
# reduced_val = (1 + num_workers) * num_workers / 2
# assert ray.get(ref) == [(reduced_val, (20,), dtype) for _ in range(num_workers)]

# ref = compiled_dag.execute(
# [((10 + idx,), dtype, idx + 1) for idx in range(num_workers)]
# )
# # The shapes mismatch but no errors are thrown.
# # [TODO] Throw error when shapes mismatch. Make sure it does not hang.
# with pytest.raises(RayChannelError):
# ray.get(ref)

# # The DAG will be torn down after any task throws an application-level
# # exception, such as when the task returns torch.Tensors of the wrong
# # shape or dtype. Check that we can no longer submit to the DAG.
# ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers])
# with pytest.raises(RayChannelError):
# ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers])

# compiled_dag.teardown()

# # [TODO:andy] Check if this requires time.sleep to avoid some issue with
# # following tests.
# # time.sleep(3)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down

0 comments on commit bbce790

Please sign in to comment.