Skip to content

Commit

Permalink
[SPARK-48924][PS] Add a pandas-like make_interval helper function
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add a pandas-like `make_interval` helper function

### Why are the changes needed?
factor it out as a helper function to be reusable

### Does this PR introduce _any_ user-facing change?
No, internal change only

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47385 from zhengruifeng/ps_simplify_make_interval.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and jingz-db committed Jul 22, 2024
1 parent 3f95864 commit 425ada9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
35 changes: 11 additions & 24 deletions python/pyspark/pandas/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Generic,
List,
Optional,
Union,
)

import numpy as np
Expand All @@ -42,6 +41,7 @@
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import FrameLike
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.internal import (
InternalField,
InternalFrame,
Expand Down Expand Up @@ -130,19 +130,6 @@ def _resamplekey_type(self) -> DataType:
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]

def get_make_interval( # type: ignore[return]
self, unit: str, col: Union[Column, int, float]
) -> Column:
col = col if not isinstance(col, (int, float)) else F.lit(col)
if unit == "MONTH":
return F.make_interval(months=col)
if unit == "HOUR":
return F.make_interval(hours=col)
if unit == "MINUTE":
return F.make_interval(mins=col)
if unit == "SECOND":
return F.make_interval(secs=col)

def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
key_type = self._resamplekey_type
origin_scol = F.lit(origin)
Expand Down Expand Up @@ -203,18 +190,18 @@ def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += self.get_make_interval("MONTH", n)
edge_label += SF.make_interval("MONTH", n)
elif right_closed and left_labeled:
edge_label -= self.get_make_interval("MONTH", n)
edge_label -= SF.make_interval("MONTH", n)

if left_labeled:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol - self.get_make_interval("MONTH", n),
).otherwise(truncated_ts_scol - self.get_make_interval("MONTH", mod))
truncated_ts_scol - SF.make_interval("MONTH", n),
).otherwise(truncated_ts_scol - SF.make_interval("MONTH", mod))
else:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - self.get_make_interval("MONTH", mod - n)
truncated_ts_scol - SF.make_interval("MONTH", mod - n)
)

ret = F.to_timestamp(
Expand Down Expand Up @@ -292,19 +279,19 @@ def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:

edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += self.get_make_interval(unit_str, n)
edge_label += SF.make_interval(unit_str, n)
elif right_closed and left_labeled:
edge_label -= self.get_make_interval(unit_str, n)
edge_label -= SF.make_interval(unit_str, n)

if left_labeled:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - self.get_make_interval(unit_str, mod)
truncated_ts_scol - SF.make_interval(unit_str, mod)
)
else:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol + self.get_make_interval(unit_str, n),
).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n))
truncated_ts_scol + SF.make_interval(unit_str, n),
).otherwise(truncated_ts_scol - SF.make_interval(unit_str, mod - n))

ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)

Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""
Additional Spark functions used in pandas-on-Spark.
"""
from pyspark.sql.column import Column
from pyspark.sql import Column, functions as F
from pyspark.sql.utils import is_remote
from typing import Union


def product(col: Column, dropna: bool) -> Column:
Expand Down Expand Up @@ -171,3 +172,16 @@ def null_index(col: Column) -> Column:

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
"MONTH": "months",
"WEEK": "weeks",
"DAY": "days",
"HOUR": "hours",
"MINUTE": "mins",
"SECOND": "secs",
}
return F.make_interval(**{unit_mapping[unit]: F.lit(e)})

0 comments on commit 425ada9

Please sign in to comment.