Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3988][SQL] add public API for date type #2901

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@


__all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"SQLContext", "HiveContext", "SchemaRDD", "Row"]
Expand Down Expand Up @@ -132,6 +132,14 @@ class BooleanType(PrimitiveType):
"""


class DateType(PrimitiveType):

"""Spark SQL DateType

The data type representing datetime.date values.
"""


class TimestampType(PrimitiveType):

"""Spark SQL TimestampType
Expand Down Expand Up @@ -438,7 +446,7 @@ def _parse_datatype_json_value(json_value):
return _all_complex_types[json_value["type"]].fromJson(json_value)


# Mapping Python types to Spark SQL DateType
# Mapping Python types to Spark SQL DataType
_type_mappings = {
bool: BooleanType,
int: IntegerType,
Expand All @@ -448,8 +456,8 @@ def _parse_datatype_json_value(json_value):
unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
datetime.datetime: TimestampType,
datetime.date: TimestampType,
datetime.time: TimestampType,
}

Expand Down Expand Up @@ -656,10 +664,10 @@ def _infer_schema_type(obj, dataType):
"""
Fill the dataType with types infered from obj

>>> schema = _parse_schema_abstract("a b c")
>>> row = (1, 1.0, "str")
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
StructType...IntegerType...DoubleType...StringType...
StructType...IntegerType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
Expand Down Expand Up @@ -703,6 +711,7 @@ def _infer_schema_type(obj, dataType):
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
BinaryType: (bytearray,),
DateType: (datetime.date,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
Expand Down Expand Up @@ -740,7 +749,7 @@ def _verify_type(obj, dataType):

# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept abject in type %s"
raise TypeError("%s can not accept object in type %s"
% (dataType, type(obj)))

if isinstance(dataType, ArrayType):
Expand All @@ -767,7 +776,7 @@ def _restore_object(dataType, obj):
""" Restore object during unpickling. """
# use id(dataType) as key to speed up lookup in dict
# Because of batched pickling, dataType will be the
# same object in mose cases.
# same object in most cases.
k = id(dataType)
cls = _cached_cls.get(k)
if cls is None:
Expand Down Expand Up @@ -1065,7 +1074,9 @@ def applySchema(self, rdd, schema):
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]

>>> from datetime import datetime
>>> from datetime import date
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: these two lines can be combined.

>>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
... date(2010, 1, 1),
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3], None)])
>>> schema = StructType([
Expand All @@ -1075,6 +1086,7 @@ def applySchema(self, rdd, schema):
... StructField("short2", ShortType(), False),
... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
... StructField("date", DateType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
... MapType(StringType(), IntegerType(), False), False),
Expand All @@ -1084,10 +1096,11 @@ def applySchema(self, rdd, schema):
... StructField("null", DoubleType(), True)])
>>> srdd = sqlCtx.applySchema(rdd, schema)
>>> results = srdd.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
... x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0]
(127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.datetime(2010, 1, 1, 0, 0),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies because of using pyrolite, java.sql.Date is serialized in the same way as java.sql.Timestamp, since they are all subtype of java.util.Date. And this make the dumps() function to generate datetime instead of date for java.util.Date. I think this is related to your comments in JIRA SPARK-2674

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. After the data was deserialized in Python, we need to some data coversions, so we can convert datetime to date if DataType is DateType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so should the convert in python side or scala side, which one would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can only do it in Python side.

datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)

>>> srdd.registerTempTable("table2")
>>> sqlCtx.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object ScalaReflection {
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
case obj: DecimalType.JvmType => DecimalType
case obj: DateType.JvmType => DateType
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ object DataType {
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
| "DecimalType" ^^^ DecimalType
| "DateType" ^^^ DateType
| "TimestampType" ^^^ TimestampType
)

Expand Down Expand Up @@ -198,7 +199,8 @@ trait PrimitiveType extends DataType {
}

object PrimitiveType {
private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++
NativeType.all

private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst

import java.math.BigInteger
import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.scalatest.FunSuite

Expand All @@ -43,6 +43,7 @@ case class NullableData(
booleanField: java.lang.Boolean,
stringField: String,
decimalField: BigDecimal,
dateField: Date,
timestampField: Timestamp,
binaryField: Array[Byte])

Expand Down Expand Up @@ -96,6 +97,7 @@ class ScalaReflectionSuite extends FunSuite {
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
StructField("decimalField", DecimalType, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
nullable = true))
Expand Down Expand Up @@ -199,8 +201,11 @@ class ScalaReflectionSuite extends FunSuite {
// DecimalType
assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))

// DateType
assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))

// TimestampType
assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00")))
assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00")))

// NullType
assert(NullType === typeOfObject(null))
Expand Down
10 changes: 7 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ByteType => true
case ShortType => true
case FloatType => true
case DateType => true
case TimestampType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
Expand All @@ -453,9 +454,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

// Converts value to the type specified by the data type.
// Because Python does not have data types for TimestampType, FloatType, ShortType, and
// ByteType, we need to explicitly convert values in columns of these data types to the desired
// JVM data types.
// Because Python does not have data types for DateType, TimestampType, FloatType, ShortType,
// and ByteType, we need to explicitly convert values in columns of these data types to the
// desired JVM data types.
def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
// TODO: We should check nullable
case (null, _) => null
Expand All @@ -475,6 +476,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
case (e, f) => convert(e, f.dataType)
}): Row

case (c: java.util.Calendar, DateType) =>
new java.sql.Date(c.getTime().getTime())

case (c: java.util.Calendar, TimestampType) =>
new java.sql.Timestamp(c.getTime().getTime())

Expand Down
20 changes: 14 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.json
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
Expand Down Expand Up @@ -372,13 +372,20 @@ private[sql] object JsonRDD extends Logging {
}
}

private def toDate(value: Any): Date = {
value match {
// only support string as date
case value: java.lang.String => Date.valueOf(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more formats support for Date? At least ISO8601, see some discussion here: http://stackoverflow.com/questions/10286204/the-right-json-date-format

We would like have jsonRDD() as robust as possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could we guess the DateType given a String in Date format?

We also should do this for TimestampType, it will be great if you could fix them in this PR.

}
}

private def toTimestamp(value: Any): Timestamp = {
value match {
case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
case value: java.lang.Long => new Timestamp(value)
case value: java.lang.String => Timestamp.valueOf(value)
}
}
case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
case value: java.lang.Long => new Timestamp(value)
case value: java.lang.String => Timestamp.valueOf(value)
}
}

private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={
if (value == null) {
Expand All @@ -396,6 +403,7 @@ private[sql] object JsonRDD extends Logging {
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
case TimestampType => toTimestamp(value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.api.java;

import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -39,6 +40,7 @@ public class JavaRowSuite {
private boolean booleanValue;
private String stringValue;
private byte[] binaryValue;
private Date dateValue;
private Timestamp timestampValue;

@Before
Expand All @@ -53,6 +55,7 @@ public void setUp() {
booleanValue = true;
stringValue = "this is a string";
binaryValue = stringValue.getBytes();
dateValue = Date.valueOf("2014-06-30");
timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0");
}

Expand All @@ -76,6 +79,7 @@ public void constructSimpleRow() {
new Boolean(booleanValue),
stringValue, // StringType
binaryValue, // BinaryType
dateValue, // DateType
timestampValue, // TimestampType
null // null
);
Expand Down Expand Up @@ -114,9 +118,10 @@ public void constructSimpleRow() {
Assert.assertEquals(stringValue, simpleRow.getString(15));
Assert.assertEquals(stringValue, simpleRow.get(15));
Assert.assertEquals(binaryValue, simpleRow.get(16));
Assert.assertEquals(timestampValue, simpleRow.get(17));
Assert.assertEquals(true, simpleRow.isNullAt(18));
Assert.assertEquals(null, simpleRow.get(18));
Assert.assertEquals(dateValue, simpleRow.get(17));
Assert.assertEquals(timestampValue, simpleRow.get(18));
Assert.assertEquals(true, simpleRow.isNullAt(19));
Assert.assertEquals(null, simpleRow.get(19));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public void createDataTypes() {
checkDataType(DataType.StringType);
checkDataType(DataType.BinaryType);
checkDataType(DataType.BooleanType);
checkDataType(DataType.DateType);
checkDataType(DataType.TimestampType);
checkDataType(DataType.DecimalType);
checkDataType(DataType.DoubleType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(org.apache.spark.sql.StringType)
checkDataType(org.apache.spark.sql.BinaryType)
checkDataType(org.apache.spark.sql.BooleanType)
checkDataType(org.apache.spark.sql.DateType)
checkDataType(org.apache.spark.sql.TimestampType)
checkDataType(org.apache.spark.sql.DecimalType)
checkDataType(org.apache.spark.sql.DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

class JsonSuite extends QueryTest {
import TestJsonData._
Expand Down Expand Up @@ -58,8 +58,11 @@ class JsonSuite extends QueryTest {
checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
checkTypePromotion(new Timestamp(intNumber.toLong),
enforceCorrectType(intNumber.toLong, TimestampType))
val strDate = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))

val strDate = "2014-10-15"
checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType))
}

test("Get compatible type") {
Expand Down