Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable null values #76

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CsvParser {
private var ignoreLeadingWhiteSpace: Boolean = false
private var ignoreTrailingWhiteSpace: Boolean = false
private var parserLib: String = ParserLibs.DEFAULT
private var nullValues: Seq[String] = Seq("")


def withUseHeader(flag: Boolean): CsvParser = {
Expand Down Expand Up @@ -81,6 +82,11 @@ class CsvParser {
this
}

def withNullValues(nullValues: Seq[String]): CsvParser = {
this.nullValues = nullValues
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
Expand All @@ -94,7 +100,8 @@ class CsvParser {
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
schema)(sqlContext)
schema,
nullValues.toSet)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand Down
9 changes: 6 additions & 3 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ case class CsvRelation protected[spark] (
parserLib: String,
ignoreLeadingWhiteSpace: Boolean,
ignoreTrailingWhiteSpace: Boolean,
userSchema: StructType = null)(@transient val sqlContext: SQLContext)
userSchema: StructType = null,
nullValues: Set[String] = Set(""))(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {

private val logger = LoggerFactory.getLogger(CsvRelation.getClass)
Expand Down Expand Up @@ -153,7 +154,8 @@ case class CsvRelation protected[spark] (
try {
index = 0
while (index < schemaFields.length) {
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
val token = if (nullValues.contains(tokens(index))) "" else tokens(index)
rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -195,7 +197,8 @@ case class CsvRelation protected[spark] (
throw new RuntimeException(s"Malformed line in FAILFAST mode: $line")
} else {
while (index < schemaFields.length) {
rowArray(index) = TypeCast.castTo(tokens.get(index), schemaFields(index).dataType)
val token = if (nullValues.contains(tokens.get(index))) "" else tokens.get(index)
rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
34 changes: 19 additions & 15 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,25 @@ object TypeCast {
* @param castType SparkSQL type
*/
private[csv] def castTo(datum: String, castType: DataType): Any = {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
if (datum.isEmpty && castType != StringType) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice if this was another option. IE: In my application we have decided to standardize on parsing empty string fields as nulls rather than empty strings.

null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/test/resources/missing-values.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
year,make,model,comment,blank
"2012","Tesla","S","No comment",
1997,Ford,E350,"Go get one now they are going fast",
2015,Chevy,Volt
NA,NULL,"T","Comment"
20 changes: 20 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite {
val carsAltFile = "src/test/resources/cars-alternative.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val carsWithNAs = "src/test/resources/missing-values.csv"
val tempEmptyDir = "target/test/empty2/"

val numCars = 3
Expand Down Expand Up @@ -93,6 +94,25 @@ class CsvFastSuite extends FunSuite {
assert(results.size === numCars - 1)
}

test("DSL test for handling NULL values") {
val results = new CsvParser()
.withUseHeader(true)
.withParserLib("univocity")
.withNullValues(Seq("NULL", "NA"))
.csvFile(TestSQLContext, carsWithNAs)
.collect()

assert(results.size === numCars + 1)

val results2 = new CsvParser()
.withUseHeader(true)
.withNullValues(Seq("NULL", "NA", "NaN"))
.csvFile(TestSQLContext, carsWithNAs)
.collect()

assert(results2.size === numCars + 1)
}

test("DSL test for FAILFAST parsing mode") {
val parser = new CsvParser()
.withParseMode("FAILFAST")
Expand Down