-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-4192][SQL] Internal API for Python UDT #3068
Changes from 8 commits
b7f666d
39f19e0
4e84fce
e98d9d0
f740379
75223db
7c4a6a9
2c9d7e4
dba5ea7
acff637
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -408,6 +408,73 @@ def fromJson(cls, json): | |
return StructType([StructField.fromJson(f) for f in json["fields"]]) | ||
|
||
|
||
class UserDefinedType(DataType): | ||
""" | ||
:: WARN: Spark Internal Use Only :: | ||
SQL User-Defined Type (UDT). | ||
""" | ||
|
||
@classmethod | ||
def typeName(cls): | ||
return cls.__name__.lower() | ||
|
||
@classmethod | ||
def sqlType(cls): | ||
""" | ||
Underlying SQL storage type for this UDT. | ||
""" | ||
raise NotImplementedError("UDT must implement sqlType().") | ||
|
||
@classmethod | ||
def module(cls): | ||
""" | ||
The Python module of the UDT. | ||
""" | ||
raise NotImplementedError("UDT must implement module().") | ||
|
||
@classmethod | ||
def scalaUDT(cls): | ||
""" | ||
The class name of the paired Scala UDT. | ||
""" | ||
raise NotImplementedError("UDT must have a paired Scala UDT.") | ||
|
||
def serialize(self, obj): | ||
""" | ||
Converts the a user-type object into a SQL datum. | ||
""" | ||
raise NotImplementedError("UDT must implement serialize().") | ||
|
||
def deserialize(self, datum): | ||
""" | ||
Converts a SQL datum into a user-type object. | ||
""" | ||
raise NotImplementedError("UDT must implement deserialize().") | ||
|
||
def json(self): | ||
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) | ||
|
||
def jsonValue(self): | ||
schema = { | ||
"type": "udt", | ||
"pyModule": self.module(), | ||
"pyClass": type(self).__name__, | ||
"class": self.scalaUDT() | ||
} | ||
return schema | ||
|
||
@classmethod | ||
def fromJson(cls, json): | ||
pyModule = json['pyModule'] | ||
pyClass = json['pyClass'] | ||
m = __import__(pyModule, globals(), locals(), [pyClass], -1) | ||
UDT = getattr(m, pyClass) | ||
return UDT() | ||
|
||
def __eq__(self, other): | ||
return type(self) == type(other) | ||
|
||
|
||
_all_primitive_types = dict((v.typeName(), v) | ||
for v in globals().itervalues() | ||
if type(v) is PrimitiveTypeSingleton and | ||
|
@@ -460,6 +527,12 @@ def _parse_datatype_json_string(json_string): | |
... complex_arraytype, False) | ||
>>> check_datatype(complex_maptype) | ||
True | ||
>>> check_datatype(ExamplePointUDT()) | ||
True | ||
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), | ||
... StructField("point", ExamplePointUDT(), False)]) | ||
>>> check_datatype(structtype_with_udt) | ||
True | ||
""" | ||
return _parse_datatype_json_value(json.loads(json_string)) | ||
|
||
|
@@ -479,7 +552,13 @@ def _parse_datatype_json_value(json_value): | |
else: | ||
raise ValueError("Could not parse datatype: %s" % json_value) | ||
else: | ||
return _all_complex_types[json_value["type"]].fromJson(json_value) | ||
tpe = json_value["type"] | ||
if tpe in _all_complex_types: | ||
return _all_complex_types[tpe].fromJson(json_value) | ||
elif tpe == 'udt': | ||
return UserDefinedType.fromJson(json_value) | ||
else: | ||
raise ValueError("not supported type: %s" % tpe) | ||
|
||
|
||
# Mapping Python types to Spark SQL DataType | ||
|
@@ -499,10 +578,18 @@ def _parse_datatype_json_value(json_value): | |
|
||
|
||
def _infer_type(obj): | ||
"""Infer the DataType from obj""" | ||
"""Infer the DataType from obj | ||
|
||
>>> p = ExamplePoint(1.0, 2.0) | ||
>>> _infer_type(p) | ||
ExamplePointUDT | ||
""" | ||
if obj is None: | ||
raise ValueError("Can not infer type for None") | ||
|
||
if hasattr(obj, '__UDT__'): | ||
return obj.__UDT__ | ||
|
||
dataType = _type_mappings.get(type(obj)) | ||
if dataType is not None: | ||
return dataType() | ||
|
@@ -548,8 +635,95 @@ def _infer_schema(row): | |
return StructType(fields) | ||
|
||
|
||
def _need_python_to_sql_conversion(dataType): | ||
""" | ||
Checks whether we need python to sql conversion for the given type. | ||
For now, only UDTs need this conversion. | ||
|
||
>>> _need_python_to_sql_conversion(DoubleType()) | ||
False | ||
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), | ||
... StructField("values", ArrayType(DoubleType(), False), False)]) | ||
>>> _need_python_to_sql_conversion(schema0) | ||
False | ||
>>> _need_python_to_sql_conversion(ExamplePointUDT()) | ||
True | ||
>>> schema1 = ArrayType(ExamplePointUDT(), False) | ||
>>> _need_python_to_sql_conversion(schema1) | ||
True | ||
>>> schema2 = StructType([StructField("label", DoubleType(), False), | ||
... StructField("point", ExamplePointUDT(), False)]) | ||
>>> _need_python_to_sql_conversion(schema2) | ||
True | ||
""" | ||
if isinstance(dataType, StructType): | ||
return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) | ||
elif isinstance(dataType, ArrayType): | ||
return _need_python_to_sql_conversion(dataType.elementType) | ||
elif isinstance(dataType, MapType): | ||
return _need_python_to_sql_conversion(dataType.keyType) or \ | ||
_need_python_to_sql_conversion(dataType.valueType) | ||
elif isinstance(dataType, UserDefinedType): | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def _python_to_sql_converter(dataType): | ||
""" | ||
Returns a converter that converts a Python object into a SQL datum for the given type. | ||
|
||
>>> conv = _python_to_sql_converter(DoubleType()) | ||
>>> conv(1.0) | ||
1.0 | ||
>>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) | ||
>>> conv([1.0, 2.0]) | ||
[1.0, 2.0] | ||
>>> conv = _python_to_sql_converter(ExamplePointUDT()) | ||
>>> conv(ExamplePoint(1.0, 2.0)) | ||
[1.0, 2.0] | ||
>>> schema = StructType([StructField("label", DoubleType(), False), | ||
... StructField("point", ExamplePointUDT(), False)]) | ||
>>> conv = _python_to_sql_converter(schema) | ||
>>> conv((1.0, ExamplePoint(1.0, 2.0))) | ||
(1.0, [1.0, 2.0]) | ||
""" | ||
if not _need_python_to_sql_conversion(dataType): | ||
return lambda x: x | ||
|
||
if isinstance(dataType, StructType): | ||
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) | ||
converters = map(_python_to_sql_converter, types) | ||
|
||
def converter(obj): | ||
if isinstance(obj, dict): | ||
return tuple(c(obj.get(n)) for n, c in zip(names, converters)) | ||
elif isinstance(obj, tuple): | ||
if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): | ||
return tuple(c(v) for c, v in zip(converters, obj)) | ||
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs | ||
d = dict(obj) | ||
return tuple(c(d.get(n)) for n, c in zip(names, converters)) | ||
else: | ||
return tuple(c(v) for c, v in zip(converters, obj)) | ||
else: | ||
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) | ||
return converter | ||
elif isinstance(dataType, ArrayType): | ||
element_converter = _python_to_sql_converter(dataType.elementType) | ||
return lambda a: [element_converter(v) for v in a] | ||
elif isinstance(dataType, MapType): | ||
key_converter = _python_to_sql_converter(dataType.keyType) | ||
value_converter = _python_to_sql_converter(dataType.valueType) | ||
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) | ||
elif isinstance(dataType, UserDefinedType): | ||
return lambda obj: dataType.serialize(obj) | ||
else: | ||
raise ValueError("Unexpected type %r" % dataType) | ||
|
||
|
||
def _create_converter(obj, dataType): | ||
"""Create an converter to drop the names of fields in obj """ | ||
"""Create an converter to drop the names of fields in obj""" | ||
if isinstance(dataType, ArrayType): | ||
conv = _create_converter(obj[0], dataType.elementType) | ||
return lambda row: map(conv, row) | ||
|
@@ -775,11 +949,21 @@ def _verify_type(obj, dataType): | |
Traceback (most recent call last): | ||
... | ||
ValueError:... | ||
>>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to remove these tests for ExamplePoint, it should be in tests.py (or already covered) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is in the same group of other doctests for this private function. I didn't find one for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your tests in tests.py have covered these internal functions, so I think it's fine to not have them here. |
||
>>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL | ||
Traceback (most recent call last): | ||
... | ||
ValueError:... | ||
""" | ||
# all objects are nullable | ||
if obj is None: | ||
return | ||
|
||
if isinstance(dataType, UserDefinedType): | ||
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): | ||
raise ValueError("%r is not an instance of type %r" % (obj, dataType)) | ||
return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also check that the serialized object matches with dataType.sqlType ? |
||
|
||
_type = type(dataType) | ||
assert _type in _acceptable_types, "unkown datatype: %s" % dataType | ||
|
||
|
@@ -854,6 +1038,8 @@ def _has_struct_or_date(dt): | |
return _has_struct_or_date(dt.valueType) | ||
elif isinstance(dt, DateType): | ||
return True | ||
elif isinstance(dt, UserDefinedType): | ||
return True | ||
return False | ||
|
||
|
||
|
@@ -924,6 +1110,9 @@ def Dict(d): | |
elif isinstance(dataType, DateType): | ||
return datetime.date | ||
|
||
elif isinstance(dataType, UserDefinedType): | ||
return lambda datum: dataType.deserialize(datum) | ||
|
||
elif not isinstance(dataType, StructType): | ||
raise Exception("unexpected data type: %s" % dataType) | ||
|
||
|
@@ -1184,6 +1373,10 @@ def applySchema(self, rdd, schema): | |
for row in rows: | ||
_verify_type(row, schema) | ||
|
||
# convert python objects to sql data | ||
converter = _python_to_sql_converter(schema) | ||
rdd = rdd.map(converter) | ||
|
||
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) | ||
jrdd = self._pythonToJava(rdd._jrdd, batched) | ||
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) | ||
|
@@ -1817,6 +2010,7 @@ def _test(): | |
# let doctest run in pyspark.sql, so DataTypes can be picklable | ||
import pyspark.sql | ||
from pyspark.sql import Row, SQLContext | ||
from pyspark.tests import ExamplePoint, ExamplePointUDT | ||
globs = pyspark.sql.__dict__.copy() | ||
# The small batch size here ensures that we see multiple batches, | ||
# even in these small test examples: | ||
|
@@ -1828,6 +2022,8 @@ def _test(): | |
Row(field1=2, field2="row2"), | ||
Row(field1=3, field2="row3")] | ||
) | ||
globs['ExamplePoint'] = ExamplePoint | ||
globs['ExamplePointUDT'] = ExamplePointUDT | ||
jsonStrings = [ | ||
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}', | ||
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can
_create_converter
do this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_create_converter
doesn't do this. It is used to drop the names if user provides Row objects, called by_drop_schema
. I think we need to refactor the code a little bit during QA.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.