Skip to content

Commit

Permalink
[SPARK-31261][SQL] Avoid npe when reading bad csv input with `columnN…
Browse files Browse the repository at this point in the history
…ameCorruptRecord` specified

### What changes were proposed in this pull request?

SPARK-25387 avoids npe for bad csv input, but when reading bad csv input with `columnNameCorruptRecord` specified, `getCurrentInput` is called and it still throws npe.

### Why are the changes needed?

Bug fix.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Add a test.

Closes #28029 from wzhfy/corrupt_column_npe.

Authored-by: Zhenhua Wang <wzh_zju@163.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
(cherry picked from commit 791d2ba)
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
wzhfy authored and HyukjinKwon committed Mar 29, 2020
1 parent 4217f75 commit 801d6a9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class UnivocityParser(

// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd)
val currentContent = tokenizer.getContext.currentParsedContent()
if (currentContent == null) null else UTF8String.fromString(currentContent.stripLineEnd)
}

// This parser first picks some tokens from the input tokens, according to the required schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
assert(spark.read.csv(input).collect().toSet == Set(Row()))
}

test("SPARK-31261: bad csv input with `columnNameCorruptRecord` should not cause NPE") {
val schema = StructType(
StructField("a", IntegerType) :: StructField("_corrupt_record", StringType) :: Nil)
val input = spark.createDataset(Seq("\u0000\u0000\u0001234"))

checkAnswer(
spark.read
.option("columnNameOfCorruptRecord", "_corrupt_record")
.schema(schema)
.csv(input),
Row(null, null))
assert(spark.read.csv(input).collect().toSet == Set(Row()))
}

test("field names of inferred schema shouldn't compare to the first row") {
val input = Seq("1,2").toDS()
val df = spark.read.option("enforceSchema", false).csv(input)
Expand Down

0 comments on commit 801d6a9

Please sign in to comment.