Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray Data] fix problem: to_pandas failed on datasets returned by from_spark #32968

Merged
merged 10 commits into from
Mar 28, 2023
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ArrowBlockBuilder(TableBlockBuilder[T]):
def __init__(self):
if pyarrow is None:
raise ImportError("Run `pip install pyarrow` for Arrow support")
super().__init__(pyarrow.Table)
super().__init__((pyarrow.Table, bytes))

@staticmethod
def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/delegating_block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_block(self, block: Block):
return
if self._builder is None:
self._builder = accessor.builder()
self._builder.add_block(block)
self._builder.add_block(accessor.to_block())

def will_build_yield_copy(self) -> bool:
if self._builder is None:
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_raydp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ray
import raydp
import torch
import pandas


# RayDP tests require Ray Java. Make sure ray jar is built before running this test.
Expand Down Expand Up @@ -51,6 +52,14 @@ def test_raydp_to_torch_iter(spark):
assert torch.equal(data_features, features) and torch.equal(data_labels, labels)


def test_to_pandas(spark):
df = spark.range(100)
ds = ray.data.from_spark(df)
pdf = ds.to_pandas()
pdf2 = df.toPandas()
pandas.testing.assert_frame_equal(pdf, pdf2)


if __name__ == "__main__":
import sys

Expand Down