Skip to content

Commit

Permalink
Merge pull request #73493 from CyrusNajmabadi/dependentProjectsFinder
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored May 16, 2024
2 parents e85832f + 7307033 commit 6d1245f
Showing 1 changed file with 98 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
Expand All @@ -28,7 +29,15 @@ internal static partial class DependentProjectsFinder
/// Cache from the <see cref="MetadataId"/> for a particular <see cref="PortableExecutableReference"/> to the
/// name of the <see cref="IAssemblySymbol"/> defined by it.
/// </summary>
private static ImmutableDictionary<MetadataId, string> s_metadataIdToAssemblyName = ImmutableDictionary<MetadataId, string>.Empty;
private static readonly Dictionary<MetadataId, string?> s_metadataIdToAssemblyName = new();
private static readonly SemaphoreSlim s_metadataIdToAssemblyNameGate = new(initialCount: 1);

private static readonly ConditionalWeakTable<
Solution,
Dictionary<
(IAssemblySymbol assembly, Project? sourceProject, SymbolVisibility visibility),
ImmutableArray<(Project project, bool hasInternalsAccess)>>> s_solutionToDependentProjectMap = new();
private static readonly SemaphoreSlim s_solutionToDependentProjectMapGate = new(initialCount: 1);

public static async Task<ImmutableArray<Project>> GetDependentProjectsAsync(
Solution solution, ImmutableArray<ISymbol> symbols, IImmutableSet<Project> projects, CancellationToken cancellationToken)
Expand Down Expand Up @@ -128,24 +137,56 @@ private static async Task<ImmutableArray<Project>> GetDependentProjectsWorkerAsy
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var dictionary = s_solutionToDependentProjectMap.GetValue(solution, static _ => new());

var dependentProjects = new HashSet<(Project, bool hasInternalsAccess)>();
var key = (symbolOrigination.assembly, symbolOrigination.sourceProject, visibility);
ImmutableArray<(Project project, bool hasInternalsAccess)> dependentProjects;

// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));
// Check cache first.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
if (dictionary.TryGetValue(key, out dependentProjects))
return dependentProjects;
}

// Compute if not in cache.
dependentProjects = await ComputeDependentProjectsWorkerAsync(
solution, symbolOrigination, visibility, cancellationToken).ConfigureAwait(false);

// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
AddNonSubmissionDependentProjects(solution, symbolOrigination, dependentProjects, cancellationToken);
// Try to add to cache, returning existing value if another thread already added it.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
return dictionary.GetOrAdd(key, dependentProjects);
}

static async Task<ImmutableArray<(Project project, bool hasInternalsAccess)>> ComputeDependentProjectsWorkerAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using var _ = PooledHashSet<(Project, bool hasInternalsAccess)>.GetInstance(out var dependentProjects);

// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));

// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
{
await AddNonSubmissionDependentProjectsAsync(
solution, symbolOrigination, dependentProjects, cancellationToken).ConfigureAwait(false);
}

// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);
// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);

return [.. dependentProjects];
return [.. dependentProjects];
}
}

private static async Task AddSubmissionDependentProjectsAsync(
Expand All @@ -154,7 +195,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
if (sourceProject?.IsSubmission != true)
return;

var projectIdsToReferencingSubmissionIds = new Dictionary<ProjectId, List<ProjectId>>();
using var _1 = PooledDictionary<ProjectId, List<ProjectId>>.GetInstance(out var projectIdsToReferencingSubmissionIds);

// search only submission project
foreach (var projectId in solution.ProjectIds)
Expand All @@ -173,15 +214,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
{
var referencedProject = solution.GetProject(previous.Assembly, cancellationToken);
if (referencedProject != null)
{
if (!projectIdsToReferencingSubmissionIds.TryGetValue(referencedProject.Id, out var referencingSubmissions))
{
referencingSubmissions = [];
projectIdsToReferencingSubmissionIds.Add(referencedProject.Id, referencingSubmissions);
}

referencingSubmissions.Add(project.Id);
}
projectIdsToReferencingSubmissionIds.MultiAdd(referencedProject.Id, project.Id);
}
}
}
Expand All @@ -191,7 +224,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
// and 2, even though 2 doesn't have a direct reference to 1. Hence we need to take
// our current set of projects and find the transitive closure over backwards
// submission previous references.
using var _ = ArrayBuilder<ProjectId>.GetInstance(out var projectIdsToProcess);
using var _2 = ArrayBuilder<ProjectId>.GetInstance(out var projectIdsToProcess);
foreach (var dependentProject in dependentProjects.Select(dp => dp.project.Id))
projectIdsToProcess.Push(dependentProject);

