Skip to content

Commit

Permalink
Add another interesting test case
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Wang <peter.wang9812@gmail.com>
  • Loading branch information
Peter Wang committed Jun 26, 2024
1 parent 6564543 commit 924f653
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
def is_object_fixable_error(e: ArrowConversionError) -> bool:
"""Returns whether this error can be fixed by using an ArrowPythonObjectArray"""
return any(
err in "".join(traceback.format_exception(e))
err in "".join(traceback.format_exception(type(e), e, e.__traceback__))
for err in ARROW_OBJECT_FIXABLE_ERRORS
)

Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
"""
import pyarrow as pa

from ray.air.util.tensor_extensions.arrow import ArrowConversionError
from ray.data.extensions import (
ArrowPythonObjectArray,
ArrowPythonObjectType,
Expand All @@ -199,7 +200,10 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":

# If the result contains pyarrow schemas, unify them
schemas_to_unify = [b.schema for b in blocks]
schema = unify_schemas(schemas_to_unify)
try:
schema = unify_schemas(schemas_to_unify)
except Exception as e:
raise ArrowConversionError(str(blocks)) from e
if (
any(isinstance(type_, pa.ExtensionType) for type_ in schema.types)
or cols_with_null_list
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/tests/test_transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def test_fallback_to_pandas_on_incompatible_data(
"op, data",
[
("map", [1, 2**100]),
("map_batches", [[1.0], [2**4]]),
],
)
def test_pyarrow_conversion_error_detailed_info(
Expand All @@ -508,7 +509,7 @@ def test_pyarrow_conversion_error_detailed_info(
error_msg = str(e.value)
expected_msg = "ArrowConversionError: Error converting data to Arrow:"
assert expected_msg in error_msg, error_msg
assert "'my_data'" in error_msg, error_msg
assert "my_data" in error_msg, error_msg


if __name__ == "__main__":
Expand Down

0 comments on commit 924f653

Please sign in to comment.