Skip to content

Commit

Permalink
[Data] Fix bug where Ray Data incorrectly emits progress bar warning (#…
Browse files Browse the repository at this point in the history
…47680)

Fixes #47679

---------

Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
  • Loading branch information
bveeramani authored Sep 16, 2024
1 parent 9495e72 commit 575d6af
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
16 changes: 10 additions & 6 deletions python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def _truncate_name(self, name: str) -> str:
):
return name

if log_once("ray_data_truncate_operator_name"):
logger.warning(
f"Truncating long operator name to {self.MAX_NAME_LENGTH} characters."
"To disable this behavior, set `ray.data.DataContext.get_current()."
"DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
)
op_names = name.split("->")
if len(op_names) == 1:
return op_names[0]
Expand All @@ -141,6 +135,13 @@ def _truncate_name(self, name: str) -> str:
+ len(op_names[-1])
) > self.MAX_NAME_LENGTH:
truncated_op_names.append("...")
if log_once("ray_data_truncate_operator_name"):
logger.warning(
f"Truncating long operator name to {self.MAX_NAME_LENGTH} "
"characters. To disable this behavior, set "
"`ray.data.DataContext.get_current()."
"DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
)
break
truncated_op_names.append(op_name)
truncated_op_names.append(op_names[-1])
Expand Down Expand Up @@ -199,6 +200,9 @@ def set_description(self, name: str) -> None:
self._desc = name
self._bar.set_description(self._desc)

def get_description(self) -> str:
return self._desc

def refresh(self):
if self._bar:
self._bar.refresh()
Expand Down
47 changes: 43 additions & 4 deletions python/ray/data/tests/test_progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import logging
from unittest.mock import patch

import pytest
from pytest import fixture
Expand Down Expand Up @@ -39,7 +41,7 @@ def wrapped_close():
bar.close = wrapped_close

# Test basic usage
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
for _ in range(total):
Expand All @@ -50,7 +52,7 @@ def wrapped_close():
assert total_at_close == total

# Test if update() exceeds the original total, the total will be updated.
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
Expand All @@ -62,7 +64,7 @@ def wrapped_close():
assert total_at_close == new_total

# Test that if the bar is not complete at close(), the total will be updated.
pb = ProgressBar("", total, "")
pb = ProgressBar("", total, "unit")
assert pb._bar is not None
patch_close(pb._bar)
new_total = total // 2
Expand All @@ -74,7 +76,7 @@ def wrapped_close():
assert total_at_close == new_total

# Test updating the total
pb = ProgressBar("", total, "", enabled=True)
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
Expand All @@ -84,3 +86,40 @@ def wrapped_close():
pb.update(total + 1, total)
assert pb._bar.total == total + 1
pb.close()


@pytest.mark.parametrize(
"name, expected_description, max_line_length, should_emit_warning",
[
("Op", "Op", 2, False),
("Op->Op", "Op->Op", 5, False),
("Op->Op->Op", "Op->...->Op", 9, True),
("Op->Op->Op", "Op->Op->Op", 10, False),
# Test case for https://github.com/ray-project/ray/issues/47679.
("spam", "spam", 1, False),
],
)
def test_progress_bar_truncates_chained_operators(
name,
expected_description,
max_line_length,
should_emit_warning,
caplog,
propagate_logs,
):
with patch.object(ProgressBar, "MAX_NAME_LENGTH", max_line_length):
pb = ProgressBar(name, None, "unit")

assert pb.get_description() == expected_description
if should_emit_warning:
assert any(
record.levelno == logging.WARNING
and "Truncating long operator name" in record.message
for record in caplog.records
), caplog.records


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 575d6af

Please sign in to comment.