Skip to content

Commit

Permalink
Fix VSTHRD003 to allow for awaiting more Task properties
Browse files Browse the repository at this point in the history
When integrating into the VS repo, recent enhancements to VSTHRD003 made insertion of the new analyzers problematic. In particular, two patterns were quite common and arguably should be allowed:

- Awaiting on the `JoinableTask.Task` property on the result of a `JoinableTaskFactory.RunAsync` method call. In general, if a method returns an object, awaiting one of that object's properties *might not* be fair game, but it seems more likely than not. So in this change, that is now allowed.
- Awaiting a `TaskCompletionSource<T>.Task` property on a TCS that was created in that same method. Test methods often do this. Awaiting a TCS.Task property is never guaranteed to be safe as far as analyzers can tell (though the 3rd threading rule can make it safe), and when common patterns appear that are _typically_ safe, we should allow it.
  • Loading branch information
AArnott committed Oct 25, 2023
1 parent 4b595ea commit b561132
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ internal static ContainingFunctionData GetContainingFunction(CSharpSyntaxNode? s
return new ContainingFunctionData(simpleLambda, simpleLambda.AsyncKeyword != default(SyntaxToken), SyntaxFactory.ParameterList().AddParameters(simpleLambda.Parameter), simpleLambda.Body, simpleLambda.WithBody);
}

if (syntaxNode is LocalFunctionStatementSyntax localFunc)
{
return new ContainingFunctionData(localFunc, localFunc.Modifiers.Any(SyntaxKind.AsyncKeyword), localFunc.ParameterList, (CSharpSyntaxNode?)localFunc.ExpressionBody ?? localFunc.Body, localFunc.WithBody);
}

if (syntaxNode is AnonymousMethodExpressionSyntax anonymousMethod)
{
return new ContainingFunctionData(anonymousMethod, anonymousMethod.AsyncKeyword != default(SyntaxToken), anonymousMethod.ParameterList, anonymousMethod.Body, anonymousMethod.WithBody);
Expand Down Expand Up @@ -125,7 +130,7 @@ internal static bool IsOnLeftHandOfAssignment(SyntaxNode syntaxNode)
return false;
}

