Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23352][PYTHON][BRANCH-2.3] Explicitly specify supported types in Pandas UDFs #20588

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A
enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the
DataFrame to the driver program and should be done on a small subset of the data. Not all Spark
data types are currently supported and an error can be raised if a column has an unsupported type,
see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
Spark will fall back to create the DataFrame without Arrow.

## Pandas UDFs (a.k.a. Vectorized UDFs)
Expand Down Expand Up @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p

### Supported SQL Types

Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`,
Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`,
`ArrayType` of `TimestampType`, and nested `StructType`.

### Setting Arrow Batch Size
Expand Down
86 changes: 50 additions & 36 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3736,10 +3736,10 @@ def foo(x):
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

@pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR)
@pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
Expand Down Expand Up @@ -3776,7 +3776,7 @@ def zero_with_type():
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid returnType'):
with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
Expand Down Expand Up @@ -3825,15 +3825,16 @@ def random_udf(v):
return random_udf

def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.functions import pandas_udf, col, array
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
col('id').cast('boolean').alias('bool'),
array(col('id')).alias('array_long'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
int_f = pandas_udf(f, IntegerType())
Expand All @@ -3842,10 +3843,11 @@ def test_vectorized_udf_basic(self):
double_f = pandas_udf(f, DoubleType())
decimal_f = pandas_udf(f, DecimalType())
bool_f = pandas_udf(f, BooleanType())
array_long_f = pandas_udf(f, ArrayType(LongType()))
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')))
bool_f(col('bool')), array_long_f('array_long'))
self.assertEquals(df.collect(), res.collect())

def test_register_nondeterministic_vectorized_udf_basic(self):
Expand Down Expand Up @@ -4050,10 +4052,11 @@ def test_vectorized_udf_chained(self):
def test_vectorized_udf_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.select(f(col('id'))).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))

def test_vectorized_udf_return_scalar(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4088,13 +4091,18 @@ def test_vectorized_udf_varargs(self):
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
from pyspark.sql.functions import pandas_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('map'))).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))

with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
pandas_udf(lambda x: x, BinaryType())

def test_vectorized_udf_dates(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4325,15 +4333,16 @@ def data(self):
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')

def test_simple(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
def test_supported_types(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I start to worry about the test coverage of vectorized udfs and arrow-based to/from pandas df. Do we have any plan in PySpark to test all the data types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, agree. Yup, I was thinking of doing it. But if you (or your colleagues) are working on that or have a plan, no need to block it by me :). please go ahead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe open a JIRA and ask the OSS community to do it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Filed SPARK-23401.

from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))

foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())]),
PandasUDFType.GROUPED_MAP
Expand Down Expand Up @@ -4436,17 +4445,15 @@ def test_datatype_string(self):

def test_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data

foo = pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
PandasUDFType.GROUPED_MAP
)

with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.groupby('id').apply(foo).sort('id').toPandas()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
PandasUDFType.GROUPED_MAP)

def test_wrong_args(self):
from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
Expand All @@ -4465,23 +4472,30 @@ def test_wrong_args(self):
df.groupby('id').apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
PandasUDFType.SCALAR))
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))

def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
from pyspark.sql.functions import pandas_udf, PandasUDFType
schema = StructType(
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.groupby('id').apply(f).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)

schema = StructType(
[StructField("id", LongType(), True),
StructField("arr_ts", ArrayType(TimestampType()), True)])
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)

# Regression test for SPARK-23314
def test_timestamp_dst(self):
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,8 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) == TimestampType:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
Expand Down Expand Up @@ -1680,6 +1682,8 @@ def from_arrow_type(at):
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
Expand Down
25 changes: 20 additions & 5 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pyspark import SparkContext, since
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
to_arrow_type, to_arrow_schema

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -109,10 +110,24 @@ def returnType(self):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)

if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
and not isinstance(self._returnType_placeholder, StructType):
raise ValueError("Invalid returnType: returnType must be a StructType for "
"pandas_udf with function type GROUPED_MAP")
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with scalar Pandas UDFs: %s is "
"not supported" % str(self._returnType_placeholder))
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with grouped map Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))
else:
raise TypeError("Invalid returnType for grouped map Pandas "
"UDFs: returnType must be a StructType.")

return self._returnType_placeholder

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ object SQLConf {
"for use with pyspark.sql.DataFrame.toPandas, and " +
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
"The following data types are unsupported: " +
"MapType, ArrayType of TimestampType, and nested StructType.")
"BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
.booleanConf
.createWithDefault(false)

Expand Down