Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for specifying custom date format for date and timestamp types. #280

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ When reading files the API accepts several options:
* `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame
* `dateFormat`: specificy a string that indicates a date format. Custom date formats follow the formats at [`java.text.SimpleDateFormat`](https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html). This applies to both `DateType` and `TimestampType`. By default, it is `null` which means trying to parse times and date by `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.

The package also supports saving simple (non-nested) DataFrame. When writing files the API accepts several options:
* `path`: location of files.
Expand Down
12 changes: 10 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CsvParser extends Serializable {
private var inferSchema: Boolean = false
private var codec: String = null
private var nullValue: String = ""
private var dateFormat: String = null

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -117,6 +118,11 @@ class CsvParser extends Serializable {
this
}

def withDateFormat(dateFormat: String): CsvParser = {
this.dateFormat = dateFormat
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
Expand All @@ -136,7 +142,8 @@ class CsvParser extends Serializable {
schema,
inferSchema,
codec,
nullValue)(sqlContext)
nullValue,
dateFormat)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand All @@ -157,7 +164,8 @@ class CsvParser extends Serializable {
schema,
inferSchema,
codec,
nullValue)(sqlContext)
nullValue,
dateFormat)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
}
17 changes: 13 additions & 4 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.databricks.spark.csv

import java.io.IOException
import java.text.SimpleDateFormat

import scala.collection.JavaConversions._
import scala.util.control.NonFatal
Expand Down Expand Up @@ -47,9 +48,13 @@ case class CsvRelation protected[spark] (
userSchema: StructType = null,
inferCsvSchema: Boolean,
codec: String = null,
nullValue: String = "")(@transient val sqlContext: SQLContext)
nullValue: String = "",
dateFormat: String = null)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {

// Share date format object as it is expensive to parse date pattern.
private val dateFormatter = if (dateFormat != null) new SimpleDateFormat(dateFormat) else null

private val logger = LoggerFactory.getLogger(CsvRelation.getClass)

// Parse mode flags
Expand Down Expand Up @@ -96,6 +101,7 @@ case class CsvRelation protected[spark] (
}

override def buildScan: RDD[Row] = {
val simpleDateFormatter = dateFormatter
val schemaFields = schema.fields
tokenRdd(schemaFields.map(_.name)).flatMap { tokens =>

Expand All @@ -112,7 +118,7 @@ case class CsvRelation protected[spark] (
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls, nullValue)
treatEmptyValuesAsNulls, nullValue, simpleDateFormatter)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -142,6 +148,7 @@ case class CsvRelation protected[spark] (
* both the indices produced by `requiredColumns` and the ones of tokens.
*/
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
val simpleDateFormatter = dateFormatter
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val shouldTableScan = schemaFields.deep == requiredFields.deep
Expand Down Expand Up @@ -189,7 +196,8 @@ case class CsvRelation protected[spark] (
field.dataType,
field.nullable,
treatEmptyValuesAsNulls,
nullValue
nullValue,
simpleDateFormatter
)
subIndex = subIndex + 1
}
Expand Down Expand Up @@ -237,7 +245,8 @@ case class CsvRelation protected[spark] (
firstRow.zipWithIndex.map { case (value, index) => s"C$index"}
}
if (this.inferCsvSchema) {
InferSchema(tokenRdd(header), header, nullValue)
val simpleDateFormatter = dateFormatter
InferSchema(tokenRdd(header), header, nullValue, simpleDateFormatter)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class DefaultSource
}
val nullValue = parameters.getOrElse("nullValue", "")

val dateFormat = parameters.getOrElse("dateFormat", null)

val codec = parameters.getOrElse("codec", null)

CsvRelation(
Expand All @@ -156,7 +158,8 @@ class DefaultSource
schema,
inferSchemaFlag,
codec,
nullValue)(sqlContext)
nullValue,
dateFormat)(sqlContext)
}

override def createRelation(
Expand Down
116 changes: 64 additions & 52 deletions src/main/scala/com/databricks/spark/csv/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.databricks.spark.csv.util

import java.sql.Timestamp
import java.text.SimpleDateFormat

import scala.util.control.Exception._

Expand All @@ -32,13 +33,13 @@ private[csv] object InferSchema {
* 3. Replace any null types with string type
*/
def apply(
tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = ""): StructType = {

tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = "",
dateFormatter: SimpleDateFormat = null): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(
inferRowType(nullValue),
inferRowType(nullValue, dateFormatter),
mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
Expand All @@ -52,11 +53,11 @@ private[csv] object InferSchema {
StructType(structFields)
}

private def inferRowType(nullValue: String)
private def inferRowType(nullValue: String, dateFormatter: SimpleDateFormat)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue, dateFormatter)
i+=1
}
rowSoFar
Expand All @@ -75,8 +76,62 @@ private[csv] object InferSchema {
* point checking if it is an Int, as the final type must be Double or higher.
*/
private[csv] def inferField(typeSoFar: DataType,
field: String,
nullValue: String = ""): DataType = {
field: String,
nullValue: String = "",
dateFormatter: SimpleDateFormat = null): DataType = {
def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent is off for this entire block

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um.. Do you mean the indentation correction as below?

  • from
  private[csv] def inferField(typeSoFar: DataType,
    field: String,
    nullValue: String = "",
    dateFormatter: SimpleDateFormat = null): DataType = {
    def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
      IntegerType
    } else {
      tryParseLong(field)
    }
...
  • to
  private[csv] def inferField(typeSoFar: DataType,
    field: String,
    nullValue: String = "",
    dateFormatter: SimpleDateFormat = null): DataType = {
      def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
        IntegerType
      } else {
        tryParseLong(field)
      }
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. The problem is for lines above:

  private[csv] def inferField(typeSoFar: DataType,
      field: String,
      nullValue: String = "",
      dateFormatter: SimpleDateFormat = null): DataType = {
    def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
      IntegerType
    } else {
      tryParseLong(field)
    }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

IntegerType
} else {
tryParseLong(field)
}

def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field)
}

