Skip to content

Commit

Permalink
Copy (and simplify) VisualStudioSourceInformationProvider from xunit.…
Browse files Browse the repository at this point in the history
…runner.utility
  • Loading branch information
bradwilson committed Sep 25, 2023
1 parent 3974475 commit 85864c8
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 36 deletions.
73 changes: 73 additions & 0 deletions src/xunit.runner.visualstudio/Utility/AppDomainManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#if NETFRAMEWORK

using System;
using System.IO;
using System.Reflection;
using System.Runtime.ExceptionServices;
using System.Security;
using System.Security.Permissions;
using Xunit.Internal;

namespace Xunit.Runner.VisualStudio;

class AppDomainManager
{
readonly AppDomain appDomain;

public AppDomainManager(string assemblyFileName)
{
Guard.ArgumentNotNullOrEmpty(assemblyFileName);

assemblyFileName = Path.GetFullPath(assemblyFileName);
Guard.FileExists(assemblyFileName);

var applicationBase = Path.GetDirectoryName(assemblyFileName);
var applicationName = Guid.NewGuid().ToString();
var setup = new AppDomainSetup
{
ApplicationBase = applicationBase,
ApplicationName = applicationName,
ShadowCopyFiles = "true",
ShadowCopyDirectories = applicationBase,
CachePath = Path.Combine(Path.GetTempPath(), applicationName)
};

appDomain = AppDomain.CreateDomain(Path.GetFileNameWithoutExtension(assemblyFileName), AppDomain.CurrentDomain.Evidence, setup, new PermissionSet(PermissionState.Unrestricted));
}

public TObject? CreateObject<TObject>(
AssemblyName assemblyName,
string typeName,
params object[] args)
where TObject : class
{
try
{
return appDomain.CreateInstanceAndUnwrap(assemblyName.FullName, typeName, false, BindingFlags.Default, null, args, null, null) as TObject;
}
catch (TargetInvocationException ex)
{
ExceptionDispatchInfo.Capture(ex.InnerException ?? ex).Throw();
return default; // Will never reach here, but the compiler doesn't know that
}
}

public virtual void Dispose()
{
if (appDomain is not null)
{
var cachePath = appDomain.SetupInformation.CachePath;

try
{
AppDomain.Unload(appDomain);

if (cachePath is not null)
Directory.Delete(cachePath, true);
}
catch { }
}
}
}

#endif
52 changes: 52 additions & 0 deletions src/xunit.runner.visualstudio/Utility/DiaSessionWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System;
using Microsoft.VisualStudio.TestPlatform.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.ObjectModel.Navigation;

namespace Xunit.Runner.VisualStudio.Utility;

// This class wraps DiaSession, and uses DiaSessionWrapperHelper to discover when a test is an async test
// (since that requires special handling by DIA). The wrapper helper needs to exist in a separate AppDomain
// so that we can do discovery without locking the assembly under test (for .NET Framework).
class DiaSessionWrapper : IDisposable
{
#if NETFRAMEWORK
readonly AppDomainManager? appDomainManager;
#endif
readonly DiaSessionWrapperHelper? helper;
readonly DiaSession session;

public DiaSessionWrapper(string assemblyFileName)
{
session = new DiaSession(assemblyFileName);

#if NETFRAMEWORK
var adapterFileName = typeof(DiaSessionWrapperHelper).Assembly.GetLocalCodeBase();
if (adapterFileName is not null)
{
appDomainManager = new AppDomainManager(assemblyFileName);
helper = appDomainManager.CreateObject<DiaSessionWrapperHelper>(typeof(DiaSessionWrapperHelper).Assembly.GetName(), typeof(DiaSessionWrapperHelper).FullName!, adapterFileName);
}
#else
helper = new DiaSessionWrapperHelper(assemblyFileName);
#endif
}

public INavigationData? GetNavigationData(
string typeName,
string methodName)
{
if (helper is null)
return null;

helper.Normalize(ref typeName, ref methodName);
return session.GetNavigationDataForMethod(typeName, methodName);
}

public void Dispose()
{
session.Dispose();
#if NETFRAMEWORK
appDomainManager?.Dispose();
#endif
}
}
114 changes: 114 additions & 0 deletions src/xunit.runner.visualstudio/Utility/DiaSessionWrapperHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using Xunit.Sdk;

