Skip to content

Commit

Permalink
refactor: Remove @curried.curry dependency from limit_rows
Browse files Browse the repository at this point in the history
This pattern is what I've used in most cases. It works and type checks correctly, but could be improved as a decorator.
Would have the benefit of preserving the return type, if one were to inspect the intermediate function.

Also replacing the `TypeVar` with a `Union` seems to have satisfied `mypy`.
  • Loading branch information
dangotbanned committed May 26, 2024
1 parent aae983a commit b9dc070
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import json
import os
import random
Expand Down Expand Up @@ -89,12 +90,19 @@ class MaxRowsError(Exception):
pass


@curried.curry
def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType:
@overload
def limit_rows(data: Literal[None] = ..., max_rows: Optional[int] = ...) -> partial: ...
@overload
def limit_rows(data: DataType, max_rows: Optional[int]) -> DataType: ...
def limit_rows(
data: Optional[DataType] = None, max_rows: Optional[int] = 5000
) -> Union[partial, DataType]:
"""Raise MaxRowsError if the data model has more than max_rows.
If max_rows is None, then do not perform any check.
"""
if data is None:
return partial(limit_rows, max_rows=max_rows)
check_data_type(data)

def raise_max_rows_error():
Expand All @@ -111,7 +119,7 @@ def raise_max_rows_error():
"on how to plot large datasets."
)

if hasattr(data, "__geo_interface__"):
if isinstance(data, SupportsGeoInterface):
if data.__geo_interface__["type"] == "FeatureCollection":
values = data.__geo_interface__["features"]
else:
Expand All @@ -122,9 +130,7 @@ def raise_max_rows_error():
if "values" in data:
values = data["values"]
else:
# mypy gets confused as it doesn't see Dict[Any, Any]
# as equivalent to TDataType
return data # type: ignore[return-value]
return data
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
if max_rows is not None and pa_table.num_rows > max_rows:
Expand Down

0 comments on commit b9dc070

Please sign in to comment.