def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
} else {
tryParseTimestamp(field)
}
}

def tryParseTimestamp(field: String): DataType = {
if (dateFormatter != null) {
// This case infers a custom `dataFormat` is set.
if ((allCatch opt dateFormatter.parse(field)).isDefined){
TimestampType
} else {
tryParseBoolean(field)
}
} else {
// We keep this for backwords competibility.
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
}
}
}

def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
stringType()
}
}

// Defining a function to return the StringType constant is necessary in order to work around
// a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions;
// see issue #128 for more details.
def stringType(): DataType = {
StringType
}

if (field == null || field.isEmpty || field == nullValue) {
typeSoFar
} else {
Expand All @@ -94,49 +149,6 @@ private[csv] object InferSchema {
}
}

private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
tryParseLong(field)
}

private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field)
}

private def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
} else {
tryParseTimestamp(field)
}
}

def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
}
}

def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
stringType()
}
}

// Defining a function to return the StringType constant is necessary in order to work around
// a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions;
// see issue #128 for more details.
private def stringType(): DataType = {
StringType
}

/**
* Copied from internal Spark api
* [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.databricks.spark.csv.util

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.NumberFormat
import java.text.{SimpleDateFormat, NumberFormat}
import java.util.Locale

import org.apache.spark.sql.types._
Expand All @@ -44,7 +44,8 @@ object TypeCast {
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false,
nullValue: String = ""): Any = {
nullValue: String = "",
dateFormatter: SimpleDateFormat = null): Any = {
// if nullValue is not an empty string, don't require treatEmptyValuesAsNulls
// to be set to true
val nullValueIsNotEmpty = nullValue != ""
Expand All @@ -65,9 +66,11 @@ object TypeCast {
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType if dateFormatter != null =>
new Timestamp(dateFormatter.parse(datum).getTime)
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType if dateFormatter != null =>
new Date(dateFormatter.parse(datum).getTime)
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/dates.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
date
26/08/2015 18:00
27/10/2014 18:30
28/01/2016 20:00
47 changes: 46 additions & 1 deletion src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ package com.databricks.spark.csv

import java.io.File
import java.nio.charset.UnsupportedCharsetException
import java.sql.Timestamp
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import scala.io.Source

import com.databricks.spark.csv.util.ParseModes
Expand All @@ -44,6 +45,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
val commentsFile = "src/test/resources/comments.csv"
val disableCommentsFile = "src/test/resources/disable_comments.csv"
val boolFile = "src/test/resources/bool.csv"
val datesFile = "src/test/resources/dates.csv"
private val simpleDatasetFile = "src/test/resources/simple.csv"

val numCars = 3
Expand Down Expand Up @@ -772,6 +774,49 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(results.toSeq.map(_.toSeq) === expected)
}

test("Inferring timestamp types via custom date format") {
val results = new CsvParser()
.withUseHeader(true)
.withParserLib(parserLib)
.withDateFormat("dd/MM/yyyy hh:mm")
.withInferSchema(true)
.csvFile(sqlContext, datesFile)
.select("date")
.collect()

val dateFormatter = new SimpleDateFormat("dd/MM/yyyy hh:mm")
val expected =
Seq(Seq(new Timestamp(dateFormatter.parse("26/08/2015 18:00").getTime)),
Seq(new Timestamp(dateFormatter.parse("27/10/2014 18:30").getTime)),
Seq(new Timestamp(dateFormatter.parse("28/01/2016 20:00").getTime)))
assert(results.toSeq.map(_.toSeq) === expected)
}

test("Load date types via custom date format") {
val customSchema = new StructType(Array(StructField("date", DateType, true)))
val results = new CsvParser()
.withSchema(customSchema)
.withUseHeader(true)
.withParserLib(parserLib)
.withDateFormat("dd/MM/yyyy hh:mm")
.csvFile(sqlContext, datesFile)
.select("date")
.collect()

val dateFormatter = new SimpleDateFormat("dd/MM/yyyy hh:mm")
val expected = Seq(
new Date(dateFormatter.parse("26/08/2015 18:00").getTime),
new Date(dateFormatter.parse("27/10/2014 18:30").getTime),
new Date(dateFormatter.parse("28/01/2016 20:00").getTime))
val dates = results.toSeq.map(_.toSeq.head)
expected.zip(dates).foreach {
case (expectedDate, date) =>
// As it truncates the hours, minutes and etc., we only check
// if the dates (days, months and years) are the same via `toString()`.
assert(expectedDate.toString === date.toString)
}
}

test("Setting comment to null disables comment support") {
val results: Array[Row] = new CsvParser()
.withDelimiter(',')
Expand Down
Loading