Skip to content

Commit

Permalink
implement custom boolean parsing logic for CSV and JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 14, 2022
1 parent c500bfb commit 1231fc6
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 5 deletions.
10 changes: 9 additions & 1 deletion integration_tests/src/test/resources/boolean.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
{ "number": true }
{ "number": True }
{ "number": TRUE }
{ "number": false }
{ "number": null }
{ "number": False }
{ "number": FALSE }
{ "number": null }
{ "number": y }
{ "number": n }
{ "number": 0 }
{ "number": 1 }
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@ False
TRUE
FALSE
BAD
y
n
yes
no
1
0
t
f
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._

import ai.rapids.cudf
import ai.rapids.cudf.{HostMemoryBuffer, Schema, Table}
import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, Scalar, Schema, Table}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

Expand Down Expand Up @@ -423,4 +423,28 @@ class CSVPartitionReader(
* @return the file format short name
*/
override def getFileFormatShortName: String = "CSV"

/**
* CSV supports "true" and "false" (case-insensitive) as valid boolean values.
*/
override def castStringToBool(input: ColumnVector): ColumnVector = {
withResource(input.strip()) { stripped =>
withResource(stripped.lower()) { lower =>
withResource(Scalar.fromString("true")) { t =>
withResource(Scalar.fromString("false")) { f =>
withResource(lower.equalTo(t)) { isTrue =>
withResource(lower.equalTo(f)) { isFalse =>
withResource(isTrue.or(isFalse)) { isValidBool =>
withResource(Scalar.fromNull(DType.BOOL8)) { nullBool =>
isValidBool.ifElse(isTrue, nullBool)
}
}
}
}
}
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ object GpuCast extends Arm {
}
}

def castStringToBool(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {
private def castStringToBool(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {
val trueStrings = Seq("t", "true", "y", "yes", "1")
val falseStrings = Seq("f", "false", "n", "no", "0")
val boolStrings = trueStrings ++ falseStrings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ abstract class GpuTextBasedPartitionReader(
for (i <- 0 until table.getNumberOfColumns) {
val castColumn = dataSchema.fields(i).dataType match {
case DataTypes.BooleanType =>
GpuCast.castStringToBool(table.getColumn(i), ansiEnabled)
castStringToBool(table.getColumn(i))
case DataTypes.FloatType =>
GpuCast.castStringToFloats(table.getColumn(i), ansiEnabled, DType.FLOAT32)
case DataTypes.DoubleType =>
Expand All @@ -220,6 +220,8 @@ abstract class GpuTextBasedPartitionReader(
}
}

def castStringToBool(input: ColumnVector): ColumnVector

/**
* Read the host buffer to GPU table
* @param dataBuffer host buffer to be read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._

import ai.rapids.cudf
import ai.rapids.cudf.{HostMemoryBuffer, Schema, Table}
import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, Scalar, Schema, Table}
import com.nvidia.spark.rapids._
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -334,4 +334,24 @@ class JsonPartitionReader(
Some(new Table(prunedColumnVectors: _*))
}
}

/**
* JSON only supports unquoted lower-case "true" and "false" as valid boolean values.
*/
override def castStringToBool(input: ColumnVector): ColumnVector = {
withResource(Scalar.fromString("true")) { t =>
withResource(Scalar.fromString("false")) { f =>
withResource(input.equalTo(t)) { isTrue =>
withResource(input.equalTo(f)) { isFalse =>
withResource(isTrue.or(isFalse)) { isValidBool =>
withResource(Scalar.fromNull(DType.BOOL8)) { nullBool =>
isValidBool.ifElse(isTrue, nullBool)
}
}
}
}
}
}
}

}

0 comments on commit 1231fc6

Please sign in to comment.