Expand Down Expand Up @@ -221,7 +254,7 @@ private static bool IsInternalsVisibleToAttribute(AttributeData attr)
attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace.ContainingNamespace?.IsGlobalNamespace == true;
}

private static void AddNonSubmissionDependentProjects(
private static async Task AddNonSubmissionDependentProjectsAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
HashSet<(Project project, bool hasInternalsAccess)> dependentProjects,
Expand All @@ -235,7 +268,7 @@ private static void AddNonSubmissionDependentProjects(
foreach (var project in solution.Projects)
{
if (!project.SupportsCompilation ||
!HasReferenceTo(symbolOrigination, project, cancellationToken))
!await HasReferenceToAsync(symbolOrigination, project, cancellationToken).ConfigureAwait(false))
{
continue;
}
Expand Down Expand Up @@ -270,7 +303,7 @@ private static HashSet<string> GetInternalsVisibleToSet(IAssemblySymbol assembly
return set;
}

private static bool HasReferenceTo(
private static async Task<bool> HasReferenceToAsync(
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
Project project,
CancellationToken cancellationToken)
Expand All @@ -284,10 +317,11 @@ private static bool HasReferenceTo(
return project.ProjectReferences.Any(p => p.ProjectId == symbolOrigination.sourceProject.Id);

// Otherwise, if the symbol is from metadata, see if the project's compilation references that metadata assembly.
return HasReferenceToAssembly(project, symbolOrigination.assembly.Name, cancellationToken);
return await HasReferenceToAssemblyAsync(
project, symbolOrigination.assembly.Name, cancellationToken).ConfigureAwait(false);
}

private static bool HasReferenceToAssembly(Project project, string assemblyName, CancellationToken cancellationToken)
private static async Task<bool> HasReferenceToAssemblyAsync(Project project, string assemblyName, CancellationToken cancellationToken)
{
Contract.ThrowIfFalse(project.SupportsCompilation);

Expand All @@ -307,31 +341,50 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,
if (metadataId is null)
continue;

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
uncomputedReferences.Add((peReference, metadataId));
continue;
if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
{
// We already know the assembly name for this metadata id. If it matches the one we're looking for,
// we're done. Otherwise, keep looking.
if (name == assemblyName)
return true;
else
continue;
}
}

if (name == assemblyName)
return true;
// We didn't know the name for the metadata id. Add it to the list of things we need to compute below.
uncomputedReferences.Add((peReference, metadataId));
}

if (uncomputedReferences.Count == 0)
return false;

Compilation? compilation = null;
var compilation = CreateCompilation(project);

foreach (var (peReference, metadataId) in uncomputedReferences)
{
cancellationToken.ThrowIfCancellationRequested();

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
// Attempt to get the assembly name for this pe-reference. If we fail, we still want to add that info into
// the dictionary (by mapping us to 'null'). That way we don't keep trying to compute it over and over.
var name = compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName }
? metadataAssemblyName
: null;

using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
// Defer creating the compilation till needed.
CreateCompilation(project, ref compilation);
if (compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName })
name = ImmutableInterlocked.GetOrAdd(ref s_metadataIdToAssemblyName, metadataId, metadataAssemblyName);
// Overwrite an existing null name with a non-null one.
if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var existingName) &&
existingName == null &&
name != null)
{
s_metadataIdToAssemblyName[metadataId] = name;
}

// Return whatever is in the map, adding ourselves if something is not already there.
name = s_metadataIdToAssemblyName.GetOrAdd(metadataId, name);
}

if (name == assemblyName)
Expand All @@ -340,18 +393,15 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,

return false;

static void CreateCompilation(Project project, [NotNull] ref Compilation? compilation)
static Compilation CreateCompilation(Project project)
{
if (compilation != null)
return;

// Use the project's compilation if it has one.
if (project.TryGetCompilation(out compilation))
return;
if (project.TryGetCompilation(out var compilation))
return compilation;

// Perf: check metadata reference using newly created empty compilation with only metadata references.
var factory = project.Services.GetRequiredService<ICompilationFactoryService>();
compilation = factory
return factory
.CreateCompilation(project.AssemblyName, project.CompilationOptions!)
.AddReferences(project.MetadataReferences);
}
Expand Down

0 comments on commit 6d1245f

Please sign in to comment.