Skip to content

Commit

Permalink
Merge pull request #596 from Evangelink/task-return-null
Browse files Browse the repository at this point in the history
VSTHRD114: Do not return null from non-async Task method
  • Loading branch information
AArnott authored Apr 19, 2020
2 parents 5cccea9 + d827da7 commit 6c5108f
Show file tree
Hide file tree
Showing 27 changed files with 961 additions and 52 deletions.
31 changes: 31 additions & 0 deletions doc/analyzers/VSTHRD114.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# VSTHRD114 Avoid returning a null Task

Returning `null` from a non-async `Task`/`Task<T>` method will cause a `NullReferenceException` at runtime. This problem can be avoided by returning `Task.CompletedTask`, `Task.FromResult<T>(null)` or `Task.FromResult(default(T))` instead.

## Examples of patterns that are flagged by this analyzer

Any non-async `Task` returning method with an explicit `return null;` will be flagged.

```csharp
Task DoAsync() {
return null;
}

Task<object> GetSomethingAsync() {
return null;
}
```

## Solution

Return a task like `Task.CompletedTask` or `Task.FromResult`.

```csharp
Task DoAsync() {
return Task.CompletedTask;
}

Task<object> GetSomethingAsync() {
return Task.FromResult<object>(null);
}
```
1 change: 1 addition & 0 deletions doc/analyzers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ID | Title | Severity | Supports | Default diagnostic severity
[VSTHRD111](VSTHRD111.md) | Use `.ConfigureAwait(bool)` | Advisory | | Hidden
[VSTHRD112](VSTHRD112.md) | Implement `System.IAsyncDisposable` | Advisory | | Info
[VSTHRD113](VSTHRD113.md) | Check for `System.IAsyncDisposable` | Advisory | | Info
[VSTHRD114](VSTHRD114.md) | Avoid returning null from a `Task`-returning method. | Advisory | | Warning
[VSTHRD200](VSTHRD200.md) | Use `Async` naming convention | Guideline | [VSTHRD103](VSTHRD103.md) | Warning

## Severity descriptions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
namespace Microsoft.VisualStudio.Threading.Analyzers
{
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Simplification;

[ExportCodeFixProvider(LanguageNames.CSharp)]
public class VSTHRD114AvoidReturningNullTaskCodeFix : CodeFixProvider
{
private static readonly ImmutableArray<string> ReusableFixableDiagnosticIds = ImmutableArray.Create(
VSTHRD114AvoidReturningNullTaskAnalyzer.Id);

/// <inheritdoc />
public override ImmutableArray<string> FixableDiagnosticIds => ReusableFixableDiagnosticIds;

/// <inheritdoc />
public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
foreach (var diagnostic in context.Diagnostics)
{
var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
var syntaxRoot = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

if (!(syntaxRoot.FindNode(diagnostic.Location.SourceSpan) is LiteralExpressionSyntax nullLiteral))
{
continue;
}

var methodDeclaration = nullLiteral.FirstAncestorOrSelf<MethodDeclarationSyntax>();
if (methodDeclaration == null)
{
continue;
}

if (!(methodDeclaration.ReturnType is GenericNameSyntax genericReturnType))
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_CompletedTask, ct => ApplyTaskCompletedTaskFix(ct), "CompletedTask"), diagnostic);
}
else
{
if (genericReturnType.TypeArgumentList.Arguments.Count != 1)
{
continue;
}

context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_FromResult, ct => ApplyTaskFromResultFix(genericReturnType.TypeArgumentList.Arguments[0], ct), "FromResult"), diagnostic);
}

Task<Document> ApplyTaskCompletedTaskFix(CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.IdentifierName("CompletedTask"))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}

Task<Document> ApplyTaskFromResultFix(TypeSyntax returnTypeArgument, CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.GenericName("FromResult").AddTypeArgumentListArguments(returnTypeArgument)))
.AddArgumentListArguments(SyntaxFactory.Argument(nullLiteral))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,38 @@ public Test()

this.SolutionTransforms.Add((solution, projectId) =>
{
var parseOptions = (CSharpParseOptions)solution.GetProject(projectId).ParseOptions;
solution = solution.WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));
Project project = solution.GetProject(projectId)!;

