Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48924][PS] Add a pandas-like make_interval helper function #47385

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions python/pyspark/pandas/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Generic,
List,
Optional,
Union,
)

import numpy as np
Expand All @@ -42,6 +41,7 @@
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import FrameLike
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.internal import (
InternalField,
InternalFrame,
Expand Down Expand Up @@ -130,19 +130,6 @@ def _resamplekey_type(self) -> DataType:
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]

def get_make_interval( # type: ignore[return]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if we skip some units intentionally here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some units was ignored due to not invoked at that time, I am adding them back to make the helper function cover all pandas units :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great thanks!

self, unit: str, col: Union[Column, int, float]
) -> Column:
col = col if not isinstance(col, (int, float)) else F.lit(col)
if unit == "MONTH":
return F.make_interval(months=col)
if unit == "HOUR":
return F.make_interval(hours=col)
if unit == "MINUTE":
return F.make_interval(mins=col)
if unit == "SECOND":
return F.make_interval(secs=col)

def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
key_type = self._resamplekey_type
origin_scol = F.lit(origin)
Expand Down Expand Up @@ -203,18 +190,18 @@ def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += self.get_make_interval("MONTH", n)
edge_label += SF.make_interval("MONTH", n)
elif right_closed and left_labeled:
edge_label -= self.get_make_interval("MONTH", n)
edge_label -= SF.make_interval("MONTH", n)

if left_labeled:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol - self.get_make_interval("MONTH", n),
).otherwise(truncated_ts_scol - self.get_make_interval("MONTH", mod))
truncated_ts_scol - SF.make_interval("MONTH", n),
).otherwise(truncated_ts_scol - SF.make_interval("MONTH", mod))
else:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - self.get_make_interval("MONTH", mod - n)
truncated_ts_scol - SF.make_interval("MONTH", mod - n)
)

ret = F.to_timestamp(
Expand Down Expand Up @@ -292,19 +279,19 @@ def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:

edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += self.get_make_interval(unit_str, n)
edge_label += SF.make_interval(unit_str, n)
elif right_closed and left_labeled:
edge_label -= self.get_make_interval(unit_str, n)
edge_label -= SF.make_interval(unit_str, n)

if left_labeled:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - self.get_make_interval(unit_str, mod)
truncated_ts_scol - SF.make_interval(unit_str, mod)
)
else:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol + self.get_make_interval(unit_str, n),
).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n))
truncated_ts_scol + SF.make_interval(unit_str, n),
).otherwise(truncated_ts_scol - SF.make_interval(unit_str, mod - n))

ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)

Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""
Additional Spark functions used in pandas-on-Spark.
"""
from pyspark.sql.column import Column
from pyspark.sql import Column, functions as F
from pyspark.sql.utils import is_remote
from typing import Union


def product(col: Column, dropna: bool) -> Column:
Expand Down Expand Up @@ -171,3 +172,16 @@ def null_index(col: Column) -> Column:

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
"MONTH": "months",
"WEEK": "weeks",
"DAY": "days",
"HOUR": "hours",
"MINUTE": "mins",
"SECOND": "secs",
}
return F.make_interval(**{unit_mapping[unit]: F.lit(e)})