diff --git a/altair/utils/core.py b/altair/utils/core.py index 8ecaa896b..41e886001 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -397,6 +397,31 @@ def to_list_if_array(val): return df +def sanitize_arrow_table(pa_table): + """Sanitize arrow table for JSON serialization""" + import pyarrow as pa + import pyarrow.compute as pc + + arrays = [] + schema = pa_table.schema + for name in schema.names: + array = pa_table[name] + dtype = schema.field(name).type + if str(dtype).startswith("timestamp"): + arrays.append(pc.strftime(array)) + elif str(dtype).startswith("duration"): + raise ValueError( + 'Field "{col_name}" has type "{dtype}" which is ' + "not supported by Altair. Please convert to " + "either a timestamp or a numerical value." + "".format(col_name=name, dtype=dtype) + ) + else: + arrays.append(array) + + return pa.Table.from_arrays(arrays, names=schema.names) + + def parse_shorthand( shorthand, data=None, diff --git a/altair/utils/data.py b/altair/utils/data.py index 28e66bfab..c72fd2ea9 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -8,7 +8,7 @@ from toolz import curried from typing import Callable -from .core import sanitize_dataframe +from .core import sanitize_dataframe, sanitize_arrow_table from .core import sanitize_geo_interface from .deprecation import AltairDeprecationWarning from .plugin_registry import PluginRegistry @@ -166,7 +166,7 @@ def to_values(data): elif hasattr(data, "__dataframe__"): # experimental interchange dataframe support pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + pa_table = sanitize_arrow_table(pi.from_dataframe(data)) return {"values": pa_table.to_pylist()} @@ -185,8 +185,6 @@ def check_data_type(data): # ============================================================================== # Private utilities # ============================================================================== - - def _compute_data_hash(data_str): return hashlib.md5(data_str.encode()).hexdigest() diff --git a/pyproject.toml b/pyproject.toml index 727690b22..d17777c24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,8 @@ dev = [ "mypy", "pandas-stubs", "types-jsonschema", - "types-setuptools" + "types-setuptools", + "pyarrow>=11" ] doc = [ "sphinx", diff --git a/tests/utils/test_dataframe_interchange.py b/tests/utils/test_dataframe_interchange.py new file mode 100644 index 000000000..6fc3a393f --- /dev/null +++ b/tests/utils/test_dataframe_interchange.py @@ -0,0 +1,57 @@ +from datetime import datetime +import pyarrow as pa +import pandas as pd +import pytest +import sys +import os + +from altair.utils.data import to_values + + +def windows_has_tzdata(): + """ + From PyArrow: python/pyarrow/tests/util.py + + This is the default location where tz.cpp will look for (until we make + this configurable at run-time) + """ + tzdata_path = os.path.expandvars(r"%USERPROFILE%\Downloads\tzdata") + return os.path.exists(tzdata_path) + + +# Skip test on Windows when the tz database is not configured. +# See https://github.com/altair-viz/altair/issues/3050. +@pytest.mark.skipif( + sys.platform == "win32" and not windows_has_tzdata(), + reason="Timezone database is not installed on Windows", +) +def test_arrow_timestamp_conversion(): + """Test that arrow timestamp values are converted to ISO-8601 strings""" + data = { + "date": [datetime(2004, 8, 1), datetime(2004, 9, 1), None], + "value": [102, 129, 139], + } + pa_table = pa.table(data) + + values = to_values(pa_table) + expected_values = { + "values": [ + {"date": "2004-08-01T00:00:00.000000", "value": 102}, + {"date": "2004-09-01T00:00:00.000000", "value": 129}, + {"date": None, "value": 139}, + ] + } + assert values == expected_values + + +def test_duration_raises(): + td = pd.timedelta_range(0, periods=3, freq="h") + df = pd.DataFrame(td).reset_index() + df.columns = ["id", "timedelta"] + pa_table = pa.table(df) + with pytest.raises(ValueError) as e: + to_values(pa_table) + + # Check that exception mentions the duration[ns] type, + # which is what the pandas timedelta is converted into + assert "duration[ns]" in e.value.args[0]