Skip to content

Commit

Permalink
[SPARK-45072][CONNECT] Fix outer scopes for ammonite classes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Ammonite places all user code inside Helper classes which are nested inside the class it creates for each command. This PR adds a custom code class wrapper for the Ammonite REPL. It makes sure the Helper classes generated by ammonite are always registered as an outer scope immediately. This way we can instantiate classes defined inside the Helper class, even when we execute Spark code as part of the Helper's constructor.

### Why are the changes needed?
When you currently define a class and execute a Spark command using that class inside the same cell/line this will fail with an NullPointerException. The reason for that is that we cannot resolve the outer scope needed to instantiate the class. This PR fixes that issue. The following code will now execute successfully (include the curly braces):
```scala
{
  case class Thing(val value: String)
  val r = (0 to 10).map( value => Thing(value.toString) )
  spark.createDataFrame(r)
}
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
I added more tests to the `ReplE2ESuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#42807 from hvanhovell/SPARK-45072.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Sep 5, 2023
1 parent 5fff242 commit 40943c2
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.util.concurrent.Semaphore
import scala.util.control.NonFatal

import ammonite.compiler.CodeClassWrapper
import ammonite.util.Bind
import ammonite.compiler.iface.CodeWrapper
import ammonite.util.{Bind, Imports, Name, Util}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -94,8 +95,8 @@ object ConnectRepl {
val main = ammonite.Main(
welcomeBanner = Option(splash),
predefCode = predefCode,
replCodeWrapper = CodeClassWrapper,
scriptCodeWrapper = CodeClassWrapper,
replCodeWrapper = ExtendedCodeClassWrapper,
scriptCodeWrapper = ExtendedCodeClassWrapper,
inputStream = inputStream,
outputStream = outputStream,
errorStream = errorStream)
Expand All @@ -107,3 +108,25 @@ object ConnectRepl {
}
}
}

/**
* [[CodeWrapper]] that makes sure new Helper classes are always registered as an outer scope.
*/
@DeveloperApi
object ExtendedCodeClassWrapper extends CodeWrapper {
override def wrapperPath: Seq[Name] = CodeClassWrapper.wrapperPath
override def apply(
code: String,
source: Util.CodeSource,
imports: Imports,
printCode: String,
indexedWrapper: Name,
extraCode: String): (String, String, Int) = {
val (top, bottom, level) =
CodeClassWrapper(code, source, imports, printCode, indexedWrapper, extraCode)
// Make sure we register the Helper before anything else, so outer scopes work as expected.
val augmentedTop = top +
"\norg.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)\n"
(augmentedTop, bottom, level)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {

override def afterEach(): Unit = {
semaphore.drainPermits()
if (ammoniteOut != null) {
ammoniteOut.reset()
}
}

def runCommandsInShell(input: String): String = {
ammoniteOut.reset()
require(input.nonEmpty)
// Pad the input with a semaphore release so that we know when the execution of the provided
// input is complete.
Expand All @@ -105,6 +103,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
getCleanString(ammoniteOut)
}

def runCommandsUsingSingleCellInShell(input: String): String = {
runCommandsInShell("{\n" + input + "\n}")
}

def assertContains(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(
Expand Down Expand Up @@ -263,6 +265,31 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output)
}

test("Single Cell Compilation") {
val input =
"""
|case class C1(value: Int)
|case class C2(value: Int)
|val h1 = classOf[C1].getDeclaringClass
|val h2 = classOf[C2].getDeclaringClass
|val same = h1 == h2
|""".stripMargin
assertContains("same: Boolean = false", runCommandsInShell(input))
assertContains("same: Boolean = true", runCommandsUsingSingleCellInShell(input))
}

test("Local relation containing REPL generated class") {
val input =
"""
|case class MyTestClass(value: Int)
|val data = (0 to 10).map(MyTestClass)
|spark.createDataset(data).map(mtc => mtc.value).select(sum($"value")).as[Long].head
|""".stripMargin
val expected = "Long = 55L"
assertContains(expected, runCommandsInShell(input))
assertContains(expected, runCommandsUsingSingleCellInShell(input))
}

test("Collect REPL generated class") {
val input =
"""
Expand All @@ -275,8 +302,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
| map(mtc => s"MyTestClass(${mtc.value})").
| mkString("[", ", ", "]")
""".stripMargin
val output = runCommandsInShell(input)
assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output)
val expected = """String = "[MyTestClass(1), MyTestClass(3)]""""
assertContains(expected, runCommandsInShell(input))
assertContains(expected, runCommandsUsingSingleCellInShell(input))
}

test("REPL class in encoder") {
Expand All @@ -288,8 +316,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
| map(mtc => mtc.value).
| collect()
""".stripMargin
val output = runCommandsInShell(input)
assertContains("Array[Int] = Array(0, 1, 2)", output)
val expected = "Array[Int] = Array(0, 1, 2)"
assertContains(expected, runCommandsInShell(input))
assertContains(expected, runCommandsUsingSingleCellInShell(input))
}

test("REPL class in UDF") {
Expand All @@ -301,8 +330,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
| map(mtc => s"MyTestClass(${mtc.value})").
| mkString("[", ", ", "]")
""".stripMargin
val output = runCommandsInShell(input)
assertContains("""String = "[MyTestClass(0), MyTestClass(1)]"""", output)
val expected = """String = "[MyTestClass(0), MyTestClass(1)]""""
assertContains(expected, runCommandsInShell(input))
assertContains(expected, runCommandsUsingSingleCellInShell(input))
}

test("streaming works with REPL generated code") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.application.ConnectRepl$" // developer API
),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.application.ExtendedCodeClassWrapper" // developer API
),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.application.ExtendedCodeClassWrapper$" // developer API
),

// SparkSession
// developer API
Expand Down

0 comments on commit 40943c2

Please sign in to comment.