Skip to content

Commit

Permalink
Explicitly specify supported types with Pandas UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Feb 7, 2018
1 parent b96a083 commit ec708d5
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 61 deletions.
4 changes: 2 additions & 2 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A
enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the
DataFrame to the driver program and should be done on a small subset of the data. Not all Spark
data types are currently supported and an error can be raised if a column has an unsupported type,
see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
Spark will fall back to create the DataFrame without Arrow.

## Pandas UDFs (a.k.a. Vectorized UDFs)
Expand Down Expand Up @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p

### Supported SQL Types

Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`,
Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`,
`ArrayType` of `TimestampType`, and nested `StructType`.

### Setting Arrow Batch Size
Expand Down
119 changes: 71 additions & 48 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3802,15 +3802,16 @@ def random_udf(v):
return random_udf

def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.functions import pandas_udf, col, array
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
col('id').cast('boolean').alias('bool'),
array(col('id')).alias('array_long'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
int_f = pandas_udf(f, IntegerType())
Expand All @@ -3819,10 +3820,11 @@ def test_vectorized_udf_basic(self):
double_f = pandas_udf(f, DoubleType())
decimal_f = pandas_udf(f, DecimalType())
bool_f = pandas_udf(f, BooleanType())
array_long_f = pandas_udf(f, ArrayType(LongType()))
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')))
bool_f(col('bool')), array_long_f('array_long'))
self.assertEquals(df.collect(), res.collect())

def test_register_nondeterministic_vectorized_udf_basic(self):
Expand Down Expand Up @@ -4027,10 +4029,11 @@ def test_vectorized_udf_chained(self):
def test_vectorized_udf_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.select(f(col('id'))).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))

def test_vectorized_udf_return_scalar(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4065,13 +4068,18 @@ def test_vectorized_udf_varargs(self):
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
from pyspark.sql.functions import pandas_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('map'))).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))

with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*BinaryType'):
pandas_udf(lambda x: x, BinaryType())

def test_vectorized_udf_dates(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4289,14 +4297,15 @@ def data(self):
.withColumn("v", explode(col('vs'))).drop('vs')

def test_simple(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))

foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())]),
PandasUDFType.GROUPED_MAP
Expand Down Expand Up @@ -4399,17 +4408,15 @@ def test_datatype_string(self):

def test_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data

foo = pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
PandasUDFType.GROUPED_MAP
)

with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.groupby('id').apply(foo).sort('id').toPandas()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
PandasUDFType.GROUPED_MAP)

def test_wrong_args(self):
from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
Expand All @@ -4428,23 +4435,30 @@ def test_wrong_args(self):
df.groupby('id').apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
PandasUDFType.SCALAR))
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))

def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
from pyspark.sql.functions import pandas_udf, PandasUDFType
schema = StructType(
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.groupby('id').apply(f).collect()
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)

schema = StructType(
[StructField("id", LongType(), True),
StructField("arr_ts", ArrayType(TimestampType()), True)])
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
Expand Down Expand Up @@ -4509,23 +4523,32 @@ def weighted_mean(v, w):
return weighted_mean

def test_manual(self):
from pyspark.sql.functions import pandas_udf, col, array

df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf

result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id')
mean_arr_udf = pandas_udf(
self.pandas_agg_mean_udf.func,
ArrayType(self.pandas_agg_mean_udf.returnType),
self.pandas_agg_mean_udf.evalType)

result1 = df.groupby('id').agg(
sum_udf(df.v),
mean_udf(df.v),
mean_arr_udf(array(df.v))).sort('id')
expected1 = self.spark.createDataFrame(
[[0, 245.0, 24.5],
[1, 255.0, 25.5],
[2, 265.0, 26.5],
[3, 275.0, 27.5],
[4, 285.0, 28.5],
[5, 295.0, 29.5],
[6, 305.0, 30.5],
[7, 315.0, 31.5],
[8, 325.0, 32.5],
[9, 335.0, 33.5]],
['id', 'sum(v)', 'avg(v)'])
[[0, 245.0, 24.5, [24.5]],
[1, 255.0, 25.5, [25.5]],
[2, 265.0, 26.5, [26.5]],
[3, 275.0, 27.5, [27.5]],
[4, 285.0, 28.5, [28.5]],
[5, 295.0, 29.5, [29.5]],
[6, 305.0, 30.5, [30.5]],
[7, 315.0, 31.5, [31.5]],
[8, 325.0, 32.5, [32.5]],
[9, 335.0, 33.5, [33.5]]],
['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

Expand Down Expand Up @@ -4562,14 +4585,14 @@ def test_basic(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())

def test_unsupported_types(self):
from pyspark.sql.types import ArrayType, DoubleType, MapType
from pyspark.sql.types import DoubleType, MapType
from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG)
@pandas_udf(ArrayType(ArrayType(TimestampType())), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return [v.mean(), v.std()]
return v

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,8 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) == TimestampType:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
Expand Down Expand Up @@ -1680,6 +1682,8 @@ def from_arrow_type(at):
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
Expand Down
36 changes: 26 additions & 10 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
_parse_datatype_string
_parse_datatype_string, to_arrow_type, to_arrow_schema

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -112,15 +112,31 @@ def returnType(self):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)

if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
and not isinstance(self._returnType_placeholder, StructType):
raise ValueError("Invalid returnType: returnType must be a StructType for "
"pandas_udf with function type GROUPED_MAP")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \
and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)):
raise NotImplementedError(
"ArrayType, StructType and MapType are not supported with "
"PandasUDFType.GROUPED_AGG")
if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped map Pandas UDF: "
"%s is not supported" % str(self._returnType_placeholder))
else:
raise TypeError("Invalid returnType for a grouped map Pandas "
"UDF: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a scalar Pandas UDF: %s is "
"not supported" % str(self._returnType_placeholder))
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped aggregate Pandas UDF: "
"%s is not supported" % str(self._returnType_placeholder))

return self._returnType_placeholder

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
def wrapped(*series):
import pandas as pd
result = f(*series)
return pd.Series(result)
return pd.Series([result])

return lambda *a: (wrapped(*a), arrow_return_type)

Expand Down

0 comments on commit ec708d5

Please sign in to comment.