diff --git a/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs b/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs index 380a3c3bf55f0..dc87229e40fca 100644 --- a/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs +++ b/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; @@ -28,7 +29,15 @@ internal static partial class DependentProjectsFinder /// Cache from the for a particular to the /// name of the defined by it. /// - private static ImmutableDictionary s_metadataIdToAssemblyName = ImmutableDictionary.Empty; + private static readonly Dictionary s_metadataIdToAssemblyName = new(); + private static readonly SemaphoreSlim s_metadataIdToAssemblyNameGate = new(initialCount: 1); + + private static readonly ConditionalWeakTable< + Solution, + Dictionary< + (IAssemblySymbol assembly, Project? sourceProject, SymbolVisibility visibility), + ImmutableArray<(Project project, bool hasInternalsAccess)>>> s_solutionToDependentProjectMap = new(); + private static readonly SemaphoreSlim s_solutionToDependentProjectMapGate = new(initialCount: 1); public static async Task> GetDependentProjectsAsync( Solution solution, ImmutableArray symbols, IImmutableSet projects, CancellationToken cancellationToken) @@ -128,24 +137,56 @@ private static async Task> GetDependentProjectsWorkerAsy SymbolVisibility visibility, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + var dictionary = s_solutionToDependentProjectMap.GetValue(solution, static _ => new()); - var dependentProjects = new HashSet<(Project, bool hasInternalsAccess)>(); + var key = (symbolOrigination.assembly, symbolOrigination.sourceProject, visibility); + ImmutableArray<(Project project, bool hasInternalsAccess)> dependentProjects; - // If a symbol was defined in source, then it is always visible to the project it - // was defined in. - if (symbolOrigination.sourceProject != null) - dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true)); + // Check cache first. + using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false)) + { + if (dictionary.TryGetValue(key, out dependentProjects)) + return dependentProjects; + } + + // Compute if not in cache. + dependentProjects = await ComputeDependentProjectsWorkerAsync( + solution, symbolOrigination, visibility, cancellationToken).ConfigureAwait(false); - // If it's not private, then we need to find possible references. - if (visibility != SymbolVisibility.Private) - AddNonSubmissionDependentProjects(solution, symbolOrigination, dependentProjects, cancellationToken); + // Try to add to cache, returning existing value if another thread already added it. + using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false)) + { + return dictionary.GetOrAdd(key, dependentProjects); + } + + static async Task> ComputeDependentProjectsWorkerAsync( + Solution solution, + (IAssemblySymbol assembly, Project? sourceProject) symbolOrigination, + SymbolVisibility visibility, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + using var _ = PooledHashSet<(Project, bool hasInternalsAccess)>.GetInstance(out var dependentProjects); + + // If a symbol was defined in source, then it is always visible to the project it + // was defined in. + if (symbolOrigination.sourceProject != null) + dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true)); + + // If it's not private, then we need to find possible references. + if (visibility != SymbolVisibility.Private) + { + await AddNonSubmissionDependentProjectsAsync( + solution, symbolOrigination, dependentProjects, cancellationToken).ConfigureAwait(false); + } - // submission projects are special here. The fields generated inside the Script object is private, but - // further submissions can bind to them. - await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false); + // submission projects are special here. The fields generated inside the Script object is private, but + // further submissions can bind to them. + await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false); - return [.. dependentProjects]; + return [.. dependentProjects]; + } } private static async Task AddSubmissionDependentProjectsAsync( @@ -154,7 +195,7 @@ private static async Task AddSubmissionDependentProjectsAsync( if (sourceProject?.IsSubmission != true) return; - var projectIdsToReferencingSubmissionIds = new Dictionary>(); + using var _1 = PooledDictionary>.GetInstance(out var projectIdsToReferencingSubmissionIds); // search only submission project foreach (var projectId in solution.ProjectIds) @@ -173,15 +214,7 @@ private static async Task AddSubmissionDependentProjectsAsync( { var referencedProject = solution.GetProject(previous.Assembly, cancellationToken); if (referencedProject != null) - { - if (!projectIdsToReferencingSubmissionIds.TryGetValue(referencedProject.Id, out var referencingSubmissions)) - { - referencingSubmissions = []; - projectIdsToReferencingSubmissionIds.Add(referencedProject.Id, referencingSubmissions); - } - - referencingSubmissions.Add(project.Id); - } + projectIdsToReferencingSubmissionIds.MultiAdd(referencedProject.Id, project.Id); } } } @@ -191,7 +224,7 @@ private static async Task AddSubmissionDependentProjectsAsync( // and 2, even though 2 doesn't have a direct reference to 1. Hence we need to take // our current set of projects and find the transitive closure over backwards // submission previous references. - using var _ = ArrayBuilder.GetInstance(out var projectIdsToProcess); + using var _2 = ArrayBuilder.GetInstance(out var projectIdsToProcess); foreach (var dependentProject in dependentProjects.Select(dp => dp.project.Id)) projectIdsToProcess.Push(dependentProject); @@ -221,7 +254,7 @@ private static bool IsInternalsVisibleToAttribute(AttributeData attr) attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace.ContainingNamespace?.IsGlobalNamespace == true; } - private static void AddNonSubmissionDependentProjects( + private static async Task AddNonSubmissionDependentProjectsAsync( Solution solution, (IAssemblySymbol assembly, Project? sourceProject) symbolOrigination, HashSet<(Project project, bool hasInternalsAccess)> dependentProjects, @@ -235,7 +268,7 @@ private static void AddNonSubmissionDependentProjects( foreach (var project in solution.Projects) { if (!project.SupportsCompilation || - !HasReferenceTo(symbolOrigination, project, cancellationToken)) + !await HasReferenceToAsync(symbolOrigination, project, cancellationToken).ConfigureAwait(false)) { continue; } @@ -270,7 +303,7 @@ private static HashSet GetInternalsVisibleToSet(IAssemblySymbol assembly return set; } - private static bool HasReferenceTo( + private static async Task HasReferenceToAsync( (IAssemblySymbol assembly, Project? sourceProject) symbolOrigination, Project project, CancellationToken cancellationToken) @@ -284,10 +317,11 @@ private static bool HasReferenceTo( return project.ProjectReferences.Any(p => p.ProjectId == symbolOrigination.sourceProject.Id); // Otherwise, if the symbol is from metadata, see if the project's compilation references that metadata assembly. - return HasReferenceToAssembly(project, symbolOrigination.assembly.Name, cancellationToken); + return await HasReferenceToAssemblyAsync( + project, symbolOrigination.assembly.Name, cancellationToken).ConfigureAwait(false); } - private static bool HasReferenceToAssembly(Project project, string assemblyName, CancellationToken cancellationToken) + private static async Task HasReferenceToAssemblyAsync(Project project, string assemblyName, CancellationToken cancellationToken) { Contract.ThrowIfFalse(project.SupportsCompilation); @@ -307,31 +341,50 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName, if (metadataId is null) continue; - if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name)) + using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false)) { - uncomputedReferences.Add((peReference, metadataId)); - continue; + if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name)) + { + // We already know the assembly name for this metadata id. If it matches the one we're looking for, + // we're done. Otherwise, keep looking. + if (name == assemblyName) + return true; + else + continue; + } } - if (name == assemblyName) - return true; + // We didn't know the name for the metadata id. Add it to the list of things we need to compute below. + uncomputedReferences.Add((peReference, metadataId)); } if (uncomputedReferences.Count == 0) return false; - Compilation? compilation = null; + var compilation = CreateCompilation(project); foreach (var (peReference, metadataId) in uncomputedReferences) { cancellationToken.ThrowIfCancellationRequested(); - if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name)) + // Attempt to get the assembly name for this pe-reference. If we fail, we still want to add that info into + // the dictionary (by mapping us to 'null'). That way we don't keep trying to compute it over and over. + var name = compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName } + ? metadataAssemblyName + : null; + + using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false)) { - // Defer creating the compilation till needed. - CreateCompilation(project, ref compilation); - if (compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName }) - name = ImmutableInterlocked.GetOrAdd(ref s_metadataIdToAssemblyName, metadataId, metadataAssemblyName); + // Overwrite an existing null name with a non-null one. + if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var existingName) && + existingName == null && + name != null) + { + s_metadataIdToAssemblyName[metadataId] = name; + } + + // Return whatever is in the map, adding ourselves if something is not already there. + name = s_metadataIdToAssemblyName.GetOrAdd(metadataId, name); } if (name == assemblyName) @@ -340,18 +393,15 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName, return false; - static void CreateCompilation(Project project, [NotNull] ref Compilation? compilation) + static Compilation CreateCompilation(Project project) { - if (compilation != null) - return; - // Use the project's compilation if it has one. - if (project.TryGetCompilation(out compilation)) - return; + if (project.TryGetCompilation(out var compilation)) + return compilation; // Perf: check metadata reference using newly created empty compilation with only metadata references. var factory = project.Services.GetRequiredService(); - compilation = factory + return factory .CreateCompilation(project.AssemblyName, project.CompilationOptions!) .AddReferences(project.MetadataReferences); }