internal static bool IsAssignedWithin(SyntaxNode container, SemanticModel semanticModel, ISymbol variable, CancellationToken cancellationToken)
internal static IEnumerable<ExpressionSyntax> FindAssignedValuesWithin(SyntaxNode container, SemanticModel semanticModel, ISymbol variable, CancellationToken cancellationToken)
{
if (semanticModel is null)
{
Expand All @@ -139,22 +144,36 @@ internal static bool IsAssignedWithin(SyntaxNode container, SemanticModel semant

if (container is null)
{
return false;
yield break;
}

foreach (SyntaxNode? node in container.DescendantNodesAndSelf(n => !(n is AnonymousFunctionExpressionSyntax)))
foreach (SyntaxNode? node in container.DescendantNodesAndSelf(n => !(n is AnonymousFunctionExpressionSyntax or LocalFunctionStatementSyntax)))
{
cancellationToken.ThrowIfCancellationRequested();
if (node is AssignmentExpressionSyntax assignment)
{
ISymbol? assignedSymbol = semanticModel.GetSymbolInfo(assignment.Left, cancellationToken).Symbol;
if (variable.Equals(assignedSymbol, SymbolEqualityComparer.Default))
{
return true;
yield return assignment.Right;
}
}
}

return false;
if (node is LocalDeclarationStatementSyntax localDeclarationStatement)
{
foreach (VariableDeclaratorSyntax localDeclVar in localDeclarationStatement.Declaration.Variables)
{
if (localDeclVar.Initializer is not null)
{
ISymbol? assignedSymbol = semanticModel.GetDeclaredSymbol(localDeclVar, cancellationToken);
if (variable.Equals(assignedSymbol, SymbolEqualityComparer.Default))
{
yield return localDeclVar.Initializer.Value;
}
}
}
}
}
}

internal static MemberAccessExpressionSyntax MemberAccess(IReadOnlyList<string> qualifiers, SimpleNameSyntax simpleName)
Expand Down Expand Up @@ -277,7 +296,7 @@ internal override bool IsAsyncMethod(SyntaxNode syntaxNode)

internal readonly struct ContainingFunctionData
{
internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, ParameterListSyntax? parameterList, CSharpSyntaxNode? blockOrExpression, Func<CSharpSyntaxNode, CSharpSyntaxNode> bodyReplacement)
internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, ParameterListSyntax? parameterList, CSharpSyntaxNode? blockOrExpression, Func<BlockSyntax, CSharpSyntaxNode> bodyReplacement)
{
this.Function = function;
this.IsAsync = isAsync;
Expand All @@ -294,6 +313,6 @@ internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, Paramet

internal CSharpSyntaxNode? BlockOrExpression { get; }

internal Func<CSharpSyntaxNode, CSharpSyntaxNode> BodyReplacement { get; }
internal Func<BlockSyntax, CSharpSyntaxNode> BodyReplacement { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,28 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
return null;
}

SymbolInfo symbolToConsider = semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken);
ExpressionSyntax focusedExpression = expressionSyntax;
SymbolInfo symbolToConsider = semanticModel.GetSymbolInfo(focusedExpression, cancellationToken);
if (CommonInterest.TaskConfigureAwait.Any(configureAwait => configureAwait.IsMatch(symbolToConsider.Symbol)))
{
// If the invocation is wrapped inside parentheses then drill down to get the invocation.
while (expressionSyntax is ParenthesizedExpressionSyntax parenthesizedExprSyntax)
while (focusedExpression is ParenthesizedExpressionSyntax parenthesizedExprSyntax)
{
expressionSyntax = parenthesizedExprSyntax.Expression;
focusedExpression = parenthesizedExprSyntax.Expression;
}

Debug.Assert(expressionSyntax is InvocationExpressionSyntax, "expressionSyntax should be an invocation");
Debug.Assert(focusedExpression is InvocationExpressionSyntax, "focusedExpression should be an invocation");

if (((InvocationExpressionSyntax)expressionSyntax).Expression is MemberAccessExpressionSyntax memberAccessExpression)
if (((InvocationExpressionSyntax)focusedExpression).Expression is MemberAccessExpressionSyntax memberAccessExpression)
{
focusedExpression = memberAccessExpression.Expression;
symbolToConsider = semanticModel.GetSymbolInfo(memberAccessExpression.Expression, cancellationToken);
}
}

ITypeSymbol symbolType;
bool dataflowAnalysisCompatibleVariable = false;
CSharpUtils.ContainingFunctionData? containingFunc = null;
switch (symbolToConsider.Symbol)
{
case ILocalSymbol localSymbol:
Expand All @@ -182,6 +185,29 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
break;
case IPropertySymbol propertySymbol when !IsSymbolAlwaysOkToAwait(propertySymbol):
symbolType = propertySymbol.Type;

if (focusedExpression is MemberAccessExpressionSyntax memberAccessExpression)
{
// Do not report a warning if the task is a member of an object that was returned from an invocation made in this method.
if (memberAccessExpression.Expression is InvocationExpressionSyntax)
{
return null;
}

// Do not report a warning if the task is a member of an object that was created in this method.
if (memberAccessExpression.Expression is IdentifierNameSyntax identifier &&
semanticModel.GetSymbolInfo(identifier, cancellationToken).Symbol is ILocalSymbol local)
{
// Search for assignments to the local and see if it was to a new object.
containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression);
if (containingFunc.Value.BlockOrExpression is not null &&
CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, local, cancellationToken).Any(v => v is ObjectCreationExpressionSyntax))
{
return null;
}
}
}

break;
case IParameterSymbol parameterSymbol:
symbolType = parameterSymbol.Type;
Expand Down Expand Up @@ -247,7 +273,7 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

break;
case IMethodSymbol methodSymbol:
if (Utils.IsTask(methodSymbol.ReturnType) && expressionSyntax is InvocationExpressionSyntax invocationExpressionSyntax)
if (Utils.IsTask(methodSymbol.ReturnType) && focusedExpression is InvocationExpressionSyntax invocationExpressionSyntax)
{
// Consider all arguments
IEnumerable<ExpressionSyntax>? expressionsToConsider = invocationExpressionSyntax.ArgumentList.Arguments.Select(a => a.Expression);
Expand Down Expand Up @@ -275,8 +301,8 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
}