namespace Xunit.Runner.VisualStudio.Utility;

class DiaSessionWrapperHelper : LongLivedMarshalByRefObject
{
readonly Assembly? assembly;
readonly Dictionary<string, Type> typeNameMap;

public DiaSessionWrapperHelper(string assemblyFileName)
{
try
{
#if NETFRAMEWORK
assembly = Assembly.ReflectionOnlyLoadFrom(assemblyFileName);
var assemblyDirectory = Path.GetDirectoryName(assemblyFileName);

if (assemblyDirectory is not null)
AppDomain.CurrentDomain.ReflectionOnlyAssemblyResolve += (sender, args) =>
{
try
{
// Try to load it normally
var name = AppDomain.CurrentDomain.ApplyPolicy(args.Name);
return Assembly.ReflectionOnlyLoad(name);
}
catch
{
try
{
// If a normal implicit load fails, try to load it from the directory that
// the test assembly lives in
return Assembly.ReflectionOnlyLoadFrom(
Path.Combine(
assemblyDirectory,
new AssemblyName(args.Name).Name + ".dll"
)
);
}
catch
{
// If all else fails, say we couldn't find it
return null;
}
}
};
#else
assembly = Assembly.Load(new AssemblyName { Name = Path.GetFileNameWithoutExtension(assemblyFileName) });
#endif
}
catch { }

if (assembly is not null)
{
Type?[]? types = null;

try
{
types = assembly.GetTypes();
}
catch (ReflectionTypeLoadException ex)
{
types = ex.Types;
}
catch { } // Ignore anything other than ReflectionTypeLoadException

if (types is not null)
typeNameMap =
types
.WhereNotNull()
.Where(t => !string.IsNullOrEmpty(t.FullName))
.ToDictionaryIgnoringDuplicateKeys(k => k.FullName!);
}

typeNameMap ??= new();
}

public void Normalize(
ref string typeName,
ref string methodName)
{
try
{
if (assembly is null)
return;

if (typeNameMap.TryGetValue(typeName, out var type) && type is not null)
{
var method = type.GetMethod(methodName);
if (method is not null && method.DeclaringType is not null && method.DeclaringType.FullName is not null)
{
// DiaSession only ever wants you to ask for the declaring type
typeName = method.DeclaringType.FullName;

// See if this is an async method by looking for [AsyncStateMachine] on the method,
// which means we need to pass the state machine's "MoveNext" method.
var stateMachineType = method.GetCustomAttribute<AsyncStateMachineAttribute>()?.StateMachineType;
if (stateMachineType is not null && stateMachineType.FullName is not null)
{
typeName = stateMachineType.FullName;
methodName = "MoveNext";
}
}
}
}
catch { }
}
}
7 changes: 7 additions & 0 deletions src/xunit.runner.visualstudio/Utility/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

