diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 8ceb88731132e..2be46a80866e2 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -791,6 +791,10 @@ def _restore_object(dataType, obj): def _create_object(cls, v): """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() return cls(v) if v is not None else v @@ -804,14 +808,16 @@ def getter(self): return getter -def _has_struct(dt): - """Return whether `dt` is or has StructType in it""" +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): - return _has_struct(dt.elementType) + return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct(dt.valueType) + return _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True return False @@ -824,7 +830,7 @@ def _create_properties(fields): or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) - if _has_struct(f.dataType): + if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: @@ -879,6 +885,9 @@ def Dict(d): return Dict + elif isinstance(dataType, DateType): + return datetime.date + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1098,7 +1107,7 @@ def applySchema(self, rdd, schema): ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.datetime(2010, 1, 1, 0, 0), + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) >>> srdd.registerTempTable("table2")