Skip to content

Commit

Permalink
fix(pandas): make quantile/multiquantile with filter work
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jan 5, 2023
1 parent 78bfd1d commit 6b5abd6
Showing 1 changed file with 78 additions and 23 deletions.
101 changes: 78 additions & 23 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,29 +314,64 @@ def execute_series_clip(op, data, lower, upper, **kwargs):
return data.clip(lower=lower, upper=upper)


@execute_node.register(ops.Quantile, (pd.Series, SeriesGroupBy), numeric_types)
def execute_series_quantile(op, data, quantile, aggcontext=None, **kwargs):
return aggcontext.agg(data, 'quantile', q=quantile, interpolation=op.interpolation)
@execute_node.register(ops.Quantile, pd.Series, numeric_types, (pd.Series, type(None)))
def execute_series_quantile(op, data, quantile, mask, aggcontext=None, **kwargs):
return aggcontext.agg(
data if mask is None else data.loc[mask],
'quantile',
q=quantile,
interpolation=op.interpolation,
)


@execute_node.register(ops.MultiQuantile, pd.Series, np.ndarray)
def execute_series_quantile_multi(op, data, quantile, aggcontext=None, **kwargs):
result = aggcontext.agg(
data, 'quantile', q=quantile, interpolation=op.interpolation
@execute_node.register(
ops.Quantile, SeriesGroupBy, numeric_types, (SeriesGroupBy, type(None))
)
def execute_series_group_by_quantile(
op, data, quantile, mask, aggcontext=None, **kwargs
):
return aggcontext.agg(
data,
(
"quantile"
if mask is None
else functools.partial(_filtered_reduction, mask.obj, pd.Series.quantile)
),
q=quantile,
interpolation=op.interpolation,
)
return np.array(result)


@execute_node.register(ops.MultiQuantile, SeriesGroupBy, np.ndarray)
@execute_node.register(
ops.MultiQuantile, pd.Series, np.ndarray, (pd.Series, type(None))
)
def execute_series_quantile_multi(op, data, quantile, mask, aggcontext=None, **kwargs):
return np.array(
aggcontext.agg(
data if mask is None else data.loc[mask],
"quantile",
q=quantile,
interpolation=op.interpolation,
)
)


@execute_node.register(
ops.MultiQuantile, SeriesGroupBy, np.ndarray, (SeriesGroupBy, type(None))
)
def execute_series_quantile_multi_groupby(
op, data, quantile, aggcontext=None, **kwargs
op, data, quantile, mask, aggcontext=None, **kwargs
):
def q(x, quantile, interpolation):
result = x.quantile(quantile, interpolation=interpolation).tolist()
res = [result for _ in range(len(x))]
return res
return [result for _ in range(len(x))]

result = aggcontext.agg(data, q, quantile, op.interpolation)
result = aggcontext.agg(
data,
q if mask is None else functools.partial(_filtered_reduction, mask.obj, q),
quantile,
op.interpolation,
)
return result


Expand Down Expand Up @@ -679,8 +714,24 @@ def execute_variance_series(op, data, mask, aggcontext=None, **kwargs):
)


@execute_node.register((ops.Any, ops.All), (pd.Series, SeriesGroupBy))
def execute_any_all_series(op, data, aggcontext=None, **kwargs):
@execute_node.register((ops.Any, ops.All), pd.Series, (pd.Series, type(None)))
def execute_any_all_series(op, data, mask, aggcontext=None, **kwargs):
if mask is not None:
data = data.loc[mask]
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = aggcontext.agg(data, type(op).__name__.lower())
else:
result = aggcontext.agg(
data, lambda data: getattr(data, type(op).__name__.lower())()
)
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register((ops.Any, ops.All), SeriesGroupBy, type(None))
def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs):
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = aggcontext.agg(data, type(op).__name__.lower())
else:
Expand All @@ -693,24 +744,28 @@ def execute_any_all_series(op, data, aggcontext=None, **kwargs):
return result


@execute_node.register(ops.NotAny, (pd.Series, SeriesGroupBy))
def execute_notany_series(op, data, aggcontext=None, **kwargs):
@execute_node.register((ops.NotAny, ops.NotAll), pd.Series, (pd.Series, type(None)))
def execute_notany_notall_series(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~(aggcontext.agg(data, 'any'))
result = ~aggcontext.agg(data, name)
else:
result = aggcontext.agg(data, lambda data: ~(data.any()))
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register(ops.NotAll, (pd.Series, SeriesGroupBy))
def execute_notall_series(op, data, aggcontext=None, **kwargs):
@execute_node.register((ops.NotAny, ops.NotAll), SeriesGroupBy, type(None))
def execute_notany_notall_series_group_by(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~(aggcontext.agg(data, 'all'))
result = ~aggcontext.agg(data, name)
else:
result = aggcontext.agg(data, lambda data: ~(data.all()))
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
Expand Down

0 comments on commit 6b5abd6

Please sign in to comment.