static class DictionaryExtensions
{
public static Dictionary<TKey, TValue> ToDictionaryIgnoringDuplicateKeys<TKey, TValue>(
this IEnumerable<TValue> inputValues,
Func<TValue, TKey> keySelector,
IEqualityComparer<TKey>? comparer = null)
where TKey : notnull =>
ToDictionaryIgnoringDuplicateKeys(inputValues, keySelector, x => x, comparer);

public static Dictionary<TKey, TValue> ToDictionaryIgnoringDuplicateKeys<TInput, TKey, TValue>(
this IEnumerable<TInput> inputValues,
Func<TInput, TKey> keySelector,
Expand Down
35 changes: 35 additions & 0 deletions src/xunit.runner.visualstudio/Utility/Guard.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.CompilerServices;

namespace Xunit.Internal;
Expand Down Expand Up @@ -49,4 +50,38 @@ public static T ArgumentNotNullOrEmpty<T>(

return argValue;
}

/// <summary>
/// Ensures that an argument is valid.
/// </summary>
/// <param name="message">The exception message to use when the argument is not valid</param>
/// <param name="test">The validity test value</param>
/// <param name="argName">The name of the argument</param>
/// <returns>The argument value as a non-null value</returns>
/// <exception cref="ArgumentException">Thrown when the argument is not valid</exception>
public static void ArgumentValid(
string message,
bool test,
string? argName = null)
{
if (!test)
throw new ArgumentException(message, argName);
}

/// <summary>
/// Ensures that a filename argument is not null or empty, and that the file exists on disk.
/// </summary>
/// <param name="fileName">The file name value</param>
/// <param name="argName">The name of the argument</param>
/// <returns>The file name as a non-null value</returns>
/// <exception cref="ArgumentException">Thrown when the argument is null, empty, or not on disk</exception>
public static string FileExists(
[NotNull] string? fileName,
[CallerArgumentExpression(nameof(fileName))] string? argName = null)
{
ArgumentNotNullOrEmpty(fileName, argName);
ArgumentValid($"File not found: {fileName}", File.Exists(fileName), argName?.TrimStart('@'));

return fileName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Xunit.Abstractions;
using Xunit.Runner.VisualStudio.Utility;
using Xunit.Sdk;

namespace Xunit.Runner.VisualStudio;

/// <summary>
/// An implementation of <see cref="ISourceInformationProvider"/> that will provide source information
/// when running inside of Visual Studio (via the DiaSession class).
/// </summary>
public class VisualStudioSourceInformationProvider : LongLivedMarshalByRefObject, ISourceInformationProvider
{
static readonly SourceInformation EmptySourceInformation = new();

readonly DiaSessionWrapper session;

/// <summary>
/// Initializes a new instance of the <see cref="VisualStudioSourceInformationProvider" /> class.
/// </summary>
/// <param name="assemblyFileName">The assembly file name.</param>
public VisualStudioSourceInformationProvider(string assemblyFileName)
{
session = new DiaSessionWrapper(assemblyFileName);
}

/// <inheritdoc/>
public ISourceInformation GetSourceInformation(ITestCase testCase)
{
var navData = session.GetNavigationData(testCase.TestMethod.TestClass.Class.Name, testCase.TestMethod.Method.Name);
if (navData is null || navData.FileName is null)
return EmptySourceInformation;

return new SourceInformation
{
FileName = navData.FileName,
LineNumber = navData.MinLineNumber
};
}

/// <inheritdoc/>
public void Dispose()
{
session.Dispose();
}
}
8 changes: 5 additions & 3 deletions src/xunit.runner.visualstudio/VsTestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ void DiscoverTests<TVisitor>(
var diagnosticSink = DiagnosticMessageSink.ForDiagnostics(logger, fileName, assembly.Configuration.DiagnosticMessagesOrDefault);
var appDomain = assembly.Configuration.AppDomain ?? AppDomainDefaultBehavior;

using var framework = new XunitFrontController(appDomain, assembly.AssemblyFilename, shadowCopy: shadowCopy, diagnosticMessageSink: MessageSinkAdapter.Wrap(diagnosticSink));
if (!DiscoverTestsInSource(framework, logger, testPlatformContext, runSettings, visitorFactory, visitComplete, assembly))
using var sourceInformationProvider = new VisualStudioSourceInformationProvider(assembly.AssemblyFilename);
using var controller = new XunitFrontController(appDomain, assembly.AssemblyFilename, shadowCopy: shadowCopy, sourceInformationProvider: sourceInformationProvider, diagnosticMessageSink: MessageSinkAdapter.Wrap(diagnosticSink));
if (!DiscoverTestsInSource(controller, logger, testPlatformContext, runSettings, visitorFactory, visitComplete, assembly))
break;
}
}
Expand Down Expand Up @@ -427,7 +428,8 @@ void RunTestsInAssembly(

var diagnosticSink = DiagnosticMessageSink.ForDiagnostics(logger, assemblyDisplayName, runInfo.Assembly.Configuration.DiagnosticMessagesOrDefault);
var diagnosticMessageSink = MessageSinkAdapter.Wrap(diagnosticSink);
using var controller = new XunitFrontController(appDomain, assemblyFileName, shadowCopy: shadowCopy, diagnosticMessageSink: diagnosticMessageSink);
using var sourceInformationProvider = new VisualStudioSourceInformationProvider(assemblyFileName);
using var controller = new XunitFrontController(appDomain, assemblyFileName, shadowCopy: shadowCopy, sourceInformationProvider: sourceInformationProvider, diagnosticMessageSink: diagnosticMessageSink);
var testCasesMap = new Dictionary<string, TestCase>();
var testCases = new List<ITestCase>();
if (runInfo.TestCases is null || !runInfo.TestCases.Any())
Expand Down
Loading

0 comments on commit 85864c8

Please sign in to comment.