Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4192][SQL] Internal API for Python UDT #3068

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 199 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,73 @@ def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])


class UserDefinedType(DataType):
"""
:: WARN: Spark Internal Use Only ::
SQL User-Defined Type (UDT).
"""

@classmethod
def typeName(cls):
return cls.__name__.lower()

@classmethod
def sqlType(cls):
"""
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement sqlType().")

@classmethod
def module(cls):
"""
The Python module of the UDT.
"""
raise NotImplementedError("UDT must implement module().")

@classmethod
def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
raise NotImplementedError("UDT must have a paired Scala UDT.")

def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement serialize().")

def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
raise NotImplementedError("UDT must implement deserialize().")

def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)

def jsonValue(self):
schema = {
"type": "udt",
"pyModule": self.module(),
"pyClass": type(self).__name__,
"class": self.scalaUDT()
}
return schema

@classmethod
def fromJson(cls, json):
pyModule = json['pyModule']
pyClass = json['pyClass']
m = __import__(pyModule, globals(), locals(), [pyClass], -1)
UDT = getattr(m, pyClass)
return UDT()

def __eq__(self, other):
return type(self) == type(other)


_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
Expand Down Expand Up @@ -460,6 +527,12 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
>>> check_datatype(ExamplePointUDT())
True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
True
"""
return _parse_datatype_json_value(json.loads(json_string))

Expand All @@ -479,7 +552,13 @@ def _parse_datatype_json_value(json_value):
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)
tpe = json_value["type"]
if tpe in _all_complex_types:
return _all_complex_types[tpe].fromJson(json_value)
elif tpe == 'udt':
return UserDefinedType.fromJson(json_value)
else:
raise ValueError("not supported type: %s" % tpe)


# Mapping Python types to Spark SQL DataType
Expand All @@ -499,10 +578,18 @@ def _parse_datatype_json_value(json_value):


def _infer_type(obj):
"""Infer the DataType from obj"""
"""Infer the DataType from obj

>>> p = ExamplePoint(1.0, 2.0)
>>> _infer_type(p)
ExamplePointUDT
"""
if obj is None:
raise ValueError("Can not infer type for None")

if hasattr(obj, '__UDT__'):
return obj.__UDT__

dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
Expand Down Expand Up @@ -548,8 +635,95 @@ def _infer_schema(row):
return StructType(fields)


def _need_python_to_sql_conversion(dataType):
"""
Checks whether we need python to sql conversion for the given type.
For now, only UDTs need this conversion.

>>> _need_python_to_sql_conversion(DoubleType())
False
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
False
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
>>> _need_python_to_sql_conversion(schema1)
True
>>> schema2 = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> _need_python_to_sql_conversion(schema2)
True
"""
if isinstance(dataType, StructType):
return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
elif isinstance(dataType, ArrayType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_python_to_sql_conversion(dataType.keyType) or \
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
else:
return False


def _python_to_sql_converter(dataType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can _create_converter do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_converter doesn't do this. It is used to drop the names if user provides Row objects, called by _drop_schema. I think we need to refactor the code a little bit during QA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

"""
Returns a converter that converts a Python object into a SQL datum for the given type.

>>> conv = _python_to_sql_converter(DoubleType())
>>> conv(1.0)
1.0
>>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
>>> conv([1.0, 2.0])
[1.0, 2.0]
>>> conv = _python_to_sql_converter(ExamplePointUDT())
>>> conv(ExamplePoint(1.0, 2.0))
[1.0, 2.0]
>>> schema = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> conv = _python_to_sql_converter(schema)
>>> conv((1.0, ExamplePoint(1.0, 2.0)))
(1.0, [1.0, 2.0])
"""
if not _need_python_to_sql_conversion(dataType):
return lambda x: x

if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
converters = map(_python_to_sql_converter, types)

def converter(obj):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
return tuple(c(v) for c, v in zip(converters, obj))
else:
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
return lambda a: [element_converter(v) for v in a]
elif isinstance(dataType, MapType):
key_converter = _python_to_sql_converter(dataType.keyType)
value_converter = _python_to_sql_converter(dataType.valueType)
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
elif isinstance(dataType, UserDefinedType):
return lambda obj: dataType.serialize(obj)
else:
raise ValueError("Unexpected type %r" % dataType)


def _create_converter(obj, dataType):
"""Create an converter to drop the names of fields in obj """
"""Create an converter to drop the names of fields in obj"""
if isinstance(dataType, ArrayType):
conv = _create_converter(obj[0], dataType.elementType)
return lambda row: map(conv, row)
Expand Down Expand Up @@ -775,11 +949,21 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
>>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to remove these tests for ExamplePoint, it should be in tests.py (or already covered)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in the same group of other doctests for this private function. I didn't find one for _verify_type in SQLTests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your tests in tests.py have covered these internal functions, so I think it's fine to not have them here.

>>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
# all objects are nullable
if obj is None:
return

if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check that the serialized object matches with dataType.sqlType ?


_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType

Expand Down Expand Up @@ -854,6 +1038,8 @@ def _has_struct_or_date(dt):
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
elif isinstance(dt, UserDefinedType):
return True
return False


Expand Down Expand Up @@ -924,6 +1110,9 @@ def Dict(d):
elif isinstance(dataType, DateType):
return datetime.date

elif isinstance(dataType, UserDefinedType):
return lambda datum: dataType.deserialize(datum)

elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)

Expand Down Expand Up @@ -1184,6 +1373,10 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)

# convert python objects to sql data
converter = _python_to_sql_converter(schema)
rdd = rdd.map(converter)

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
Expand Down Expand Up @@ -1817,6 +2010,7 @@ def _test():
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
from pyspark.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
Expand All @@ -1828,6 +2022,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
Expand Down
Loading