diff --git a/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs b/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs index f2f0dd30e..c2c5e63ee 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/AssemblyLoaderTests.cs @@ -4,7 +4,11 @@ using System; using System.IO; +using System.Reflection; +using System.Runtime.Loader; +using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Utils; +using Moq; using Xunit; namespace Microsoft.Spark.UnitTest @@ -12,6 +16,13 @@ namespace Microsoft.Spark.UnitTest [Collection("Spark Unit Tests")] public class AssemblyLoaderTests { + private readonly Mock _mockJvm; + + public AssemblyLoaderTests(SparkFixture _fixture) + { + _mockJvm = _fixture.MockJvm; + } + [Fact] public void TestAssemblySearchPathResolver() { @@ -45,5 +56,20 @@ public void TestAssemblySearchPathResolver() AssemblySearchPathResolver.AssemblySearchPathsEnvVarName, null); } + + [Fact] + public void TestResolveAssemblyWithRelativePath() + { + _mockJvm.Setup(m => m.CallStaticJavaMethod( + "org.apache.spark.SparkFiles", + "getRootDirectory")) + .Returns("."); + + AssemblyLoader.LoadFromFile = AssemblyLoadContext.Default.LoadFromAssemblyPath; + Assembly expectedAssembly = Assembly.GetExecutingAssembly(); + Assembly actualAssembly = AssemblyLoader.ResolveAssembly(expectedAssembly.FullName); + + Assert.Equal(expectedAssembly, actualAssembly); + } } } diff --git a/src/csharp/Microsoft.Spark/Utils/AssemblyLoader.cs b/src/csharp/Microsoft.Spark/Utils/AssemblyLoader.cs index 3b9b34f5e..fbc6e199a 100644 --- a/src/csharp/Microsoft.Spark/Utils/AssemblyLoader.cs +++ b/src/csharp/Microsoft.Spark/Utils/AssemblyLoader.cs @@ -189,12 +189,12 @@ private static bool TryLoadAssembly(string assemblyFileName, ref Assembly assemb { foreach (string searchPath in s_searchPaths.Value) { - string assemblyPath = Path.Combine(searchPath, assemblyFileName); - if (File.Exists(assemblyPath)) + var assemblyFile = new FileInfo(Path.Combine(searchPath, assemblyFileName)); + if (assemblyFile.Exists) { try { - assembly = LoadFromFile(assemblyPath); + assembly = LoadFromFile(assemblyFile.FullName); return true; } catch (Exception ex) when (