diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CodeFixVerifier.cs b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CodeFixVerifier.cs index 9100ecdd5..23ac4850e 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CodeFixVerifier.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CodeFixVerifier.cs @@ -165,14 +165,21 @@ private void VerifyFix(string language, DiagnosticAnalyzer analyzer, CodeFixProv } } - Assert.True(fixApplied, "No code fix offered."); + if (newSources != null && newSources[0] != null) + { + Assert.True(fixApplied, "No code fix offered."); - // After applying all of the code fixes, compare the resulting string to the inputted one - int j = 0; - foreach (var document in project.Documents) + // After applying all of the code fixes, compare the resulting string to the inputted one + int j = 0; + foreach (var document in project.Documents) + { + var actual = GetStringFromDocument(document); + Assert.Equal(newSources[j++], actual, ignoreLineEndingDifferences: true); + } + } + else { - var actual = GetStringFromDocument(document); - Assert.Equal(newSources[j++], actual, ignoreLineEndingDifferences: true); + Assert.False(fixApplied, "No code fix expected, but was offered."); } } } diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD103UseAsyncOptionAnalyzerTests.cs b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD103UseAsyncOptionAnalyzerTests.cs index a53ab1829..64295b6bb 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD103UseAsyncOptionAnalyzerTests.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD103UseAsyncOptionAnalyzerTests.cs @@ -254,6 +254,116 @@ async Task T() { this.VerifyCSharpFix(test, withFix); } + [Fact] + public void IVsTaskWaitInTaskReturningMethodGeneratesWarning() + { + var test = @" +using System.Threading.Tasks; +using Microsoft.VisualStudio.Shell; +using Microsoft.VisualStudio.Shell.Interop; +using Task = System.Threading.Tasks.Task; + +class Test { + Task T() { + IVsTask t = null; + t.Wait(); + return Task.FromResult(1); + } +} +"; + + var withFix = @" +using System.Threading.Tasks; +using Microsoft.VisualStudio.Shell; +using Microsoft.VisualStudio.Shell.Interop; +using Task = System.Threading.Tasks.Task; + +class Test { + async Task T() { + IVsTask t = null; + await t; + } +} +"; + this.expect = this.CreateDiagnostic(10, 11, 4); + this.VerifyCSharpDiagnostic(test, this.expect); + this.VerifyCSharpFix(test, withFix); + } + + [Fact] + public void IVsTaskGetResultInTaskReturningMethodGeneratesWarning() + { + var test = @" +using System.Threading.Tasks; +using Microsoft.VisualStudio.Shell; +using Microsoft.VisualStudio.Shell.Interop; +using Task = System.Threading.Tasks.Task; + +class Test { + Task T() { + IVsTask t = null; + object result = t.GetResult(); + return Task.FromResult(1); + } +} +"; + + var withFix = @" +using System.Threading.Tasks; +using Microsoft.VisualStudio.Shell; +using Microsoft.VisualStudio.Shell.Interop; +using Task = System.Threading.Tasks.Task; + +class Test { + async Task T() { + IVsTask t = null; + object result = await t; + } +} +"; + this.expect = this.CreateDiagnostic(10, 27, 9); + this.VerifyCSharpDiagnostic(test, this.expect); + this.VerifyCSharpFix(test, withFix); + } + + /// + /// Ensures we don't offer a code fix when the required using directive is not already present. + /// + [Fact] + public void IVsTaskGetResultInTaskReturningMethod_WithoutUsing_OffersNoFix() + { + var test = @" +using System.Threading.Tasks; +using Microsoft.VisualStudio.Shell.Interop; + +class Test { + Task T() { + IVsTask t = null; + object result = t.GetResult(); + return Task.FromResult(1); + } +} +"; + + string withFix = null; +//// var withFix = @" +//// using System.Threading.Tasks; +//// using Microsoft.VisualStudio.Shell; +//// using Microsoft.VisualStudio.Shell.Interop; +//// using Task = System.Threading.Tasks.Task; +//// +//// class Test { +//// async Task T() { +//// IVsTask t = null; +//// object result = await t; +//// } +//// } +//// "; + this.expect = this.CreateDiagnostic(8, 27, 9); + this.VerifyCSharpDiagnostic(test, this.expect); + this.VerifyCSharpFix(test, withFix); + } + [Fact] public void TaskOfTResultInTaskReturningMethodGeneratesWarning() { @@ -319,7 +429,7 @@ void T() { } "; - this.expect.Locations = new[] { new DiagnosticResultLocation("Test0.cs", 9, 28, 9, 34) }; + this.expect = this.CreateDiagnostic(9, 28, 6); this.VerifyCSharpDiagnostic(test, this.expect); this.VerifyCSharpFix(test, withFix); } @@ -874,7 +984,7 @@ public void NoDiagnosticAndNoExceptionForProperties() class Test { string Foo => string.Empty; - string Bar => string.Join(""a"", string.Empty); + string Bar => string.Join(""a"", string.Empty); } "; @@ -1078,5 +1188,14 @@ Task MethodAsync() this.VerifyCSharpDiagnostic(test); } + + private DiagnosticResult CreateDiagnostic(int line, int column, int length, string messagePattern = null) => + new DiagnosticResult + { + Id = this.expect.Id, + MessagePattern = messagePattern ?? this.expect.MessagePattern, + Severity = this.expect.Severity, + Locations = new[] { new DiagnosticResultLocation("Test0.cs", line, column, line, column + length) }, + }; } } diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/CommonInterest.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/CommonInterest.cs index db7315d1f..bba9b76be 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers/CommonInterest.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers/CommonInterest.cs @@ -34,7 +34,11 @@ internal static class CommonInterest new SyncBlockingMethod(new QualifiedMember(new QualifiedType(Namespaces.SystemRuntimeCompilerServices, nameof(TaskAwaiter)), nameof(TaskAwaiter.GetResult)), null), }; - internal static readonly IEnumerable SyncBlockingMethods = JTFSyncBlockers.Concat(ProblematicSyncBlockingMethods); + internal static readonly IEnumerable SyncBlockingMethods = JTFSyncBlockers.Concat(ProblematicSyncBlockingMethods).Concat(new[] + { + new SyncBlockingMethod(new QualifiedMember(new QualifiedType(Namespaces.MicrosoftVisualStudioShellInterop, "IVsTask"), "Wait"), extensionMethodNamespace: Namespaces.MicrosoftVisualStudioShell), + new SyncBlockingMethod(new QualifiedMember(new QualifiedType(Namespaces.MicrosoftVisualStudioShellInterop, "IVsTask"), "GetResult"), extensionMethodNamespace: Namespaces.MicrosoftVisualStudioShell), + }); internal static readonly IEnumerable LegacyThreadSwitchingMethods = new[] { @@ -381,15 +385,18 @@ public bool IsMatch(ISymbol symbol) [DebuggerDisplay("{" + nameof(Method) + "} -> {" + nameof(AsyncAlternativeMethodName) + "}")] internal struct SyncBlockingMethod { - public SyncBlockingMethod(QualifiedMember method, string asyncAlternativeMethodName) + public SyncBlockingMethod(QualifiedMember method, string asyncAlternativeMethodName = null, IReadOnlyList extensionMethodNamespace = null) { this.Method = method; this.AsyncAlternativeMethodName = asyncAlternativeMethodName; + this.ExtensionMethodNamespace = extensionMethodNamespace; } public QualifiedMember Method { get; } public string AsyncAlternativeMethodName { get; } + + public IReadOnlyList ExtensionMethodNamespace { get; } } } } diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/Namespaces.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/Namespaces.cs index 24421503c..83998a6c8 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers/Namespaces.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers/Namespaces.cs @@ -62,5 +62,13 @@ internal static class Namespaces "VisualStudio", "Shell", }; + + internal static readonly IReadOnlyList MicrosoftVisualStudioShellInterop = new[] + { + "Microsoft", + "VisualStudio", + "Shell", + "Interop", + }; } } diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionAnalyzer.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionAnalyzer.cs index 1aabb9733..9b0cbb03b 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionAnalyzer.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionAnalyzer.cs @@ -182,7 +182,8 @@ private static bool InspectMemberAccess(SyntaxNodeAnalysisContext context, Membe if (item.Method.IsMatch(memberSymbol)) { var location = memberAccessSyntax.Name.GetLocation(); - var properties = ImmutableDictionary.Empty; + var properties = ImmutableDictionary.Empty + .Add(VSTHRD103UseAsyncOptionCodeFix.ExtensionMethodNamespaceKeyName, item.ExtensionMethodNamespace != null ? string.Join(".", item.ExtensionMethodNamespace) : string.Empty); DiagnosticDescriptor descriptor; var messageArgs = new List(2); messageArgs.Add(item.Method.Name); diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionCodeFix.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionCodeFix.cs index 0fadad67e..30162a538 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionCodeFix.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD103UseAsyncOptionCodeFix.cs @@ -33,6 +33,8 @@ public class VSTHRD103UseAsyncOptionCodeFix : CodeFixProvider { internal const string AsyncMethodKeyName = "AsyncMethodName"; + internal const string ExtensionMethodNamespaceKeyName = "ExtensionMethodNamespace"; + private static readonly ImmutableArray ReusableFixableDiagnosticIds = ImmutableArray.Create( VSTHRD103UseAsyncOptionAnalyzer.Id); @@ -40,15 +42,41 @@ public class VSTHRD103UseAsyncOptionCodeFix : CodeFixProvider public override ImmutableArray FixableDiagnosticIds => ReusableFixableDiagnosticIds; /// - public override Task RegisterCodeFixesAsync(CodeFixContext context) + public override async Task RegisterCodeFixesAsync(CodeFixContext context) { var diagnostic = context.Diagnostics.FirstOrDefault(d => d.Properties.ContainsKey(AsyncMethodKeyName)); if (diagnostic != null) { - context.RegisterCodeFix(new ReplaceSyncMethodCallWithAwaitAsync(context.Document, diagnostic), diagnostic); - } + // Check that the method we're replacing the sync blocking call with actually exists. + // This is particularly useful when the method is an extension method, since the using directive + // would need to be present (or the namespace imply it) and we don't yet add missing using directives. + bool asyncAlternativeExists = false; + string asyncMethodName = diagnostic.Properties[AsyncMethodKeyName]; + if (string.IsNullOrEmpty(asyncMethodName)) + { + asyncMethodName = "GetAwaiter"; + } - return Task.FromResult(null); + var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false); + var syntaxRoot = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + var blockingIdentifier = syntaxRoot.FindNode(diagnostic.Location.SourceSpan) as IdentifierNameSyntax; + var memberAccessExpression = blockingIdentifier?.Parent as MemberAccessExpressionSyntax; + + // Check whether this code was already calling the awaiter (in a synchronous fashion). + asyncAlternativeExists |= memberAccessExpression?.Expression is InvocationExpressionSyntax invoke && invoke.Expression is MemberAccessExpressionSyntax parentMemberAccess && parentMemberAccess.Name.Identifier.Text == nameof(Task.GetAwaiter); + + if (!asyncAlternativeExists) + { + // If we fail to recognize the container, assume it exists since the analyzer thought it would. + var container = memberAccessExpression != null ? semanticModel.GetTypeInfo(memberAccessExpression.Expression, context.CancellationToken).ConvertedType : null; + asyncAlternativeExists = container == null || semanticModel.LookupSymbols(diagnostic.Location.SourceSpan.Start, name: asyncMethodName, container: container, includeReducedExtensionMethods: true).Any(); + } + + if (asyncAlternativeExists) + { + context.RegisterCodeFix(new ReplaceSyncMethodCallWithAwaitAsync(context.Document, diagnostic), diagnostic); + } + } } /// @@ -80,6 +108,8 @@ public override string Title private string AlternativeAsyncMethod => this.diagnostic.Properties[AsyncMethodKeyName]; + private string ExtensionMethodNamespace => this.diagnostic.Properties[ExtensionMethodNamespaceKeyName]; + protected override async Task GetChangedSolutionAsync(CancellationToken cancellationToken) { var document = this.document;