+
+#### Forward compatibility
+
+The following table describes the oldest version of .NET for Apache Spark release that the current worker is compatible with.
+
+
+
+
+
Oldest compatible .NET for Apache Spark release version
+
+
+
+
+
v0.9.0
+
+
+
+
+### Supported Spark Versions
+
+The following table outlines the supported Spark versions along with the microsoft-spark JAR to use with:
+
+
diff --git a/docs/release-notes/0.12/release-0.12.md b/docs/release-notes/0.12/release-0.12.md
new file mode 100644
index 000000000..7000299ea
--- /dev/null
+++ b/docs/release-notes/0.12/release-0.12.md
@@ -0,0 +1,115 @@
+# .NET for Apache Spark 0.12 Release Notes
+
+### New Features/Improvements and Bug Fixes
+
+* Expose `DataStreamWriter.ForeachBatch` API ([#549](https://github.com/dotnet/spark/pull/549))
+* Support for [dotnet-interactive](https://github.com/dotnet/interactive) ([#515](https://github.com/dotnet/spark/pull/515)) ([#517](https://github.com/dotnet/spark/pull/517)) ([#554](https://github.com/dotnet/spark/pull/554))
+* Support for [Hyperspace v0.1.0](https://github.com/microsoft/hyperspace) APIs ([#555](https://github.com/dotnet/spark/pull/555))
+* Support for Spark 2.4.6 ([#547](https://github.com/dotnet/spark/pull/547))
+* Bug fixes:
+ * Udf bug caused by `BroadcastVariablesRegistry` ([#551](https://github.com/dotnet/spark/pull/551))
+ * Null checks for `TimestampType` and `DateType` ([#530](https://github.com/dotnet/spark/pull/530))
+* Update `Microsoft.Data.Analysis` to v`0.4.0` ([#528](https://github.com/dotnet/spark/pull/528))
+
+### Infrastructure / Documentation / Etc.
+
+* Improve build pipeline ([#510](https://github.com/dotnet/spark/pull/510)) ([#511](https://github.com/dotnet/spark/pull/511)) ([#512](https://github.com/dotnet/spark/pull/512)) ([#513](https://github.com/dotnet/spark/pull/513)) ([#524](https://github.com/dotnet/spark/pull/524))
+* Update AppName for the C# Spark Examples ([#548](https://github.com/dotnet/spark/pull/548))
+* Update maven links in build documentation ([#558](https://github.com/dotnet/spark/pull/558)) ([#560](https://github.com/dotnet/spark/pull/560))
+
+### Breaking Changes
+
+* None
+
+### Known Issues
+
+* Broadcast variables do not work with [dotnet-interactive](https://github.com/dotnet/interactive) ([#561](https://github.com/dotnet/spark/pull/561))
+
+### Compatibility
+
+#### Backward compatibility
+
+The following table describes the oldest version of the worker that the current version is compatible with, along with new features that are incompatible with the worker.
+
+
+
+#### Forward compatibility
+
+The following table describes the oldest version of .NET for Apache Spark release that the current worker is compatible with.
+
+
+
+
+
Oldest compatible .NET for Apache Spark release version
+
+
+
+
+
v0.9.0
+
+
+
+
+### Supported Spark Versions
+
+The following table outlines the supported Spark versions along with the microsoft-spark JAR to use with:
+
+
diff --git a/docs/udf-guide.md b/docs/udf-guide.md
new file mode 100644
index 000000000..6a2905bf4
--- /dev/null
+++ b/docs/udf-guide.md
@@ -0,0 +1,171 @@
+# Guide to User-Defined Functions (UDFs)
+
+This is a guide to show how to use UDFs in .NET for Apache Spark.
+
+## What are UDFs
+
+[User-Defined Functions (UDFs)](https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/expressions/UserDefinedFunction.html) are a feature of Spark that allow developers to use custom functions to extend the system's built-in functionality. They transform values from a single row within a table to produce a single corresponding output value per row based on the logic defined in the UDF.
+
+Let's take the following as an example for a UDF definition:
+
+```csharp
+string s1 = "hello";
+Func udf = Udf(
+ str => $"{s1} {str}");
+
+```
+The above defined UDF takes a `string` as an input (in the form of a [Column](https://github.com/dotnet/spark/blob/master/src/csharp/Microsoft.Spark/Sql/Column.cs#L14) of a [Dataframe](https://github.com/dotnet/spark/blob/master/src/csharp/Microsoft.Spark/Sql/DataFrame.cs#L24)), and returns a `string` with `hello` appended in front of the input.
+
+For a sample Dataframe, let's take the following Dataframe `df`:
+
+```text
++-------+
+| name|
++-------+
+|Michael|
+| Andy|
+| Justin|
++-------+
+```
+
+Now let's apply the above defined `udf` to the dataframe `df`:
+
+```csharp
+DataFrame udfResult = df.Select(udf(df["name"]));
+```
+
+This would return the below as the Dataframe `udfResult`:
+
+```text
++-------------+
+| name|
++-------------+
+|hello Michael|
+| hello Andy|
+| hello Justin|
++-------------+
+```
+To get a better understanding of how to implement UDFs, please take a look at the [UDF helper functions](https://github.com/dotnet/spark/blob/master/src/csharp/Microsoft.Spark/Sql/Functions.cs#L3616) and some [test examples](https://github.com/dotnet/spark/blob/master/src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs#L49).
+
+## UDF serialization
+
+Since UDFs are functions that need to be executed on the workers, they have to be serialized and sent to the workers as part of the payload from the driver. This involves serializing the [delegate](https://docs.microsoft.com/en-us/dotnet/csharp/programming-guide/delegates/) which is a reference to the method, along with its [target](https://docs.microsoft.com/en-us/dotnet/api/system.delegate.target?view=netframework-4.8) which is the class instance on which the current delegate invokes the instance method. Please take a look at this [code](https://github.com/dotnet/spark/blob/master/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs#L149) to get a better understanding of how UDF serialization is being done.
+
+## Good to know while implementing UDFs
+
+One behavior to be aware of while implementing UDFs in .NET for Apache Spark is how the target of the UDF gets serialized. .NET for Apache Spark uses .NET Core, which does not support serializing delegates, so it is instead done by using reflection to serialize the target where the delegate is defined. When multiple delegates are defined in a common scope, they have a shared closure that becomes the target of reflection for serialization. Let's take an example to illustrate what that means.
+
+The following code snippet defines two string variables that are being referenced in two function delegates that return the respective strings as result:
+
+```csharp
+using System;
+
+public class C {
+ public void M() {
+ string s1 = "s1";
+ string s2 = "s2";
+ Func a = str => s1;
+ Func b = str => s2;
+ }
+}
+```
+
+The above C# code generates the following C# disassembly (credit source: [sharplab.io](https://sharplab.io)) code from the compiler:
+
+```csharp
+public class C
+{
+ [CompilerGenerated]
+ private sealed class <>c__DisplayClass0_0
+ {
+ public string s1;
+
+ public string s2;
+
+ internal string b__0(string str)
+ {
+ return s1;
+ }
+
+ internal string b__1(string str)
+ {
+ return s2;
+ }
+ }
+
+ public void M()
+ {
+ <>c__DisplayClass0_0 <>c__DisplayClass0_ = new <>c__DisplayClass0_0();
+ <>c__DisplayClass0_.s1 = "s1";
+ <>c__DisplayClass0_.s2 = "s2";
+ Func func = new Func(<>c__DisplayClass0_.b__0);
+ Func func2 = new Func(<>c__DisplayClass0_.b__1);
+ }
+}
+```
+As can be seen in the above decompiled code, both `func` and `func2` share the same closure `<>c__DisplayClass0_0`, which is the target that is serialized when serializing the delegates `func` and `func2`. Hence, even though `Func a` is only referencing `s1`, `s2` also gets serialized when sending over the bytes to the workers.
+
+This can lead to some unexpected behaviors at runtime (like in the case of using [broadcast variables](broadcast-guide.md)), which is why we recommend restricting the visibility of the variables used in a function to that function's scope.
+
+Going back to the above example, the following is the recommended way to implement the desired behavior of previous code snippet:
+
+```csharp
+using System;
+
+public class C {
+ public void M() {
+ {
+ string s1 = "s1";
+ Func a = str => s1;
+ }
+ {
+ string s2 = "s2";
+ Func b = str => s2;
+ }
+ }
+}
+```
+
+The above C# code generates the following C# disassembly (credit source: [sharplab.io](https://sharplab.io)) code from the compiler:
+
+```csharp
+public class C
+{
+ [CompilerGenerated]
+ private sealed class <>c__DisplayClass0_0
+ {
+ public string s1;
+
+ internal string b__0(string str)
+ {
+ return s1;
+ }
+ }
+
+ [CompilerGenerated]
+ private sealed class <>c__DisplayClass0_1
+ {
+ public string s2;
+
+ internal string b__1(string str)
+ {
+ return s2;
+ }
+ }
+
+ public void M()
+ {
+ <>c__DisplayClass0_0 <>c__DisplayClass0_ = new <>c__DisplayClass0_0();
+ <>c__DisplayClass0_.s1 = "s1";
+ Func func = new Func(<>c__DisplayClass0_.b__0);
+ <>c__DisplayClass0_1 <>c__DisplayClass0_2 = new <>c__DisplayClass0_1();
+ <>c__DisplayClass0_2.s2 = "s2";
+ Func func2 = new Func(<>c__DisplayClass0_2.b__1);
+ }
+}
+```
+
+Here we see that `func` and `func2` no longer share a closure and have their own separate closures `<>c__DisplayClass0_0` and `<>c__DisplayClass0_1` respectively. When used as the target for serialization, nothing other than the referenced variables will get serialized for the delegate.
+
+This behavior is important to keep in mind while implementing multiple UDFs in a common scope.
+To learn more about UDFs in general, please review the following articles that explain UDFs and how to use them: [UDFs in databricks(scala)](https://docs.databricks.com/spark/latest/spark-sql/udf-scala.html), [Spark UDFs and some gotchas](https://medium.com/@achilleus/spark-udfs-we-can-use-them-but-should-we-use-them-2c5a561fde6d).
\ No newline at end of file
diff --git a/eng/Versions.props b/eng/Versions.props
index dc954bcc5..1219678bb 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -1,7 +1,7 @@
- 0.11.0
+ 0.12.1prerelease
$(RestoreSources);
diff --git a/examples/Microsoft.Spark.CSharp.Examples/MachineLearning/Sentiment/Program.cs b/examples/Microsoft.Spark.CSharp.Examples/MachineLearning/Sentiment/Program.cs
index efb85e468..51f63078d 100644
--- a/examples/Microsoft.Spark.CSharp.Examples/MachineLearning/Sentiment/Program.cs
+++ b/examples/Microsoft.Spark.CSharp.Examples/MachineLearning/Sentiment/Program.cs
@@ -27,7 +27,7 @@ public void Run(string[] args)
SparkSession spark = SparkSession
.Builder()
- .AppName(".NET for Apache Spark Sentiment Analysis")
+ .AppName("Sentiment Analysis using .NET for Apache Spark")
.GetOrCreate();
// Read in and display Yelp reviews
diff --git a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Basic.cs b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Basic.cs
index 6ef95eefa..e09c79e20 100644
--- a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Basic.cs
+++ b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Basic.cs
@@ -26,7 +26,7 @@ public void Run(string[] args)
SparkSession spark = SparkSession
.Builder()
- .AppName(".NET Spark SQL basic example")
+ .AppName("SQL basic example using .NET for Apache Spark")
.Config("spark.some.config.option", "some-value")
.GetOrCreate();
@@ -108,6 +108,25 @@ public void Run(string[] args)
DataFrame joinedDf3 = df.Join(df, df["name"] == df["name"], "outer");
joinedDf3.Show();
+
+ // Union of two data frames
+ DataFrame unionDf = df.Union(df);
+ unionDf.Show();
+
+ // Add new column to data frame
+ df.WithColumn("location", Lit("Seattle")).Show();
+
+ // Rename existing column
+ df.WithColumnRenamed("name", "fullname").Show();
+
+ // Filter rows with null age
+ df.Filter(Col("age").IsNull()).Show();
+
+ // Fill null values in age column with -1
+ df.Na().Fill(-1, new[] { "age" }).Show();
+
+ // Drop age column
+ df.Drop(new[] { "age" }).Show();
spark.Stop();
}
diff --git a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Datasource.cs b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Datasource.cs
index cf41eeceb..0945df791 100644
--- a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Datasource.cs
+++ b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/Datasource.cs
@@ -32,7 +32,7 @@ public void Run(string[] args)
SparkSession spark = SparkSession
.Builder()
- .AppName(".NET Spark SQL Datasource example")
+ .AppName("SQL Datasource example using .NET for Apache Spark")
.Config("spark.some.config.option", "some-value")
.GetOrCreate();
diff --git a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorDataFrameUdfs.cs b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorDataFrameUdfs.cs
index 697301733..aafea7256 100644
--- a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorDataFrameUdfs.cs
+++ b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorDataFrameUdfs.cs
@@ -31,7 +31,7 @@ public void Run(string[] args)
.Builder()
// Lower the shuffle partitions to speed up groupBy() operations.
.Config("spark.sql.shuffle.partitions", "3")
- .AppName(".NET Spark SQL VectorUdfs example")
+ .AppName("SQL VectorUdfs example using .NET for Apache Spark")
.GetOrCreate();
DataFrame df = spark.Read().Schema("age INT, name STRING").Json(args[0]);
diff --git a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorUdfs.cs b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorUdfs.cs
index 369cc3aff..2497d5ef3 100644
--- a/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorUdfs.cs
+++ b/examples/Microsoft.Spark.CSharp.Examples/Sql/Batch/VectorUdfs.cs
@@ -29,7 +29,7 @@ public void Run(string[] args)
.Builder()
// Lower the shuffle partitions to speed up groupBy() operations.
.Config("spark.sql.shuffle.partitions", "3")
- .AppName(".NET Spark SQL VectorUdfs example")
+ .AppName("SQL VectorUdfs example using .NET for Apache Spark")
.GetOrCreate();
DataFrame df = spark.Read().Schema("age INT, name STRING").Json(args[0]);
diff --git a/examples/Microsoft.Spark.FSharp.Examples/Sql/Basic.fs b/examples/Microsoft.Spark.FSharp.Examples/Sql/Basic.fs
index 4e503fac9..6af1f81f7 100644
--- a/examples/Microsoft.Spark.FSharp.Examples/Sql/Basic.fs
+++ b/examples/Microsoft.Spark.FSharp.Examples/Sql/Basic.fs
@@ -78,6 +78,25 @@ type Basic() =
let joinedDf3 = df.Join(df, df.["name"].EqualTo(df.["name"]), "outer")
joinedDf3.Show()
+
+ // Union of two data frames
+ let unionDf = df.Union(df)
+ unionDf.Show()
+
+ // Add new column to data frame
+ df.WithColumn("location", Functions.Lit("Seattle")).Show()
+
+ // Rename existing column
+ df.WithColumnRenamed("name", "fullname").Show()
+
+ // Filter rows with null age
+ df.Filter(df.["age"].IsNull()).Show()
+
+ // Fill null values in age column with -1
+ df.Na().Fill(-1L, ["age"]).Show()
+
+ // Drop age column
+ df.Drop(df.["age"]).Show()
spark.Stop()
0
diff --git a/script/download-spark-distros.cmd b/script/download-spark-distros.cmd
index d02bb49a7..0d2435a00 100644
--- a/script/download-spark-distros.cmd
+++ b/script/download-spark-distros.cmd
@@ -23,5 +23,7 @@ curl -k -L -o spark-2.4.1.tgz https://archive.apache.org/dist/spark/spark-2.4.1/
curl -k -L -o spark-2.4.3.tgz https://archive.apache.org/dist/spark/spark-2.4.3/spark-2.4.3-bin-hadoop2.7.tgz && tar xzvf spark-2.4.3.tgz
curl -k -L -o spark-2.4.4.tgz https://archive.apache.org/dist/spark/spark-2.4.4/spark-2.4.4-bin-hadoop2.7.tgz && tar xzvf spark-2.4.4.tgz
curl -k -L -o spark-2.4.5.tgz https://archive.apache.org/dist/spark/spark-2.4.5/spark-2.4.5-bin-hadoop2.7.tgz && tar xzvf spark-2.4.5.tgz
+curl -k -L -o spark-2.4.6.tgz https://archive.apache.org/dist/spark/spark-2.4.6/spark-2.4.6-bin-hadoop2.7.tgz && tar xzvf spark-2.4.6.tgz
+
+endlocal
-endlocal
\ No newline at end of file
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs
index 9c0472485..9ca3851f0 100644
--- a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaFixture.cs
@@ -16,7 +16,7 @@ public DeltaFixture()
{
Environment.SetEnvironmentVariable(
SparkFixture.EnvironmentVariableNames.ExtraSparkSubmitArgs,
- "--packages io.delta:delta-core_2.11:0.6.0 " +
+ "--packages io.delta:delta-core_2.11:0.6.1 " +
"--conf spark.databricks.delta.snapshotPartitions=2 " +
"--conf spark.sql.sources.parallelPartitionDiscovery.parallelism=5");
SparkFixture = new SparkFixture();
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaTableTests.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaTableTests.cs
index 69249d8c5..fab7c74dc 100644
--- a/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaTableTests.cs
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Delta.E2ETest/DeltaTableTests.cs
@@ -11,6 +11,7 @@
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Streaming;
using Microsoft.Spark.Sql.Types;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.Extensions.Delta.E2ETest
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj
new file mode 100644
index 000000000..391582751
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest.csproj
@@ -0,0 +1,23 @@
+
+
+
+ netcoreapp3.1
+ Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest
+ false
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/PackageResolverTests.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/PackageResolverTests.cs
new file mode 100644
index 000000000..219c533ff
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest/PackageResolverTests.cs
@@ -0,0 +1,95 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using Microsoft.DotNet.Interactive.Utility;
+using Microsoft.Spark.UnitTest.TestUtils;
+using Microsoft.Spark.Utils;
+using Moq;
+using Xunit;
+
+namespace Microsoft.Spark.Extensions.DotNet.Interactive.UnitTest
+{
+ public class PackageResolverTests
+ {
+ [Fact]
+ public void TestPackageResolver()
+ {
+ using var tempDir = new TemporaryDirectory();
+
+ string packageName = "package.name";
+ string packageVersion = "0.1.0";
+ string packageRootPath =
+ Path.Combine(tempDir.Path, "path", "to", "packages", packageName, packageVersion);
+ string packageFrameworkPath = Path.Combine(packageRootPath, "lib", "framework");
+
+ Directory.CreateDirectory(packageRootPath);
+ var nugetFile = new FileInfo(
+ Path.Combine(packageRootPath, $"{packageName}.{packageVersion}.nupkg"));
+ using (File.Create(nugetFile.FullName))
+ {
+ }
+
+ var assemblyPaths = new List
+ {
+ new FileInfo(Path.Combine(packageFrameworkPath, "1.dll")),
+ new FileInfo(Path.Combine(packageFrameworkPath, "2.dll"))
+ };
+ var probingPaths = new List { new DirectoryInfo(packageRootPath) };
+
+ var mockSupportNugetWrapper = new Mock();
+ mockSupportNugetWrapper
+ .SetupGet(m => m.ResolvedPackageReferences)
+ .Returns(new ResolvedPackageReference[]
+ {
+ new ResolvedPackageReference(
+ packageName,
+ packageVersion,
+ assemblyPaths,
+ new DirectoryInfo(packageRootPath),
+ probingPaths)
+ });
+
+ var packageResolver = new PackageResolver(mockSupportNugetWrapper.Object);
+ IEnumerable actualFiles = packageResolver.GetFiles(tempDir.Path);
+
+ string metadataFilePath =
+ Path.Combine(tempDir.Path, DependencyProviderUtils.CreateFileName(1));
+ var expectedFiles = new string[]
+ {
+ nugetFile.FullName,
+ metadataFilePath
+ };
+ Assert.True(expectedFiles.SequenceEqual(actualFiles));
+ Assert.True(File.Exists(metadataFilePath));
+
+ DependencyProviderUtils.Metadata actualMetadata =
+ DependencyProviderUtils.Metadata.Deserialize(metadataFilePath);
+ var expectedMetadata = new DependencyProviderUtils.Metadata
+ {
+ AssemblyProbingPaths = new string[]
+ {
+ Path.Combine(packageName, packageVersion, "lib", "framework", "1.dll"),
+ Path.Combine(packageName, packageVersion, "lib", "framework", "2.dll")
+ },
+ NativeProbingPaths = new string[]
+ {
+ Path.Combine(packageName, packageVersion)
+ },
+ NuGets = new DependencyProviderUtils.NuGetMetadata[]
+ {
+ new DependencyProviderUtils.NuGetMetadata
+ {
+ FileName = $"{packageName}.{packageVersion}.nupkg",
+ PackageName = packageName,
+ PackageVersion = packageVersion
+ }
+ }
+ };
+ Assert.True(expectedMetadata.Equals(actualMetadata));
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs
new file mode 100644
index 000000000..bb30e4957
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs
@@ -0,0 +1,156 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis;
+using Microsoft.DotNet.Interactive;
+using Microsoft.DotNet.Interactive.Commands;
+using Microsoft.DotNet.Interactive.CSharp;
+using Microsoft.DotNet.Interactive.Utility;
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Utils;
+
+namespace Microsoft.Spark.Extensions.DotNet.Interactive
+{
+ ///
+ /// A kernel extension when using .NET for Apache Spark with Microsoft.DotNet.Interactive
+ /// Adds nuget and assembly dependencies to the default
+ /// using .
+ ///
+ public class AssemblyKernelExtension : IKernelExtension
+ {
+ private const string TempDirEnvVar = "DOTNET_SPARK_EXTENSION_INTERACTIVE_TMPDIR";
+
+ private readonly PackageResolver _packageResolver =
+ new PackageResolver(new SupportNugetWrapper());
+
+ ///
+ /// Called by the Microsoft.DotNet.Interactive Assembly Extension Loader.
+ ///
+ /// The kernel calling this method.
+ /// when extension is loaded.
+ public Task OnLoadAsync(IKernel kernel)
+ {
+ if (kernel is CompositeKernel kernelBase)
+ {
+ Environment.SetEnvironmentVariable(Constants.RunningREPLEnvVar, "true");
+
+ DirectoryInfo tempDir = CreateTempDirectory();
+ kernelBase.RegisterForDisposal(new DisposableDirectory(tempDir));
+
+ kernelBase.AddMiddleware(async (command, context, next) =>
+ {
+ await next(command, context);
+
+ if ((context.HandlingKernel is CSharpKernel kernel) &&
+ (command is SubmitCode) &&
+ TryGetSparkSession(out SparkSession sparkSession) &&
+ TryEmitAssembly(kernel, tempDir.FullName, out string assemblyPath))
+ {
+ sparkSession.SparkContext.AddFile(assemblyPath);
+
+ foreach (string filePath in GetPackageFiles(tempDir.FullName))
+ {
+ sparkSession.SparkContext.AddFile(filePath);
+ }
+ }
+ });
+ }
+
+ return Task.CompletedTask;
+ }
+
+ private DirectoryInfo CreateTempDirectory()
+ {
+ string envTempDir = Environment.GetEnvironmentVariable(TempDirEnvVar);
+ string tempDirBasePath = string.IsNullOrEmpty(envTempDir) ?
+ Directory.GetCurrentDirectory() :
+ envTempDir;
+
+ if (!IsPathValid(tempDirBasePath))
+ {
+ throw new Exception($"[{GetType().Name}] Spaces in " +
+ $"'{tempDirBasePath}' is unsupported. Set the {TempDirEnvVar} " +
+ "environment variable to control the base path. Please see " +
+ "https://issues.apache.org/jira/browse/SPARK-30126 and " +
+ "https://github.com/apache/spark/pull/26773 for more details.");
+ }
+
+ return Directory.CreateDirectory(
+ Path.Combine(tempDirBasePath, Path.GetRandomFileName()));
+ }
+
+ private bool TryEmitAssembly(CSharpKernel kernel, string dstPath, out string assemblyPath)
+ {
+ Compilation compilation = kernel.ScriptState.Script.GetCompilation();
+ string assemblyName =
+ AssemblyLoader.NormalizeAssemblyName(compilation.AssemblyName);
+ assemblyPath = Path.Combine(dstPath, $"{assemblyName}.dll");
+ if (!File.Exists(assemblyPath))
+ {
+ FileSystemExtensions.Emit(compilation, assemblyPath);
+ return true;
+ }
+
+ throw new Exception(
+ $"TryEmitAssembly() unexpected duplicate assembly: ${assemblyPath}");
+ }
+
+ private bool TryGetSparkSession(out SparkSession sparkSession)
+ {
+ sparkSession = SparkSession.GetDefaultSession();
+ return sparkSession != null;
+ }
+
+ private IEnumerable GetPackageFiles(string path)
+ {
+ foreach (string filePath in _packageResolver.GetFiles(path))
+ {
+ if (IsPathValid(filePath))
+ {
+ yield return filePath;
+ }
+ else
+ {
+ // Copy file to a path without spaces.
+ string fileDestPath = Path.Combine(
+ path,
+ Path.GetFileName(filePath).Replace(" ", string.Empty));
+ File.Copy(filePath, fileDestPath);
+ yield return fileDestPath;
+ }
+ }
+ }
+
+ ///
+ /// In some versions of Spark, spaces is unsupported when using
+ /// .
+ ///
+ /// For more details please see:
+ /// - https://issues.apache.org/jira/browse/SPARK-30126
+ /// - https://github.com/apache/spark/pull/26773
+ ///
+ /// The path to validate.
+ /// true if the path is supported by Spark, false otherwise.
+ private bool IsPathValid(string path)
+ {
+ if (!path.Contains(" "))
+ {
+ return true;
+ }
+
+ Version version = SparkEnvironment.SparkVersion;
+ return (version.Major, version.Minor, version.Build) switch
+ {
+ (2, _, _) => false,
+ (3, 0, _) => true,
+ _ => throw new NotSupportedException($"Spark {version} not supported.")
+ };
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/Microsoft.Spark.Extensions.DotNet.Interactive.csproj b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/Microsoft.Spark.Extensions.DotNet.Interactive.csproj
new file mode 100644
index 000000000..da330c762
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/Microsoft.Spark.Extensions.DotNet.Interactive.csproj
@@ -0,0 +1,38 @@
+
+
+
+ Library
+ netcoreapp3.1
+ Microsoft.Spark.Extensions.DotNet.Interactive
+ true
+ true
+
+ NU5100;$(NoWarn)
+
+ DotNet Interactive Extension for .NET for Apache Spark
+ https://github.com/dotnet/spark/tree/master/docs/release-notes
+ spark;dotnet;csharp;interactive;dotnet-interactive
+
+
+
+
+
+
+
+
+
+ all
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/PackageResolver.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/PackageResolver.cs
new file mode 100644
index 000000000..f9a76e43f
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/PackageResolver.cs
@@ -0,0 +1,165 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading;
+using Microsoft.DotNet.Interactive.Utility;
+using Microsoft.Spark.Utils;
+
+namespace Microsoft.Spark.Extensions.DotNet.Interactive
+{
+ internal class PackageResolver
+ {
+ private readonly SupportNugetWrapper _supportNugetWrapper;
+ private readonly ConcurrentDictionary _filesCopied;
+ private long _metadataCounter;
+
+ internal PackageResolver(SupportNugetWrapper supportNugetWrapper)
+ {
+ _supportNugetWrapper = supportNugetWrapper;
+ _filesCopied = new ConcurrentDictionary();
+ _metadataCounter = 0;
+ }
+
+ ///
+ /// Generates and serializes a to
+ /// . Returns a list of file paths which include the
+ /// the serialized and nuget file
+ /// dependencies.
+ ///
+ /// Path to write metadata.
+ ///
+ /// List of file paths of the serialized
+ /// and nuget file dependencies.
+ ///
+ internal IEnumerable GetFiles(string writePath)
+ {
+ IEnumerable nugetPackagesToCopy = GetNewPackages();
+
+ var assemblyProbingPaths = new List();
+ var nativeProbingPaths = new List();
+ var nugetMetadata = new List();
+
+ foreach (ResolvedNuGetPackage package in nugetPackagesToCopy)
+ {
+ ResolvedPackageReference resolvedPackage = package.ResolvedPackage;
+
+ foreach (FileInfo asmPath in resolvedPackage.AssemblyPaths)
+ {
+ // asmPath.FullName
+ // /path/to/packages/package.name/package.version/lib/framework/1.dll
+ // resolvedPackage.PackageRoot
+ // /path/to/packages/package.name/package.version/
+ // GetRelativeToPackages(..)
+ // package.name/package.version/lib/framework/1.dll
+ assemblyProbingPaths.Add(
+ GetPathRelativeToPackages(
+ asmPath.FullName,
+ resolvedPackage.PackageRoot));
+ }
+
+ foreach (DirectoryInfo probePath in resolvedPackage.ProbingPaths)
+ {
+ // probePath.FullName
+ // /path/to/packages/package.name/package.version/
+ // resolvedPackage.PackageRoot
+ // /path/to/packages/package.name/package.version/
+ // GetRelativeToPackages(..)
+ // package.name/package.version
+ nativeProbingPaths.Add(
+ GetPathRelativeToPackages(
+ probePath.FullName,
+ resolvedPackage.PackageRoot));
+ }
+
+ nugetMetadata.Add(
+ new DependencyProviderUtils.NuGetMetadata
+ {
+ FileName = package.NuGetFile.Name,
+ PackageName = resolvedPackage.PackageName,
+ PackageVersion = resolvedPackage.PackageVersion
+ });
+
+ yield return package.NuGetFile.FullName;
+ }
+
+ if (nugetMetadata.Count > 0)
+ {
+ var metadataPath =
+ Path.Combine(
+ writePath,
+ DependencyProviderUtils.CreateFileName(
+ Interlocked.Increment(ref _metadataCounter)));
+ new DependencyProviderUtils.Metadata
+ {
+ AssemblyProbingPaths = assemblyProbingPaths.ToArray(),
+ NativeProbingPaths = nativeProbingPaths.ToArray(),
+ NuGets = nugetMetadata.ToArray()
+ }.Serialize(metadataPath);
+
+ yield return metadataPath;
+ }
+ }
+
+ ///
+ /// Return the delta of the list of packages that have been introduced
+ /// since the last call.
+ ///
+ /// The delta of the list of packages.
+ private IEnumerable GetNewPackages()
+ {
+ IEnumerable packages =
+ _supportNugetWrapper.ResolvedPackageReferences;
+ foreach (ResolvedPackageReference package in packages)
+ {
+ IEnumerable files =
+ package.PackageRoot.EnumerateFiles("*.nupkg", SearchOption.AllDirectories);
+
+ foreach (FileInfo file in files)
+ {
+ if (_filesCopied.TryAdd(file.Name, 1))
+ {
+ yield return new ResolvedNuGetPackage
+ {
+ ResolvedPackage = package,
+ NuGetFile = file
+ };
+ }
+ }
+ }
+ }
+
+ ///
+ /// Given a , get the relative path to the packages directory.
+ /// The package is a subfolder within the packages directory.
+ ///
+ /// Examples:
+ /// path:
+ /// /path/to/packages/package.name/package.version/lib/framework/1.dll
+ /// directory:
+ /// /path/to/packages/package.name/package.version/
+ /// relative path:
+ /// package.name/package.version/lib/framework/1.dll
+ ///
+ /// path:
+ /// /path/to/packages/package.name/package.version/
+ /// directory:
+ /// /path/to/packages/package.name/package.version/
+ /// relative path:
+ /// package.name/package.version
+ ///
+ /// The full path used to determine the relative path.
+ /// The package directory.
+ /// The relative path to the packages directory.
+ private string GetPathRelativeToPackages(string path, DirectoryInfo directory)
+ {
+ string strippedRoot = path
+ .Substring(directory.FullName.Length)
+ .Trim(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar);
+ return Path.Combine(directory.Parent.Name, directory.Name, strippedRoot);
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/ResolvedNugetPackage.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/ResolvedNugetPackage.cs
new file mode 100644
index 000000000..57106c16a
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/ResolvedNugetPackage.cs
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+using Microsoft.DotNet.Interactive.Utility;
+
+namespace Microsoft.Spark.Extensions.DotNet.Interactive
+{
+ internal class ResolvedNuGetPackage
+ {
+ public ResolvedPackageReference ResolvedPackage { get; set; }
+ public FileInfo NuGetFile { get; set; }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/SupportNugetWrapper.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/SupportNugetWrapper.cs
new file mode 100644
index 000000000..489e39e94
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/SupportNugetWrapper.cs
@@ -0,0 +1,13 @@
+using System.Collections.Generic;
+using Microsoft.DotNet.Interactive;
+using Microsoft.DotNet.Interactive.Utility;
+
+namespace Microsoft.Spark.Extensions.DotNet.Interactive
+{
+ internal class SupportNugetWrapper
+ {
+ internal virtual IEnumerable ResolvedPackageReferences =>
+ ((ISupportNuget)KernelInvocationContext.Current.HandlingKernel)
+ .ResolvedPackageReferences;
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Constants.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Constants.cs
new file mode 100644
index 000000000..969dd85f1
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Constants.cs
@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Spark.Extensions.Hyperspace.E2ETest
+{
+ ///
+ /// Constants related to the Hyperspace test suite.
+ ///
+ internal class Constants
+ {
+ public const string HyperspaceTestContainerName = "Hyperspace Tests";
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceFixture.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceFixture.cs
new file mode 100644
index 000000000..8578c77f0
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceFixture.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Spark.E2ETest;
+using Xunit;
+
+namespace Microsoft.Spark.Extensions.Hyperspace.E2ETest
+{
+ public class HyperspaceFixture
+ {
+ public HyperspaceFixture()
+ {
+ Environment.SetEnvironmentVariable(
+ SparkFixture.EnvironmentVariableNames.ExtraSparkSubmitArgs,
+ "--packages com.microsoft.hyperspace:hyperspace-core_2.11:0.1.0");
+
+ SparkFixture = new SparkFixture();
+ }
+
+ public SparkFixture SparkFixture { get; private set; }
+ }
+
+ [CollectionDefinition(Constants.HyperspaceTestContainerName)]
+ public class HyperspaceTestCollection : ICollectionFixture
+ {
+ // This class has no code, and is never created. Its purpose is simply
+ // to be the place to apply [CollectionDefinition] and all the
+ // ICollectionFixture<> interfaces.
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceTests.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceTests.cs
new file mode 100644
index 000000000..12e8bca60
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/HyperspaceTests.cs
@@ -0,0 +1,141 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Spark.E2ETest.Utils;
+using Microsoft.Spark.Extensions.Hyperspace.Index;
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
+using Xunit;
+
+namespace Microsoft.Spark.Extensions.Hyperspace.E2ETest
+{
+ ///
+ /// Test suite for Hyperspace index management APIs.
+ ///
+ [Collection(Constants.HyperspaceTestContainerName)]
+ public class HyperspaceTests : IDisposable
+ {
+ private readonly SparkSession _spark;
+ private readonly TemporaryDirectory _hyperspaceSystemDirectory;
+ private readonly Hyperspace _hyperspace;
+
+ // Fields needed for sample DataFrame.
+ private readonly DataFrame _sampleDataFrame;
+ private readonly string _sampleIndexName;
+ private readonly IndexConfig _sampleIndexConfig;
+
+ public HyperspaceTests(HyperspaceFixture fixture)
+ {
+ _spark = fixture.SparkFixture.Spark;
+ _hyperspaceSystemDirectory = new TemporaryDirectory();
+ _spark.Conf().Set("spark.hyperspace.system.path", _hyperspaceSystemDirectory.Path);
+ _hyperspace = new Hyperspace(_spark);
+
+ _sampleDataFrame = _spark.Read()
+ .Option("header", true)
+ .Option("delimiter", ";")
+ .Csv("Resources\\people.csv");
+ _sampleIndexName = "sample_dataframe";
+ _sampleIndexConfig = new IndexConfig(_sampleIndexName, new[] { "job" }, new[] { "name" });
+ _hyperspace.CreateIndex(_sampleDataFrame, _sampleIndexConfig);
+ }
+
+ ///
+ /// Clean up the Hyperspace system directory in between tests.
+ ///
+ public void Dispose()
+ {
+ _hyperspaceSystemDirectory.Dispose();
+ }
+
+ ///
+ /// Test the method signatures for all Hyperspace APIs.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestSignatures()
+ {
+ // Indexes API.
+ Assert.IsType(_hyperspace.Indexes());
+
+ // Delete and Restore APIs.
+ _hyperspace.DeleteIndex(_sampleIndexName);
+ _hyperspace.RestoreIndex(_sampleIndexName);
+
+ // Refresh API.
+ _hyperspace.RefreshIndex(_sampleIndexName);
+
+ // Cancel API.
+ Assert.Throws(() => _hyperspace.Cancel(_sampleIndexName));
+
+ // Explain API.
+ _hyperspace.Explain(_sampleDataFrame, true);
+ _hyperspace.Explain(_sampleDataFrame, true, s => Console.WriteLine(s));
+
+ // Delete and Vacuum APIs.
+ _hyperspace.DeleteIndex(_sampleIndexName);
+ _hyperspace.VacuumIndex(_sampleIndexName);
+
+ // Enable and disable Hyperspace.
+ Assert.IsType(_spark.EnableHyperspace());
+ Assert.IsType(_spark.DisableHyperspace());
+ Assert.IsType(_spark.IsHyperspaceEnabled());
+ }
+
+ ///
+ /// Test E2E functionality of index CRUD APIs.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestIndexCreateAndDelete()
+ {
+ // Should be one active index.
+ DataFrame indexes = _hyperspace.Indexes();
+ Assert.Equal(1, indexes.Count());
+ Assert.Equal(_sampleIndexName, indexes.SelectExpr("name").First()[0]);
+ Assert.Equal(States.Active, indexes.SelectExpr("state").First()[0]);
+
+ // Delete the index then verify it has been deleted.
+ _hyperspace.DeleteIndex(_sampleIndexName);
+ indexes = _hyperspace.Indexes();
+ Assert.Equal(1, indexes.Count());
+ Assert.Equal(States.Deleted, indexes.SelectExpr("state").First()[0]);
+
+ // Restore the index to active state and verify it is back.
+ _hyperspace.RestoreIndex(_sampleIndexName);
+ indexes = _hyperspace.Indexes();
+ Assert.Equal(1, indexes.Count());
+ Assert.Equal(States.Active, indexes.SelectExpr("state").First()[0]);
+
+ // Delete and vacuum the index, then verify it is gone.
+ _hyperspace.DeleteIndex(_sampleIndexName);
+ _hyperspace.VacuumIndex(_sampleIndexName);
+ Assert.Equal(0, _hyperspace.Indexes().Count());
+ }
+
+ ///
+ /// Test that the explain API generates the expected string.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestExplainAPI()
+ {
+ // Run a query that hits the index.
+ DataFrame queryDataFrame = _sampleDataFrame
+ .Where("job == 'Developer'")
+ .Select("name");
+
+ string explainString = string.Empty;
+ _hyperspace.Explain(queryDataFrame, true, s => explainString = s);
+ Assert.False(string.IsNullOrEmpty(explainString));
+ }
+
+ ///
+ /// Index states used in testing.
+ ///
+ private static class States
+ {
+ public const string Active = "ACTIVE";
+ public const string Deleted = "DELETED";
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Index/IndexConfigTests.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Index/IndexConfigTests.cs
new file mode 100644
index 000000000..b96f85432
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Index/IndexConfigTests.cs
@@ -0,0 +1,86 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.Spark.E2ETest.Utils;
+using Microsoft.Spark.Extensions.Hyperspace.Index;
+using Xunit;
+
+namespace Microsoft.Spark.Extensions.Hyperspace.E2ETest.Index
+{
+ ///
+ /// Test suite for Hyperspace IndexConfig tests.
+ ///
+ [Collection(Constants.HyperspaceTestContainerName)]
+ public class IndexConfigTests
+ {
+ public IndexConfigTests(HyperspaceFixture fixture)
+ {
+ }
+
+ ///
+ /// Test the method signatures for IndexConfig and IndexConfigBuilder APIs.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestSignatures()
+ {
+ string indexName = "testIndexName";
+ var indexConfig = new IndexConfig(indexName, new[] { "Id" }, new string[] { });
+ Assert.IsType(indexConfig.IndexName);
+ Assert.IsType>(indexConfig.IndexedColumns);
+ Assert.IsType>(indexConfig.IncludedColumns);
+ Assert.IsType(IndexConfig.Builder());
+ Assert.IsType(indexConfig.Equals(indexConfig));
+ Assert.IsType(indexConfig.GetHashCode());
+ Assert.IsType(indexConfig.ToString());
+
+ Builder builder = IndexConfig.Builder();
+ Assert.IsType(builder);
+ Assert.IsType(builder.IndexName("indexName"));
+ Assert.IsType(builder.IndexBy("indexed1", "indexed2"));
+ Assert.IsType(builder.Include("included1"));
+ Assert.IsType(builder.Create());
+ }
+
+ ///
+ /// Test creating an IndexConfig using its class constructor.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestIndexConfigConstructor()
+ {
+ string indexName = "indexName";
+ string[] indexedColumns = { "idx1" };
+ string[] includedColumns = { "inc1", "inc2", "inc3" };
+ var config = new IndexConfig(indexName, indexedColumns, includedColumns);
+
+ // Validate that the config was built correctly.
+ Assert.Equal(indexName, config.IndexName);
+ Assert.Equal(indexedColumns, config.IndexedColumns);
+ Assert.Equal(includedColumns, config.IncludedColumns);
+ }
+
+ ///
+ /// Test creating an IndexConfig using the builder pattern.
+ ///
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestIndexConfigBuilder()
+ {
+ string indexName = "indexName";
+ string[] indexedColumns = { "idx1" };
+ string[] includedColumns = { "inc1", "inc2", "inc3" };
+
+ Builder builder = IndexConfig.Builder();
+ builder.IndexName(indexName);
+ builder.Include(includedColumns[0], includedColumns[1], includedColumns[2]);
+ builder.IndexBy(indexedColumns[0]);
+
+ // Validate that the config was built correctly.
+ IndexConfig config = builder.Create();
+ Assert.Equal(indexName, config.IndexName);
+ Assert.Equal(indexedColumns, config.IndexedColumns);
+ Assert.Equal(includedColumns, config.IncludedColumns);
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Microsoft.Spark.Extensions.Hyperspace.E2ETest.csproj b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Microsoft.Spark.Extensions.Hyperspace.E2ETest.csproj
new file mode 100644
index 000000000..231022e4b
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace.E2ETest/Microsoft.Spark.Extensions.Hyperspace.E2ETest.csproj
@@ -0,0 +1,13 @@
+
+
+
+ netcoreapp3.1
+ false
+
+
+
+
+
+
+
+
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Hyperspace.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Hyperspace.cs
new file mode 100644
index 000000000..13509779d
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Hyperspace.cs
@@ -0,0 +1,113 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Spark.Extensions.Hyperspace.Index;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Sql;
+
+namespace Microsoft.Spark.Extensions.Hyperspace
+{
+ ///
+ /// .Net for Spark binding for Hyperspace index management APIs.
+ ///
+ public class Hyperspace : IJvmObjectReferenceProvider
+ {
+ private static readonly string s_hyperspaceClassName =
+ "com.microsoft.hyperspace.Hyperspace";
+ private readonly SparkSession _spark;
+ private readonly IJvmBridge _jvmBridge;
+ private readonly JvmObjectReference _jvmObject;
+
+ public Hyperspace(SparkSession spark)
+ {
+ _spark = spark;
+ _jvmBridge = ((IJvmObjectReferenceProvider)spark).Reference.Jvm;
+ _jvmObject = _jvmBridge.CallConstructor(s_hyperspaceClassName, spark);
+ }
+
+ JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
+
+ ///
+ /// Collect all the index metadata.
+ ///
+ /// All index metadata as a .
+ public DataFrame Indexes() =>
+ new DataFrame((JvmObjectReference)_jvmObject.Invoke("indexes"));
+
+ ///
+ /// Create index.
+ ///
+ /// The DataFrame object to build index on.
+ /// The configuration of index to be created.
+ public void CreateIndex(DataFrame df, IndexConfig indexConfig) =>
+ _jvmObject.Invoke("createIndex", df, indexConfig);
+
+ ///
+ /// Soft deletes the index with given index name.
+ ///
+ /// The name of index to delete.
+ public void DeleteIndex(string indexName) => _jvmObject.Invoke("deleteIndex", indexName);
+
+ ///
+ /// Restores index with given index name.
+ ///
+ /// Name of the index to restore.
+ public void RestoreIndex(string indexName) => _jvmObject.Invoke("restoreIndex", indexName);
+
+ ///
+ /// Does hard delete of indexes marked as DELETED.
+ ///
+ /// Name of the index to restore.
+ public void VacuumIndex(string indexName) => _jvmObject.Invoke("vacuumIndex", indexName);
+
+ ///
+ /// Update indexes for the latest version of the data.
+ ///
+ /// Name of the index to refresh.
+ public void RefreshIndex(string indexName) => _jvmObject.Invoke("refreshIndex", indexName);
+
+ ///
+ /// Cancel api to bring back index from an inconsistent state to the last known stable
+ /// state.
+ ///
+ /// E.g. if index fails during creation, in CREATING state.
+ /// The index will not allow any index modifying operations unless a cancel is called.
+ ///
+ /// Note: Cancel from VACUUMING state will move it forward to DOESNOTEXIST
+ /// state.
+ ///
+ /// Note: If no previous stable state exists, cancel will move it to DOESNOTEXIST
+ /// state.
+ ///
+ /// Name of the index to cancel.
+ public void Cancel(string indexName) => _jvmObject.Invoke("cancel", indexName);
+
+ ///
+ /// Explains how indexes will be applied to the given dataframe.
+ ///
+ /// dataFrame
+ /// Flag to enable verbose mode.
+ public void Explain(DataFrame df, bool verbose) =>
+ Explain(df, verbose, s => Console.WriteLine(s));
+
+ ///
+ /// Explains how indexes will be applied to the given dataframe.
+ ///
+ /// dataFrame
+ /// Flag to enable verbose mode.
+ /// Function to redirect output of explain.
+ public void Explain(DataFrame df, bool verbose, Action redirectFunc)
+ {
+ var explainString = (string)_jvmBridge.CallStaticJavaMethod(
+ "com.microsoft.hyperspace.index.plananalysis.PlanAnalyzer",
+ "explainString",
+ df,
+ _spark,
+ Indexes(),
+ verbose);
+ redirectFunc(explainString);
+ }
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/HyperspaceSparkSessionExtensions.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/HyperspaceSparkSessionExtensions.cs
new file mode 100644
index 000000000..3c43f369c
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/HyperspaceSparkSessionExtensions.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Sql;
+
+namespace Microsoft.Spark.Extensions.Hyperspace
+{
+ ///
+ /// Hyperspace-specific extension methods on .
+ ///
+ public static class HyperspaceSparkSessionExtensions
+ {
+ private static readonly string s_pythonUtilsClassName =
+ "com.microsoft.hyperspace.util.PythonUtils";
+
+ ///
+ /// Plug in Hyperspace-specific rules.
+ ///
+ /// A spark session that does not contain Hyperspace-specific rules.
+ ///
+ /// A spark session that contains Hyperspace-specific rules.
+ public static SparkSession EnableHyperspace(this SparkSession session) =>
+ new SparkSession(
+ (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_pythonUtilsClassName,
+ "enableHyperspace",
+ session));
+
+ ///
+ /// Plug out Hyperspace-specific rules.
+ ///
+ /// A spark session that contains Hyperspace-specific rules.
+ /// A spark session that does not contain Hyperspace-specific rules.
+ public static SparkSession DisableHyperspace(this SparkSession session) =>
+ new SparkSession(
+ (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_pythonUtilsClassName,
+ "disableHyperspace",
+ session));
+
+ ///
+ /// Checks if Hyperspace is enabled or not.
+ ///
+ ///
+ /// True if Hyperspace is enabled or false otherwise.
+ public static bool IsHyperspaceEnabled(this SparkSession session) =>
+ (bool)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_pythonUtilsClassName,
+ "isHyperspaceEnabled",
+ session);
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/Builder.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/Builder.cs
new file mode 100644
index 000000000..4623de3e7
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/Builder.cs
@@ -0,0 +1,74 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.Extensions.Hyperspace.Index
+{
+ ///
+ /// Builder for .
+ ///
+ public sealed class Builder : IJvmObjectReferenceProvider
+ {
+ private readonly JvmObjectReference _jvmObject;
+
+ internal Builder(JvmObjectReference jvmObject)
+ {
+ _jvmObject = jvmObject;
+ }
+
+ JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
+
+ ///
+ /// Updates index name for .
+ ///
+ /// Index name for the .
+ /// An object with updated indexname.
+ public Builder IndexName(string indexName)
+ {
+ _jvmObject.Invoke("indexName", indexName);
+ return this;
+ }
+
+ ///
+ /// Updates column names for .
+ ///
+ /// Note: API signature supports passing one or more argument.
+ ///
+ /// Indexed column for the
+ /// .
+ /// Indexed columns for the
+ /// .
+ /// An object with updated indexed columns.
+ public Builder IndexBy(string indexedColumn, params string[] indexedColumns)
+ {
+ _jvmObject.Invoke("indexBy", indexedColumn, indexedColumns);
+ return this;
+ }
+
+ ///
+ /// Updates included columns for .
+ ///
+ /// Note: API signature supports passing one or more argument.
+ ///
+ /// Included column for .
+ ///
+ /// Included columns for .
+ ///
+ /// An object with updated included columns.
+ public Builder Include(string includedColumn, params string[] includedColumns)
+ {
+ _jvmObject.Invoke("include", includedColumn, includedColumns);
+ return this;
+ }
+
+ ///
+ /// Creates IndexConfig from supplied index name, indexed columns and included columns
+ /// to .
+ ///
+ /// An object.
+ public IndexConfig Create() =>
+ new IndexConfig((JvmObjectReference)_jvmObject.Invoke("create"));
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/IndexConfig.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/IndexConfig.cs
new file mode 100644
index 000000000..030dda2ca
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Index/IndexConfig.cs
@@ -0,0 +1,92 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Internal.Scala;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.Extensions.Hyperspace.Index
+{
+ ///
+ /// specifies the configuration of an index.
+ ///
+ public sealed class IndexConfig : IJvmObjectReferenceProvider
+ {
+ private static readonly string s_className = "com.microsoft.hyperspace.index.IndexConfig";
+ private readonly JvmObjectReference _jvmObject;
+
+ ///
+ /// specifies the configuration of an index.
+ ///
+ /// Index name.
+ /// Columns from which an index is created.
+ public IndexConfig(string indexName, IEnumerable indexedColumns)
+ : this(indexName, indexedColumns, new string[] { })
+ {
+ }
+
+ ///
+ /// specifies the configuration of an index.
+ ///
+ /// Index name.
+ /// Columns from which an index is created.
+ /// Columns to be included in the index.
+ public IndexConfig(
+ string indexName,
+ IEnumerable indexedColumns,
+ IEnumerable includedColumns)
+ {
+ IndexName = indexName;
+ IndexedColumns = new List(indexedColumns);
+ IncludedColumns = new List(includedColumns);
+
+ _jvmObject = (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_className,
+ "apply",
+ IndexName,
+ IndexedColumns,
+ IncludedColumns);
+ }
+
+ ///
+ /// specifies the configuration of an index.
+ ///
+ /// JVM object reference.
+ internal IndexConfig(JvmObjectReference jvmObject)
+ {
+ _jvmObject = jvmObject;
+ IndexName = (string)_jvmObject.Invoke("indexName");
+ IndexedColumns = new List(
+ new Seq((JvmObjectReference)_jvmObject.Invoke("indexedColumns")));
+ IncludedColumns = new List(
+ new Seq((JvmObjectReference)_jvmObject.Invoke("includedColumns")));
+ }
+
+ JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
+
+ public string IndexName { get; private set; }
+
+ public List IndexedColumns { get; private set; }
+
+ public List IncludedColumns { get; private set; }
+
+ ///
+ /// Creates new for constructing an
+ /// .
+ ///
+ /// An object.
+ public static Builder Builder() =>
+ new Builder(
+ (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_className,
+ "builder"));
+
+ public override bool Equals(object that) => (bool)_jvmObject.Invoke("equals", that);
+
+ public override int GetHashCode() => (int)_jvmObject.Invoke("hashCode");
+
+ public override string ToString() => (string)_jvmObject.Invoke("toString");
+ }
+}
diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Microsoft.Spark.Extensions.Hyperspace.csproj b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Microsoft.Spark.Extensions.Hyperspace.csproj
new file mode 100644
index 000000000..d85c62f71
--- /dev/null
+++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.Hyperspace/Microsoft.Spark.Extensions.Hyperspace.csproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0;netstandard2.1
+ true
+ true
+
+
+
+
+
+
+
diff --git a/src/csharp/Extensions/README.md b/src/csharp/Extensions/README.md
new file mode 100644
index 000000000..fa32b6946
--- /dev/null
+++ b/src/csharp/Extensions/README.md
@@ -0,0 +1,19 @@
+# .Net for Apache Spark Extensions
+
+## Table of Contents
+* [NuGet Packages](#nuget-packages)
+
+## NuGet Packages
+
+The following .Net for Apache Spark extensions are available as NuGet packages:
+
+### First-Party
+
+* [Microsoft.Spark.Extensions.Azure.Synapse.Analytics](https://www.nuget.org/packages/Microsoft.Spark.Extensions.Azure.Synapse.Analytics/)
+* [Microsoft.Spark.Extensions.Delta](https://www.nuget.org/packages/Microsoft.Spark.Extensions.Delta/)
+* [Microsoft.Spark.Extensions.DotNet.Interactive](https://www.nuget.org/packages/Microsoft.Spark.Extensions.DotNet.Interactive/)
+* [Microsoft.Spark.Extensions.Hyperspace](https://www.nuget.org/packages/Microsoft.Spark.Extensions.Hyperspace/)
+
+### Third-Party
+
+* Community-created extensions can be added here.
\ No newline at end of file
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
index 000c8f27e..511f5a122 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
@@ -1,10 +1,8 @@
using System;
-using System.Collections.Generic;
using System.Linq;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
-using static Microsoft.Spark.Sql.Functions;
using Xunit;
+using static Microsoft.Spark.Sql.Functions;
namespace Microsoft.Spark.E2ETest.IpcTests
{
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmBridgeTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmBridgeTests.cs
new file mode 100644
index 000000000..3ae609f5c
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmBridgeTests.cs
@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Spark.Sql;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests
+{
+ [Collection("Spark E2E Tests")]
+ public class JvmBridgeTests
+ {
+ private readonly SparkSession _spark;
+
+ public JvmBridgeTests(SparkFixture fixture)
+ {
+ _spark = fixture.Spark;
+ }
+
+ [Fact]
+ public void TestInnerJvmException()
+ {
+ try
+ {
+ _spark.Sql("THROW!!!");
+ }
+ catch (Exception ex)
+ {
+ Assert.NotNull(ex.InnerException);
+ Assert.IsType(ex.InnerException);
+ Assert.False(string.IsNullOrWhiteSpace(ex.InnerException.Message));
+ }
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs
index 11037bc6d..e9193fd0b 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs
@@ -4,9 +4,10 @@
using System.Collections.Generic;
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
+using Microsoft.Spark.ML.Feature.Param;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
@@ -58,6 +59,19 @@ public void TestBucketizer()
Bucketizer loadedBucketizer = Bucketizer.Load(savePath);
Assert.Equal(bucketizer.Uid(), loadedBucketizer.Uid());
}
+
+ Assert.NotEmpty(bucketizer.ExplainParams());
+
+ Param handleInvalidParam = bucketizer.GetParam("handleInvalid");
+ Assert.NotEmpty(handleInvalidParam.Doc);
+ Assert.NotEmpty(handleInvalidParam.Name);
+ Assert.Equal(handleInvalidParam.Parent, bucketizer.Uid());
+
+ Assert.NotEmpty(bucketizer.ExplainParam(handleInvalidParam));
+ bucketizer.Set(handleInvalidParam, "keep");
+ Assert.Equal("keep", bucketizer.GetHandleInvalid());
+
+ Assert.Equal("error", bucketizer.Clear(handleInvalidParam).GetHandleInvalid());
}
[Fact]
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs
index 7b6882bea..df459ed7a 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs
@@ -2,13 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Collections.Generic;
using System.IO;
-using System.Linq;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs
index 623b7322c..202187809 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs
@@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs
index 3dea63de7..72da97887 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs
@@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs
index 8cdb4e03a..4b1998f50 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs
@@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs
index 4845e011a..a5227149b 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs
@@ -2,11 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs
index 30e14ed28..1d5da5335 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs
@@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.
using System.IO;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Param/ParamTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Param/ParamTests.cs
new file mode 100644
index 000000000..ecb9166e1
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Param/ParamTests.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.ML.Feature.Param;
+using Microsoft.Spark.Sql;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests.ML.ParamTests
+{
+ [Collection("Spark E2E Tests")]
+ public class ParamTests
+ {
+ private readonly SparkSession _spark;
+
+ public ParamTests(SparkFixture fixture)
+ {
+ _spark = fixture.Spark;
+ }
+
+ [Fact]
+ public void Test()
+ {
+ const string expectedParent = "parent";
+ const string expectedName = "name";
+ const string expectedDoc = "doc";
+
+ var param = new Param(expectedParent, expectedName, expectedDoc);
+
+ Assert.Equal(expectedParent, param.Parent);
+ Assert.Equal(expectedDoc, param.Doc);
+ Assert.Equal(expectedName, param.Name);
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs
index 07fbf2372..ca752570a 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs
@@ -3,7 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
-using Microsoft.Spark.E2ETest.Utils;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs
index 7359bdb6b..46e899a87 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs
@@ -3,13 +3,13 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections.Generic;
using System.Linq;
using Apache.Arrow;
using Microsoft.Data.Analysis;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
using static Microsoft.Spark.Sql.Functions;
using static Microsoft.Spark.UnitTest.TestUtils.ArrowTestUtils;
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs
index a7e214160..4f0d06742 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs
@@ -3,8 +3,8 @@
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
index c312ddc6c..5a70a6698 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
@@ -94,7 +94,7 @@ public void TestCreateDataFrame()
// Calling CreateDataFrame(IEnumerable _) without schema
{
- var data = new List(new string[] { "Alice", "Bob" });
+ var data = new string[] { "Alice", "Bob", null };
StructType schema = SchemaWithSingleColumn(new StringType());
DataFrame df = _spark.CreateDataFrame(data);
@@ -103,7 +103,16 @@ public void TestCreateDataFrame()
// Calling CreateDataFrame(IEnumerable _) without schema
{
- var data = new List(new int[] { 1, 2 });
+ var data = new int[] { 1, 2 };
+ StructType schema = SchemaWithSingleColumn(new IntegerType(), false);
+
+ DataFrame df = _spark.CreateDataFrame(data);
+ ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
+ }
+
+ // Calling CreateDataFrame(IEnumerable _) without schema
+ {
+ var data = new int?[] { 1, 2, null };
StructType schema = SchemaWithSingleColumn(new IntegerType());
DataFrame df = _spark.CreateDataFrame(data);
@@ -112,7 +121,16 @@ public void TestCreateDataFrame()
// Calling CreateDataFrame(IEnumerable _) without schema
{
- var data = new List(new double[] { 1.2, 2.3 });
+ var data = new double[] { 1.2, 2.3 };
+ StructType schema = SchemaWithSingleColumn(new DoubleType(), false);
+
+ DataFrame df = _spark.CreateDataFrame(data);
+ ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
+ }
+
+ // Calling CreateDataFrame(IEnumerable _) without schema
+ {
+ var data = new double?[] { 1.2, 2.3, null };
StructType schema = SchemaWithSingleColumn(new DoubleType());
DataFrame df = _spark.CreateDataFrame(data);
@@ -121,19 +139,29 @@ public void TestCreateDataFrame()
// Calling CreateDataFrame(IEnumerable _) without schema
{
- var data = new List(new bool[] { true, false });
+ var data = new bool[] { true, false };
+ StructType schema = SchemaWithSingleColumn(new BooleanType(), false);
+
+ DataFrame df = _spark.CreateDataFrame(data);
+ ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
+ }
+
+ // Calling CreateDataFrame(IEnumerable _) without schema
+ {
+ var data = new bool?[] { true, false, null };
StructType schema = SchemaWithSingleColumn(new BooleanType());
DataFrame df = _spark.CreateDataFrame(data);
ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
}
-
+
// Calling CreateDataFrame(IEnumerable _) without schema
{
var data = new Date[]
{
new Date(2020, 1, 1),
- new Date(2020, 1, 2)
+ new Date(2020, 1, 2),
+ null
};
StructType schema = SchemaWithSingleColumn(new DateType());
@@ -151,7 +179,8 @@ public void TestCreateDataFrameWithTimestamp()
var data = new Timestamp[]
{
new Timestamp(2020, 1, 1, 0, 0, 0, 0),
- new Timestamp(2020, 1, 2, 15, 30, 30, 0)
+ new Timestamp(2020, 1, 2, 15, 30, 30, 0),
+ null
};
StructType schema = SchemaWithSingleColumn(new TimestampType());
@@ -172,8 +201,9 @@ private void ValidateDataFrame(
/// Returns a single column schema of the given datatype.
///
/// Datatype of the column
+ /// Indicates if values of the column can be null
/// Schema as StructType
- private StructType SchemaWithSingleColumn(DataType dataType) =>
- new StructType(new[] { new StructField("_1", dataType) });
+ private StructType SchemaWithSingleColumn(DataType dataType, bool isNullable = true) =>
+ new StructType(new[] { new StructField("_1", dataType, isNullable) });
}
}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs
index 4e87dc6c6..0983035f4 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs
@@ -6,10 +6,12 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Threading;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Streaming;
using Microsoft.Spark.Sql.Types;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
using static Microsoft.Spark.Sql.Functions;
@@ -66,6 +68,69 @@ public void TestSignaturesV2_3_X()
Assert.IsType(dsw.Trigger(Trigger.Once()));
}
+ [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
+ public void TestForeachBatch()
+ {
+ // Temporary folder to put our test stream input.
+ using var srcTempDirectory = new TemporaryDirectory();
+ // Temporary folder to write ForeachBatch output.
+ using var dstTempDirectory = new TemporaryDirectory();
+
+ Func outerUdf = Udf(i => i + 100);
+
+ // id column: [0, 1, ..., 9]
+ WriteCsv(0, 10, Path.Combine(srcTempDirectory.Path, "input1.csv"));
+
+ DataStreamWriter dsw = _spark
+ .ReadStream()
+ .Schema("id INT")
+ .Csv(srcTempDirectory.Path)
+ .WriteStream()
+ .ForeachBatch((df, id) =>
+ {
+ Func innerUdf = Udf(i => i + 200);
+ df.Select(outerUdf(innerUdf(Col("id"))))
+ .Write()
+ .Csv(Path.Combine(dstTempDirectory.Path, id.ToString()));
+ });
+
+ StreamingQuery sq = dsw.Start();
+
+ // Process until all available data in the source has been processed and committed
+ // to the ForeachBatch sink.
+ sq.ProcessAllAvailable();
+
+ // Add new file to the source path. The spark stream will read any new files
+ // added to the source path.
+ // id column: [10, 11, ..., 19]
+ WriteCsv(10, 10, Path.Combine(srcTempDirectory.Path, "input2.csv"));
+
+ // Process until all available data in the source has been processed and committed
+ // to the ForeachBatch sink.
+ sq.ProcessAllAvailable();
+ sq.Stop();
+
+ // Verify folders in the destination path.
+ string[] csvPaths =
+ Directory.GetDirectories(dstTempDirectory.Path).OrderBy(s => s).ToArray();
+ var expectedPaths = new string[]
+ {
+ Path.Combine(dstTempDirectory.Path, "0"),
+ Path.Combine(dstTempDirectory.Path, "1"),
+ };
+ Assert.True(expectedPaths.SequenceEqual(csvPaths));
+
+ // Read the generated csv paths and verify contents.
+ DataFrame df = _spark
+ .Read()
+ .Schema("id INT")
+ .Csv(csvPaths[0], csvPaths[1])
+ .Sort("id");
+
+ IEnumerable actualIds = df.Collect().Select(r => r.GetAs("id"));
+ Assert.True(Enumerable.Range(300, 20).SequenceEqual(actualIds));
+ }
+
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
public void TestForeach()
{
@@ -199,6 +264,15 @@ private void TestAndValidateForeach(
foreachWriterOutputDF.Collect().Select(r => r.Values));
}
+ private void WriteCsv(int start, int count, string path)
+ {
+ using var streamWriter = new StreamWriter(path);
+ foreach (int i in Enumerable.Range(start, count))
+ {
+ streamWriter.WriteLine(i);
+ }
+ }
+
[Serializable]
private class TestForeachWriter : IForeachWriter
{
diff --git a/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj b/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj
index abe436ec9..7a6240ecc 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj
+++ b/src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj
@@ -12,6 +12,7 @@
+
@@ -23,6 +24,7 @@
+
diff --git a/src/csharp/Microsoft.Spark.E2ETest/SparkFixture.cs b/src/csharp/Microsoft.Spark.E2ETest/SparkFixture.cs
index fc8272c5b..6d8dadbac 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/SparkFixture.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/SparkFixture.cs
@@ -7,9 +7,9 @@
using System.IO;
using System.Reflection;
using System.Runtime.InteropServices;
-using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;
namespace Microsoft.Spark.E2ETest
diff --git a/src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs b/src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs
index e4c4cabb9..92422c205 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs
@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Threading;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Xunit;
@@ -166,5 +167,29 @@ public void TestUdfWithReturnAsTimestampType()
}
}
}
+
+ ///
+ /// Test to validate UDFs defined in separate threads work.
+ ///
+ [Fact]
+ public void TestUdfWithMultipleThreads()
+ {
+ try
+ {
+ void DefineUdf() => Udf(str => str);
+
+ // Define a UDF in the main thread.
+ Udf(str => str);
+
+ // Verify a UDF can be defined in a separate thread.
+ Thread t = new Thread(DefineUdf);
+ t.Start();
+ t.Join();
+ }
+ catch (Exception)
+ {
+ Assert.True(false);
+ }
+ }
}
}
diff --git a/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs b/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs
index da7d05197..c2c5e63ee 100644
--- a/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs
+++ b/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs
@@ -4,22 +4,35 @@
using System;
using System.IO;
+using System.Reflection;
+using System.Runtime.Loader;
+using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Utils;
+using Moq;
using Xunit;
namespace Microsoft.Spark.UnitTest
{
+ [Collection("Spark Unit Tests")]
public class AssemblyLoaderTests
{
+ private readonly Mock _mockJvm;
+
+ public AssemblyLoaderTests(SparkFixture _fixture)
+ {
+ _mockJvm = _fixture.MockJvm;
+ }
+
[Fact]
public void TestAssemblySearchPathResolver()
{
+ string sparkFilesDir = SparkFiles.GetRootDirectory();
string curDir = Directory.GetCurrentDirectory();
string appDir = AppDomain.CurrentDomain.BaseDirectory;
// Test the default scenario.
string[] searchPaths = AssemblySearchPathResolver.GetAssemblySearchPaths();
- Assert.Equal(new[] { curDir, appDir }, searchPaths);
+ Assert.Equal(new[] { sparkFilesDir, curDir, appDir }, searchPaths);
// Test the case where DOTNET_ASSEMBLY_SEARCH_PATHS is defined.
char sep = Path.PathSeparator;
@@ -34,6 +47,7 @@ public void TestAssemblySearchPathResolver()
"mydir2",
Path.Combine(curDir, $".{sep}mydir3"),
Path.Combine(curDir, $".{sep}mydir4"),
+ sparkFilesDir,
curDir,
appDir },
searchPaths);
@@ -42,5 +56,20 @@ public void TestAssemblySearchPathResolver()
AssemblySearchPathResolver.AssemblySearchPathsEnvVarName,
null);
}
+
+ [Fact]
+ public void TestResolveAssemblyWithRelativePath()
+ {
+ _mockJvm.Setup(m => m.CallStaticJavaMethod(
+ "org.apache.spark.SparkFiles",
+ "getRootDirectory"))
+ .Returns(".");
+
+ AssemblyLoader.LoadFromFile = AssemblyLoadContext.Default.LoadFromAssemblyPath;
+ Assembly expectedAssembly = Assembly.GetExecutingAssembly();
+ Assembly actualAssembly = AssemblyLoader.ResolveAssembly(expectedAssembly.FullName);
+
+ Assert.Equal(expectedAssembly, actualAssembly);
+ }
}
}
diff --git a/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs
new file mode 100644
index 000000000..04266e814
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs
@@ -0,0 +1,239 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Network;
+using Moq;
+using Xunit;
+
+namespace Microsoft.Spark.UnitTest
+{
+ [Collection("Spark Unit Tests")]
+ public class CallbackTests
+ {
+ private readonly Mock _mockJvm;
+
+ public CallbackTests(SparkFixture fixture)
+ {
+ _mockJvm = fixture.MockJvm;
+ }
+
+ [Fact]
+ public async Task TestCallbackIds()
+ {
+ int numToRegister = 100;
+ var callbackServer = new CallbackServer(_mockJvm.Object, false);
+ var callbackHandler = new TestCallbackHandler();
+
+ var ids = new ConcurrentBag();
+ var tasks = new List();
+ for (int i = 0; i < numToRegister; ++i)
+ {
+ tasks.Add(
+ Task.Run(() => ids.Add(callbackServer.RegisterCallback(callbackHandler))));
+ }
+
+ await Task.WhenAll(tasks);
+
+ IOrderedEnumerable actualIds = ids.OrderBy(i => i);
+ IEnumerable expectedIds = Enumerable.Range(1, numToRegister);
+ Assert.True(expectedIds.SequenceEqual(actualIds));
+ }
+
+ [Fact]
+ public void TestCallbackServer()
+ {
+ var callbackServer = new CallbackServer(_mockJvm.Object, false);
+ var callbackHandler = new TestCallbackHandler();
+
+ callbackHandler.Id = callbackServer.RegisterCallback(callbackHandler);
+ Assert.Equal(1, callbackHandler.Id);
+
+ using ISocketWrapper callbackSocket = SocketFactory.CreateSocket();
+ callbackServer.Run(callbackSocket);
+
+ int connectionNumber = 10;
+ for (int i = 0; i < connectionNumber; ++i)
+ {
+ var ipEndpoint = (IPEndPoint)callbackSocket.LocalEndPoint;
+ ISocketWrapper clientSocket = SocketFactory.CreateSocket();
+ clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);
+
+ WriteAndReadTestData(clientSocket, callbackHandler, i, new CancellationToken());
+ }
+
+ Assert.Equal(connectionNumber, callbackServer.CurrentNumConnections);
+
+ IOrderedEnumerable actualValues = callbackHandler.Inputs.OrderBy(i => i);
+ IEnumerable expectedValues = Enumerable
+ .Range(0, connectionNumber)
+ .Select(i => callbackHandler.Apply(i))
+ .OrderBy(i => i);
+ Assert.True(expectedValues.SequenceEqual(actualValues));
+ }
+
+ [Fact]
+ public void TestCallbackHandlers()
+ {
+ var tokenSource = new CancellationTokenSource();
+ var callbackHandlersDict = new ConcurrentDictionary();
+ int inputToHandler = 1;
+ {
+ // Test CallbackConnection using a ICallbackHandler that runs
+ // normally without error.
+ var callbackHandler = new TestCallbackHandler
+ {
+ Id = 1
+ };
+ callbackHandlersDict[callbackHandler.Id] = callbackHandler;
+ TestCallbackConnection(
+ callbackHandlersDict,
+ callbackHandler,
+ inputToHandler,
+ tokenSource.Token);
+ Assert.Single(callbackHandler.Inputs);
+ Assert.Equal(
+ callbackHandler.Apply(inputToHandler),
+ callbackHandler.Inputs.First());
+ }
+ {
+ // Test CallbackConnection using a ICallbackHandler that
+ // throws an exception.
+ var callbackHandler = new ThrowsExceptionHandler
+ {
+ Id = 2
+ };
+ callbackHandlersDict[callbackHandler.Id] = callbackHandler;
+ TestCallbackConnection(
+ callbackHandlersDict,
+ callbackHandler,
+ inputToHandler,
+ tokenSource.Token);
+ Assert.Empty(callbackHandler.Inputs);
+ }
+ {
+ // Test CallbackConnection when cancellation has been requested for the token.
+ tokenSource.Cancel();
+ var callbackHandler = new TestCallbackHandler
+ {
+ Id = 3
+ };
+ callbackHandlersDict[callbackHandler.Id] = callbackHandler;
+ TestCallbackConnection(
+ callbackHandlersDict,
+ callbackHandler,
+ inputToHandler,
+ tokenSource.Token);
+ Assert.Empty(callbackHandler.Inputs);
+ }
+ }
+
+ private void TestCallbackConnection(
+ ConcurrentDictionary callbackHandlersDict,
+ ITestCallbackHandler callbackHandler,
+ int inputToHandler,
+ CancellationToken token)
+ {
+ using ISocketWrapper serverListener = SocketFactory.CreateSocket();
+ serverListener.Listen();
+
+ var ipEndpoint = (IPEndPoint)serverListener.LocalEndPoint;
+ ISocketWrapper clientSocket = SocketFactory.CreateSocket();
+ clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);
+
+ var callbackConnection = new CallbackConnection(0, clientSocket, callbackHandlersDict);
+ Task.Run(() => callbackConnection.Run(token));
+
+ using ISocketWrapper serverSocket = serverListener.Accept();
+ WriteAndReadTestData(serverSocket, callbackHandler, inputToHandler, token);
+ }
+
+ private void WriteAndReadTestData(
+ ISocketWrapper socket,
+ ITestCallbackHandler callbackHandler,
+ int inputToHandler,
+ CancellationToken token)
+ {
+ Stream inputStream = socket.InputStream;
+ Stream outputStream = socket.OutputStream;
+
+ SerDe.Write(outputStream, (int)CallbackFlags.CALLBACK);
+ SerDe.Write(outputStream, callbackHandler.Id);
+ SerDe.Write(outputStream, sizeof(int));
+ SerDe.Write(outputStream, inputToHandler);
+ SerDe.Write(outputStream, (int)CallbackFlags.END_OF_STREAM);
+ outputStream.Flush();
+
+ if (token.IsCancellationRequested)
+ {
+ Assert.Throws(() => SerDe.ReadInt32(inputStream));
+ }
+ else
+ {
+ int callbackFlag = SerDe.ReadInt32(inputStream);
+ if (callbackFlag == (int)CallbackFlags.DOTNET_EXCEPTION_THROWN)
+ {
+ string exceptionMessage = SerDe.ReadString(inputStream);
+ Assert.False(string.IsNullOrEmpty(exceptionMessage));
+ Assert.Contains(callbackHandler.ExceptionMessage, exceptionMessage);
+ }
+ else
+ {
+ Assert.Equal((int)CallbackFlags.END_OF_STREAM, callbackFlag);
+ }
+ }
+ }
+
+ private class TestCallbackHandler : ICallbackHandler, ITestCallbackHandler
+ {
+ public void Run(Stream inputStream) => Inputs.Add(Apply(SerDe.ReadInt32(inputStream)));
+
+ public ConcurrentBag Inputs { get; } = new ConcurrentBag();
+
+ public int Id { get; set; }
+
+ public bool Throws { get; } = false;
+
+ public string ExceptionMessage => throw new NotImplementedException();
+
+ public int Apply(int i) => 10 * i;
+ }
+
+ private class ThrowsExceptionHandler : ICallbackHandler, ITestCallbackHandler
+ {
+ public void Run(Stream inputStream) => throw new Exception(ExceptionMessage);
+
+ public ConcurrentBag Inputs { get; } = new ConcurrentBag();
+
+ public int Id { get; set; }
+
+ public bool Throws { get; } = true;
+
+ public string ExceptionMessage { get; } = "Dotnet Callback Handler Exception Message";
+
+ public int Apply(int i) => throw new NotImplementedException();
+ }
+
+ private interface ITestCallbackHandler
+ {
+ ConcurrentBag Inputs { get; }
+
+ int Id { get; set; }
+
+ bool Throws { get; }
+
+ string ExceptionMessage { get; }
+
+ int Apply(int i);
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark.UnitTest/CollectionUtilsTests.cs b/src/csharp/Microsoft.Spark.UnitTest/CollectionUtilsTests.cs
new file mode 100644
index 000000000..9a723b2b5
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.UnitTest/CollectionUtilsTests.cs
@@ -0,0 +1,26 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Utils;
+using Xunit;
+
+namespace Microsoft.Spark.UnitTest
+{
+ public class CollectionUtilsTests
+ {
+ [Fact]
+ public void TestArrayEquals()
+ {
+ Assert.False(CollectionUtils.ArrayEquals(new int[] { 1 }, null));
+ Assert.False(CollectionUtils.ArrayEquals(null, new int[] { 1 }));
+ Assert.False(CollectionUtils.ArrayEquals(new int[] { }, new int[] { 1 }));
+ Assert.False(CollectionUtils.ArrayEquals(new int[] { 1 }, new int[] { }));
+ Assert.False(CollectionUtils.ArrayEquals(new int[] { 1 }, new int[] { 1, 2 }));
+ Assert.False(CollectionUtils.ArrayEquals(new int[] { 1 }, new int[] { 2 }));
+
+ Assert.True(CollectionUtils.ArrayEquals