Skip to content

Commit

Permalink
refactor(pandas-format): move to classmethods to pickup super class b…
Browse files Browse the repository at this point in the history
…ehavior where possible
  • Loading branch information
cpcloud committed Dec 18, 2023
1 parent a3fac3e commit 7bb0470
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 71 deletions.
38 changes: 16 additions & 22 deletions ibis/backends/snowflake/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,29 @@


class SnowflakePandasData(PandasData):
@staticmethod
def convert_JSON(s, dtype, pandas_type):
converter = SnowflakePandasData.convert_JSON_element(dtype)
@classmethod
def convert_JSON(cls, s, dtype, pandas_type):
converter = cls.convert_JSON_element(dtype)
return s.map(converter, na_action="ignore").astype("object")

convert_Struct = convert_Map = convert_JSON

@staticmethod
def get_element_converter(dtype):
funcgen = getattr(
SnowflakePandasData,
f"convert_{type(dtype).__name__}_element",
lambda _: lambda x: x,
)
return funcgen(dtype)

def convert_Timestamp_element(dtype):
return lambda values: list(map(datetime.datetime.fromisoformat, values))
@classmethod
def convert_Timestamp_element(cls, dtype):
return datetime.datetime.fromisoformat

def convert_Date_element(dtype):
return lambda values: list(map(datetime.date.fromisoformat, values))
@classmethod
def convert_Date_element(cls, dtype):
return datetime.date.fromisoformat

def convert_Time_element(dtype):
return lambda values: list(map(datetime.time.fromisoformat, values))
@classmethod
def convert_Time_element(cls, dtype):
return datetime.time.fromisoformat

@staticmethod
def convert_Array(s, dtype, pandas_type):
raw_json_objects = SnowflakePandasData.convert_JSON(s, dtype, pandas_type)
converter = SnowflakePandasData.get_element_converter(dtype.value_type)
@classmethod
def convert_Array(cls, s, dtype, pandas_type):
raw_json_objects = cls.convert_JSON(s, dtype, pandas_type)
converter = cls.get_element_converter(dtype.value_type)
return raw_json_objects.map(converter, na_action="ignore")


Expand Down
112 changes: 63 additions & 49 deletions ibis/formats/pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import json
import warnings

Expand All @@ -10,6 +11,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.common.temporal import normalize_timezone
from ibis.formats import DataMapper, SchemaMapper, TableProxy
from ibis.formats.numpy import NumpyType
from ibis.formats.pyarrow import PyArrowData, PyArrowSchema, PyArrowType
Expand Down Expand Up @@ -132,8 +134,8 @@ def convert_column(cls, obj, dtype):
assert not isinstance(result, np.ndarray), f"{convert_method} -> {type(result)}"
return result

@staticmethod
def convert_GeoSpatial(s, dtype, pandas_type):
@classmethod
def convert_GeoSpatial(cls, s, dtype, pandas_type):
return s

convert_Point = (
Expand All @@ -144,15 +146,15 @@ def convert_GeoSpatial(s, dtype, pandas_type):
convert_MultiLineString
) = convert_MultiPoint = convert_MultiPolygon = convert_GeoSpatial

@staticmethod
def convert_default(s, dtype, pandas_type):
@classmethod
def convert_default(cls, s, dtype, pandas_type):
try:
return s.astype(pandas_type)
except Exception: # noqa: BLE001
return s

@staticmethod
def convert_Boolean(s, dtype, pandas_type):
@classmethod
def convert_Boolean(cls, s, dtype, pandas_type):
if s.empty:
return s.astype(pandas_type)
elif pdt.is_object_dtype(s.dtype):
Expand All @@ -162,8 +164,8 @@ def convert_Boolean(s, dtype, pandas_type):
else:
return s

@staticmethod
def convert_Timestamp(s, dtype, pandas_type):
@classmethod
def convert_Timestamp(cls, s, dtype, pandas_type):
if isinstance(dtype, pd.DatetimeTZDtype):
return s.dt.tz_convert(dtype.timezone)
elif pdt.is_datetime64_dtype(s.dtype):
Expand All @@ -184,14 +186,14 @@ def convert_Timestamp(s, dtype, pandas_type):
except TypeError:
return pd.to_datetime(s).dt.tz_localize(dtype.timezone)

@staticmethod
def convert_Date(s, dtype, pandas_type):
@classmethod
def convert_Date(cls, s, dtype, pandas_type):
if isinstance(s.dtype, pd.DatetimeTZDtype):
s = s.dt.tz_convert("UTC").dt.tz_localize(None)
return s.astype(pandas_type, errors="ignore").dt.normalize()