var parseOptions = (CSharpParseOptions)project.ParseOptions!;
project = project.WithParseOptions(parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));

if (this.HasEntryPoint)
{
var compilationOptions = solution.GetProject(projectId).CompilationOptions;
solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
project = project.WithCompilationOptions(project.CompilationOptions!.WithOutputKind(OutputKind.ConsoleApplication));
}

if (this.IncludeMicrosoftVisualStudioThreading)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
}

if (this.IncludeWindowsBase)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
}

if (this.IncludeVisualStudioSdk)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));

var nugetPackagesFolder = Environment.CurrentDirectory;
foreach (var reference in ReferencesHelper.VSSDKPackageReferences)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
}
}

return solution;
return project.Solution;
});

this.TestState.AdditionalFilesFactories.Add(() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,52 @@ public static partial class VisualBasicCodeFixVerifier<TAnalyzer, TCodeFix>
{
public class Test : VisualBasicCodeFixTest<TAnalyzer, TCodeFix, XUnitVerifier>
{
private static readonly ImmutableArray<string> VSSDKPackageReferences = ImmutableArray.Create(new string[] {
"Microsoft.VisualStudio.Shell.Interop.dll",
"Microsoft.VisualStudio.Shell.Interop.11.0.dll",
"Microsoft.VisualStudio.Shell.Interop.14.0.DesignTime.dll",
"Microsoft.VisualStudio.Shell.Immutable.14.0.dll",
"Microsoft.VisualStudio.Shell.14.0.dll",
});

public Test()
{
this.ReferenceAssemblies = ReferencesHelper.DefaultReferences;

this.SolutionTransforms.Add((solution, projectId) =>
{
var parseOptions = (VisualBasicParseOptions)solution.GetProject(projectId).ParseOptions;
solution = solution.WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.VisualBasic15_5));
Project? project = solution.GetProject(projectId);

var parseOptions = (VisualBasicParseOptions)project!.ParseOptions!;
project = project.WithParseOptions(parseOptions.WithLanguageVersion(LanguageVersion.VisualBasic15_5));

if (this.HasEntryPoint)
{
var compilationOptions = solution.GetProject(projectId).CompilationOptions;
solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
project = project.WithCompilationOptions(project.CompilationOptions!.WithOutputKind(OutputKind.ConsoleApplication));
}

if (this.IncludeMicrosoftVisualStudioThreading)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
}

if (this.IncludeWindowsBase)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
}

if (this.IncludeVisualStudioSdk)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));

var nugetPackagesFolder = Environment.CurrentDirectory;
foreach (var reference in ReferencesHelper.VSSDKPackageReferences)
foreach (var reference in VisualBasicCodeFixVerifier<TAnalyzer, TCodeFix>.Test.VSSDKPackageReferences)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
}
}

return solution;
return project.Solution;
});

this.TestState.AdditionalFilesFactories.Add(() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Testing;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.CodeAnalysis.Testing.Verifiers;
using Xunit;
using Verify = MultiAnalyzerTests.Verifier;

Expand All @@ -32,7 +30,7 @@ Task<int> FooAsync() {
return Task.FromResult(1);
}
Task BarAsync() => null;
Task BarAsync() => Task.CompletedTask;
static void SetTaskSourceIfCompleted<T>(Task<T> task, TaskCompletionSource<T> tcs) {
if (task.IsCompleted) {
Expand Down Expand Up @@ -171,7 +169,7 @@ public Task BAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}
internal Task CAsync() {
Expand All @@ -181,7 +179,7 @@ internal Task CAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}
private void D<T>() { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ static Task<bool> MyMethodAsync()
{
var projectA = solution.AddProject("ProjectA", "ProjectA", LanguageNames.CSharp)
.WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary))
.WithMetadataReferences(solution.GetProject(projectId).MetadataReferences.Concat(test!.TestState.AdditionalReferences))
.WithMetadataReferences(solution.GetProject(projectId)!.MetadataReferences.Concat(test!.TestState.AdditionalReferences))
.AddDocument("SpecialTasks.cs", specialTasksCs).Project;
solution = projectA.Solution;
solution = solution.AddProjectReference(projectId, new ProjectReference(projectA.Id));
Expand Down
Loading

0 comments on commit 6c5108f

Please sign in to comment.