diff --git a/core/src/main/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31.scala b/core/src/main/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31.scala index c3b2182a0..5917f59dc 100644 --- a/core/src/main/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31.scala +++ b/core/src/main/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31.scala @@ -1,10 +1,21 @@ package doric package syntax +import doric.sem.Location import org.apache.spark.sql.{functions => f} private[syntax] trait StringColumns31 { + /** + * Throws an exception with the provided error message. + * + * @throws java.lang.RuntimeException with the error message + * @group String Type + * @see [[org.apache.spark.sql.functions.raise_error]] + */ + def raiseError(str: String)(implicit l: Location): NullColumn = + str.lit.raiseError + implicit class StringOperationsSyntax31(s: DoricColumn[String]) { /** @@ -20,6 +31,8 @@ private[syntax] trait StringColumns31 { * @group String Type * @see [[org.apache.spark.sql.functions.raise_error]] */ - def raiseError: NullColumn = s.elem.map(f.raise_error).toDC + def raiseError(implicit l: Location): NullColumn = + ds"""$s + located at . ${l.getLocation.lit}""".elem.map(f.raise_error).toDC } } diff --git a/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31Spec.scala b/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31Spec.scala index 4c72d13ab..cbad9a91d 100644 --- a/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31Spec.scala +++ b/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/StringColumns31Spec.scala @@ -1,9 +1,8 @@ package doric package syntax -import org.scalatest.EitherValues +import org.scalatest.{Assertion, EitherValues} import org.scalatest.matchers.should.Matchers - import org.apache.spark.sql.{functions => f} import org.apache.spark.sql.types.NullType @@ -15,7 +14,18 @@ class StringColumns31Spec describe("raiseError doric function") { import spark.implicits._ - val df = List("this is an error").toDF("errorMsg") + lazy val errorMsg = "this is an error" + lazy val df = List(errorMsg).toDF("errorMsg") + + def validateExceptions( + doricExc: RuntimeException, + sparkExc: RuntimeException + ): Assertion = { +// doricExc.getMessage should fullyMatch regex +// s"""${sparkExc.getMessage} +// located at . (${this.getClass.getSimpleName}.scala:33)""" + doricExc.getMessage should startWith(sparkExc.getMessage) + } it("should work as spark raise_error function") { import java.lang.{RuntimeException => exception} @@ -30,7 +40,23 @@ class StringColumns31Spec df.select(f.raise_error(f.col("errorMsg"))).collect() } - doricErr.getMessage shouldBe sparkErr.getMessage + validateExceptions(doricErr, sparkErr) + } + + it("should be available for strings") { + import java.lang.{RuntimeException => exception} + + val doricErr = intercept[exception] { + val res = df.select(raiseError(errorMsg)) + + res.schema.head.dataType shouldBe NullType + res.collect() + } + val sparkErr = intercept[exception] { + df.select(f.raise_error(f.col("errorMsg"))).collect() + } + + validateExceptions(doricErr, sparkErr) } }