@staticmethod
def convert_Interval(s, dtype, pandas_type):
@classmethod
def convert_Interval(cls, s, dtype, pandas_type):
values = s.values
try:
result = values.astype(pandas_type)
Expand All @@ -201,42 +203,41 @@ def convert_Interval(s, dtype, pandas_type):
result = s.__class__(result, index=s.index, name=s.name)
return result

@staticmethod
def convert_String(s, dtype, pandas_type):
@classmethod
def convert_String(cls, s, dtype, pandas_type):
return s.astype(pandas_type, errors="ignore")

@staticmethod
def convert_UUID(s, dtype, pandas_type):
return s.map(PandasData.get_element_converter(dtype), na_action="ignore")

@staticmethod
def convert_Struct(s, dtype, pandas_type):
return s.map(PandasData.get_element_converter(dtype), na_action="ignore")
@classmethod
def convert_UUID(cls, s, dtype, pandas_type):
return s.map(cls.get_element_converter(dtype), na_action="ignore")

@staticmethod
def convert_Array(s, dtype, pandas_type):
return s.map(PandasData.get_element_converter(dtype), na_action="ignore")
@classmethod
def convert_Struct(cls, s, dtype, pandas_type):
return s.map(cls.get_element_converter(dtype), na_action="ignore")

@staticmethod
def convert_Map(s, dtype, pandas_type):
return s.map(PandasData.get_element_converter(dtype), na_action="ignore")
@classmethod
def convert_Array(cls, s, dtype, pandas_type):
return s.map(cls.get_element_converter(dtype), na_action="ignore")

@staticmethod
def convert_JSON(s, dtype, pandas_type):
return s.map(
PandasData.get_element_converter(dtype), na_action="ignore"
).astype("object")
@classmethod
def convert_Map(cls, s, dtype, pandas_type):
return s.map(cls.get_element_converter(dtype), na_action="ignore")

@staticmethod
def get_element_converter(dtype):
funcgen = getattr(
PandasData, f"convert_{type(dtype).__name__}_element", lambda _: lambda x: x
@classmethod
def convert_JSON(cls, s, dtype, pandas_type):
return s.map(cls.get_element_converter(dtype), na_action="ignore").astype(
"object"
)

@classmethod
def get_element_converter(cls, dtype):
name = f"convert_{type(dtype).__name__}_element"
funcgen = getattr(cls, name, lambda _: lambda x: x)
return funcgen(dtype)

@staticmethod
def convert_Struct_element(dtype):
converters = tuple(map(PandasData.get_element_converter, dtype.types))
@classmethod
def convert_Struct_element(cls, dtype):
converters = tuple(map(cls.get_element_converter, dtype.types))

def convert(values, names=dtype.names, converters=converters):
items = values.items() if isinstance(values, dict) else zip(names, values)
Expand All @@ -247,8 +248,8 @@ def convert(values, names=dtype.names, converters=converters):

return convert

@staticmethod
def convert_JSON_element(_):
@classmethod
def convert_JSON_element(cls, _):
def try_json(x):
if x is None:
return x
Expand All @@ -259,23 +260,36 @@ def try_json(x):

return try_json

@staticmethod
def convert_Array_element(dtype):
convert_value = PandasData.get_element_converter(dtype.value_type)
@classmethod
def convert_Timestamp_element(cls, dtype):
def converter(value, dtype=dtype):
with contextlib.suppress(AttributeError):
value = value.item()

if (tz := dtype.timezone) is not None:
return value.astimezone(normalize_timezone(tz))

return value.replace(tzinfo=None)

return converter

@classmethod
def convert_Array_element(cls, dtype):
convert_value = cls.get_element_converter(dtype.value_type)
return lambda values: [
convert_value(value) if value is not None else value for value in values
]

@staticmethod
def convert_Map_element(dtype):
convert_value = PandasData.get_element_converter(dtype.value_type)
@classmethod
def convert_Map_element(cls, dtype):
convert_value = cls.get_element_converter(dtype.value_type)
return lambda row: {
key: convert_value(value) if value is not None else value
for key, value in dict(row).items()
}

@staticmethod
def convert_UUID_element(_):
@classmethod
def convert_UUID_element(cls, _):
from uuid import UUID

return lambda v: v if isinstance(v, UUID) else UUID(v)
Expand Down

0 comments on commit 7bb0470

Please sign in to comment.