Skip to content

Commit

Permalink
Fixes UDF driver reuse bug
Browse files Browse the repository at this point in the history
We need to depend on type information from the node, not the actual callable itself, since we
will mutate the annotations on it for running spark. This just impacts behavior of rerunning things.
  • Loading branch information
skrawcz committed Mar 2, 2023
1 parent a08c392 commit 7558cfa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 10 additions & 0 deletions examples/spark/pyspark_udfs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def my_spark_job(spark: SparkSession, use_pandas_udfs: bool = False):
df = df.select(["spend", "signups", "avg_3wk_spend"] + cols_to_append)
df.explain()
df.show()
# and you can reuse the same driver to execute UDFs on new dataframes:
# df2 = spark.createDataFrame(pandas_df)
# add some extra values to the DF, e.g. aggregates, etc.
# df2 = add_values_to_dataframe(df2)
# execute_inputs = {col: df2 for col in df2.columns}
# df2 = dr.execute([
# "spend_per_signup",
# "spend_zero_mean_unit_variance",
# ], inputs=execute_inputs)
# df2.show()


def add_values_to_dataframe(df: DataFrame) -> DataFrame:
Expand Down
8 changes: 4 additions & 4 deletions hamilton/experimental/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _lambda_udf(
"""
sig = inspect.signature(hamilton_udf)
input_parameters = dict(sig.parameters)
return_type = sig.return_annotation
# return_type = sig.return_annotation
params_from_df = {}

columns_present = set(df.columns)
Expand Down Expand Up @@ -287,7 +287,7 @@ def _lambda_udf(
)
elif all(pandas_annotation.values()):
# pull from annotation here instead of tag.
base_type, type_args = htypes.get_type_information(sig2.return_annotation)
base_type, type_args = htypes.get_type_information(node_.type)
logger.debug("PandasUDF: %s, %s, %s", node_.name, base_type, type_args)
if not type_args:
raise ValueError(
Expand All @@ -303,8 +303,8 @@ def _lambda_udf(
hamilton_udf.__annotations__["return"] = base_type
spark_udf = pandas_udf(hamilton_udf, spark_return_type)
else:
logger.debug("RegularUDF: %s, %s", node_.name, return_type)
spark_return_type = get_spark_type(actual_kwargs, df, hamilton_udf, return_type)
logger.debug("RegularUDF: %s, %s", node_.name, node_.type)
spark_return_type = get_spark_type(actual_kwargs, df, hamilton_udf, node_.type)
spark_udf = udf(hamilton_udf, spark_return_type)
return df.withColumn(
node_.name, spark_udf(*[_value for _name, _value in params_from_df.items()])
Expand Down

0 comments on commit 7558cfa

Please sign in to comment.