Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add serialiser for Polars using Dataframe Interchange Protocol #298

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pytest = ">= 7.0.0, < 8"
altair = ">= 5.2.0, < 6"
httpx = ">=0.26.0, < 1"
alfred-cli = "^2.2.7"
polars = "^0.20.15"

[tool.poetry.extras]
ds = ["pandas", "pyarrow", "plotly"]
Expand Down
19 changes: 13 additions & 6 deletions src/streamsync/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def serialise(self, v: Any) -> Union[Dict, List, str, bool, int, float, None]:
return None
return v

if hasattr(v, "__dataframe__"):
return self._serialize_dataframe(v)

if "matplotlib.figure.Figure" in v_mro:
return self._serialise_matplotlib_fig(v)
if "plotly.graph_objs._figure.Figure" in v_mro:
Expand All @@ -128,8 +131,6 @@ def serialise(self, v: Any) -> Union[Dict, List, str, bool, int, float, None]:
return float(v)
if "numpy.ndarray" in v_mro:
return self._serialise_list_recursively(v.tolist())
if "pandas.core.frame.DataFrame" in v_mro:
return self._serialise_pandas_dataframe(v)
if "pyarrow.lib.Table" in v_mro:
return self._serialise_pyarrow_table(v)

Expand Down Expand Up @@ -161,11 +162,17 @@ def _serialise_matplotlib_fig(self, fig) -> str:
plt.close(fig)
return FileWrapper(iobytes, "image/png").get_as_dataurl()

def _serialise_pandas_dataframe(self, df):
import pyarrow as pa # type: ignore
def _serialize_dataframe(self, df) -> str:
"""
Serialize a dataframe with pyarrow a dataframe that implements
the Dataframe Interchange Protocol i.e. the __dataframe__() method

pa_table = pa.Table.from_pandas(df, preserve_index=True)
return self._serialise_pyarrow_table(pa_table)
:param df: dataframe that implements Dataframe Interchange Protocol (__dataframe__ method)
:return: a arrow file as a dataurl (application/vnd.apache.arrow.file)
"""
import pyarrow.interchange # type: ignore
table = pyarrow.interchange.from_dataframe(df)
return self._serialise_pyarrow_table(table)

def _serialise_pyarrow_table(self, table):
import pyarrow as pa # type: ignore
Expand Down
15 changes: 15 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import streamsync as ss
from streamsync.ss_types import StreamsyncEvent
import pandas as pd
import polars as pl
import plotly.express as px
import pytest
import altair
Expand Down Expand Up @@ -569,6 +570,20 @@ def test_pandas_df(self) -> None:
assert table.column("name")[0].as_py() == "Byte"
assert table.column("length_cm")[2].as_py() == 32

def test_polars_df(self) -> None:
d = {
"name": "Normal name",
"df": pl.read_csv(self.df_path)
}
s = self.sts.serialise(d)
assert s.get("name") == "Normal name"
df_durl = s.get("df")
df_buffer = urllib.request.urlopen(df_durl)
reader = pa.ipc.open_file(df_buffer)
table = reader.read_all()
assert table.column("name")[0].as_py() == "Byte"
assert table.column("length_cm")[2].as_py() == 32

class TestEvaluator:

def test_evaluate_field_simple(self) -> None:
Expand Down
Loading