// Report warning if the task was not initialized within the current delegate or lambda expression
CSharpUtils.ContainingFunctionData containingFunc = CSharpUtils.GetContainingFunction(expressionSyntax);
if (containingFunc.BlockOrExpression is BlockSyntax delegateBlock)
containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression);
if (containingFunc.Value.BlockOrExpression is BlockSyntax delegateBlock)
{
if (dataflowAnalysisCompatibleVariable)
{
Expand All @@ -285,9 +311,9 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

// When possible (await is direct child of the block and not a field), execute data flow analysis by passing first and last statement to capture only what happens before the await
// Check if the await is direct child of the code block (first parent is ExpressionStantement, second parent is the block itself)
if (delegateBlock.Equals(expressionSyntax.Parent?.Parent?.Parent))
if (delegateBlock.Equals(focusedExpression.Parent?.Parent?.Parent))
{
dataFlowAnalysis = semanticModel.AnalyzeDataFlow(delegateBlock.ChildNodes().First(), expressionSyntax.Parent.Parent);
dataFlowAnalysis = semanticModel.AnalyzeDataFlow(delegateBlock.ChildNodes().First(), focusedExpression.Parent.Parent);
}
else
{
Expand All @@ -297,22 +323,22 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

if (dataFlowAnalysis?.WrittenInside.Contains(symbolToConsider.Symbol) is false)
{
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}
}
else
{
// Do the best we can searching for assignment statements.
if (!CSharpUtils.IsAssignedWithin(containingFunc.BlockOrExpression, semanticModel, symbolToConsider.Symbol, cancellationToken))
if (!CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, symbolToConsider.Symbol, cancellationToken).Any())
{
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}
}
}
else
{
// It's not a block, it's just a lambda expression, so the variable must be external.
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,11 @@ class Tests
public async Task AwaitAndGetResult()
{{
await task.ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
await [|task|].ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
}}
}}
";
DiagnosticResult expected = this.CreateDiagnostic(10, 15, 21 + continueOnCapturedContext.ToString().Length);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Theory]
Expand All @@ -223,12 +222,11 @@ class Tests
public async Task<int> AwaitAndGetResult()
{{
return await task.ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
return await [|task|].ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
}}
}}
";
DiagnosticResult expected = this.CreateDiagnostic(10, 22, 21 + continueOnCapturedContext.ToString().Length);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand All @@ -244,12 +242,11 @@ class Tests
public async Task AwaitAndGetResult()
{
await task.ConfigureAwaitRunInline();
await [|task|].ConfigureAwaitRunInline();
}
}
";
DiagnosticResult expected = this.CreateDiagnostic(11, 15, 30);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand All @@ -265,12 +262,11 @@ class Tests
public async Task<int> AwaitAndGetResult()
{
return await task.ConfigureAwaitRunInline();
return await [|task|].ConfigureAwaitRunInline();
}
}
";
DiagnosticResult expected = this.CreateDiagnostic(11, 22, 30);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand Down Expand Up @@ -1280,6 +1276,64 @@ async Task GetTask()
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectCreatedInContext()
{
var test = @"
using System.Threading.Tasks;
class Tests
{
private Task MyTaskProperty { get; set; }
static async Task GetTask()
{
// our own property.
var obj = new Tests();
await obj.MyTaskProperty;
// local with initializer
var tcs = new TaskCompletionSource<int>();
await tcs.Task;
// Assign later
TaskCompletionSource<int> tcs2;
tcs2 = new TaskCompletionSource<int>();
await tcs2.Task;
// Assigned, but not to a newly created object.
TaskCompletionSource<int> tcs3 = tcs2;
await [|tcs3.Task|];
}
}
";
await CSVerify.VerifyAnalyzerAsync(test);
}

/// <summary>
/// This is important to allow folks to return jtf.RunAsync(...).Task from a method.
/// </summary>
[Fact]
public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectReturnedFromMethod()
{
var test = @"
using System.Threading.Tasks;
class Tests
{
private Task MyTaskProperty { get; set; }
static Tests NewTests() => new Tests();
static async Task GetTask()
{
await NewTests().MyTaskProperty;
}
}
";
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
public async Task ReportWarningWhenAwaitingTaskPropertyThatWasNotSetInContext()
{
Expand All @@ -1293,6 +1347,7 @@ class Tests
async Task GetTask()
{
await [|this.MyTaskProperty|];
await [|MyTaskProperty|];
}
}
";
Expand Down

0 comments on commit b561132

Please sign in to comment.