Skip to content

Commit

Permalink
fix: Pass native dataframe to data transformers (#3550)
Browse files Browse the repository at this point in the history
Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
  • Loading branch information
MarcoGorelli and dangotbanned authored Aug 21, 2024
1 parent 5de5138 commit c984002
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 53 deletions.
6 changes: 0 additions & 6 deletions altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
)
from weakref import WeakValueDictionary

import narwhals.stable.v1 as nw

from altair.utils._importers import import_vegafusion
from altair.utils.core import DataFrameLike
from altair.utils.data import (
Expand Down Expand Up @@ -71,10 +69,6 @@ def vegafusion_data_transformer(
data: DataType | None = None, max_rows: int = 100000
) -> Callable[..., Any] | _VegaFusionReturnType:
"""VegaFusion Data Transformer."""
# Vegafusion does not support Narwhals, so if `data` is a Narwhals
# object, we make sure to extract the native object and let Vegafusion handle it.
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data = nw.to_native(data, strict=False)
if data is None:
return vegafusion_data_transformer
elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface):
Expand Down
76 changes: 39 additions & 37 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,7 @@ def to_values(data: DataType) -> ToValuesReturnType:
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data_native = nw.to_native(data, strict=False)
if isinstance(data_native, SupportsGeoInterface):
if _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
# Maybe the type could be further clarified here that it is
# SupportGeoInterface and then the ignore statement is not needed?
data_sanitized = sanitize_geo_interface(data_native.__geo_interface__)
return {"values": data_sanitized}
return {"values": _from_geo_interface(data_native)}
elif _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
return {"values": data_native.to_dict(orient="records")}
Expand Down Expand Up @@ -350,32 +345,45 @@ def _compute_data_hash(data_str: str) -> str:
return hashlib.sha256(data_str.encode()).hexdigest()[:32]


def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]:
"""
Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed.
Notes
-----
Split out to resolve typing issues related to:
- Intersection types
- ``typing.TypeGuard``
- ``pd.DataFrame.__getattr__``
"""
if _is_pandas_dataframe(data):
data = sanitize_pandas_dataframe(data)
return sanitize_geo_interface(data.__geo_interface__)


def _data_to_json_string(data: DataType) -> str:
"""Return a JSON string representation of the input data."""
check_data_type(data)
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data_native = nw.to_native(data, strict=False)
if isinstance(data_native, SupportsGeoInterface):
if _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
data_native = sanitize_geo_interface(data_native.__geo_interface__)
return json.dumps(data_native)
elif _is_pandas_dataframe(data_native):
data = sanitize_pandas_dataframe(data_native)
return data_native.to_json(orient="records", double_precision=15)
elif isinstance(data_native, dict):
if "values" not in data_native:
if isinstance(data, SupportsGeoInterface):
return json.dumps(_from_geo_interface(data))
elif _is_pandas_dataframe(data):
data = sanitize_pandas_dataframe(data)
return data.to_json(orient="records", double_precision=15)
elif isinstance(data, dict):
if "values" not in data:
msg = "values expected in data dict, but not present."
raise KeyError(msg)
return json.dumps(data_native["values"], sort_keys=True)
elif isinstance(data, nw.DataFrame):
return json.dumps(data.rows(named=True))
else:
msg = "to_json only works with data expressed as " "a DataFrame or as a dict"
raise NotImplementedError(msg)
return json.dumps(data["values"], sort_keys=True)
try:
data_nw = nw.from_native(data, eager_only=True)
except TypeError as exc:
msg = "to_json only works with data expressed as a DataFrame or as a dict"
raise NotImplementedError(msg) from exc
data_nw = sanitize_narwhals_dataframe(data_nw)
return json.dumps(data_nw.rows(named=True))


def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str:
def _data_to_csv_string(data: DataType) -> str:
"""Return a CSV string representation of the input data."""
check_data_type(data)
if isinstance(data, SupportsGeoInterface):
Expand All @@ -398,18 +406,12 @@ def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str:
msg = "pandas is required to convert a dict to a CSV string"
raise ImportError(msg) from exc
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
elif isinstance(data, DataFrameLike):
# experimental interchange dataframe support
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = arrow_table_from_dfi_dataframe(data)
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(pa_table, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
else:
msg = "to_csv only works with data expressed as " "a DataFrame or as a dict"
raise NotImplementedError(msg)
try:
data_nw = nw.from_native(data, eager_only=True)
except TypeError as exc:
msg = "to_csv only works with data expressed as a DataFrame or as a dict"
raise NotImplementedError(msg) from exc
return data_nw.write_csv()


def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table:
Expand Down
3 changes: 2 additions & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing_extensions import TypeAlias

import jsonschema
import narwhals.stable.v1 as nw

from altair import utils
from altair.expr import core as _expr_core
Expand Down Expand Up @@ -274,7 +275,7 @@ def _prepare_data(
# convert dataframes or objects with __geo_interface__ to dict
elif not isinstance(data, dict) and _is_data_type(data):
if func := data_transformers.get():
data = func(data)
data = func(nw.to_native(data, strict=False))

# convert string input to a URLData
elif isinstance(data, str):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
# If you update the minimum required jsonschema version, also update it in build.yml
"jsonschema>=3.0",
"packaging",
"narwhals>=1.1.0"
"narwhals>=1.5.2"
]
description = "Vega-Altair: A declarative statistical visualization library for Python."
readme = "README.md"
Expand Down
23 changes: 15 additions & 8 deletions tests/utils/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, SupportsIndex, TypeVar

import narwhals.stable.v1 as nw
import pandas as pd
Expand All @@ -15,6 +17,8 @@
to_values,
)

T = TypeVar("T")


def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any:
# Redefined to maintain existing tests
Expand All @@ -24,13 +28,15 @@ def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any:
return data


def _create_dataframe(N):
data = pd.DataFrame({"x": range(N), "y": range(N)})
def _create_dataframe(
n: SupportsIndex, /, tp: Callable[..., T] | type[Any] = pd.DataFrame
) -> T | Any:
data = tp({"x": range(n), "y": range(n)})
return data


def _create_data_with_values(N):
data = {"values": [{"x": i, "y": i + 1} for i in range(N)]}
def _create_data_with_values(n: SupportsIndex, /) -> dict[str, Any]:
data = {"values": [{"x": i, "y": i + 1} for i in range(n)]}
return data


Expand Down Expand Up @@ -127,19 +133,20 @@ def test_dict_to_json():
assert data == {"values": output}


def test_dataframe_to_csv():
@pytest.mark.parametrize("tp", [pd.DataFrame, pl.DataFrame], ids=["pandas", "polars"])
def test_dataframe_to_csv(tp: type[Any]) -> None:
"""
Test to_csv with dataframe input.
- make certain the filename is deterministic
- make certain the file contents match the data.
"""
data = _create_dataframe(10)
data = _create_dataframe(10, tp=tp)
try:
result1 = _pipe(data, to_csv)
result2 = _pipe(data, to_csv)
filename = result1["url"]
output = pd.read_csv(filename)
output = tp(pd.read_csv(filename))
finally:
Path(filename).unlink()

Expand Down

0 comments on commit c984002

Please sign in to comment.