Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VSTHRD114: Do not return null from non-async Task method #596

Merged
merged 14 commits into from
Apr 19, 2020
Merged
31 changes: 31 additions & 0 deletions doc/analyzers/VSTHRD114.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# VSTHRD114 Avoid returning a null Task

Returning `null` from a non-async `Task`/`Task<T>` method will cause a `NullReferenceException` at runtime. This problem can be avoided by returning `Task.CompletedTask`, `Task.FromResult<T>(null)` or `Task.FromResult(default(T))` instead.

## Examples of patterns that are flagged by this analyzer

Any non-async `Task` returning method with an explicit `return null;` will be flagged.

```csharp
Task DoAsync() {
return null;
}

Task<object> GetSomethingAsync() {
return null;
}
```

## Solution

Return a task like `Task.CompletedTask` or `Task.FromResult`.

```csharp
Task DoAsync() {
return Task.CompletedTask;
}

Task<object> GetSomethingAsync() {
return Task.FromResult<object>(null);
}
```
1 change: 1 addition & 0 deletions doc/analyzers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ID | Title | Severity | Supports | Default diagnostic severity
[VSTHRD111](VSTHRD111.md) | Use `.ConfigureAwait(bool)` | Advisory | | Hidden
[VSTHRD112](VSTHRD112.md) | Implement `System.IAsyncDisposable` | Advisory | | Info
[VSTHRD113](VSTHRD113.md) | Check for `System.IAsyncDisposable` | Advisory | | Info
[VSTHRD114](VSTHRD114.md) | Avoid returning null from a `Task`-returning method. | Advisory | | Warning
[VSTHRD200](VSTHRD200.md) | Use `Async` naming convention | Guideline | [VSTHRD103](VSTHRD103.md) | Warning

## Severity descriptions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
namespace Microsoft.VisualStudio.Threading.Analyzers
{
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Simplification;

[ExportCodeFixProvider(LanguageNames.CSharp)]
public class VSTHRD114AvoidReturningNullTaskCodeFix : CodeFixProvider
{
private static readonly ImmutableArray<string> ReusableFixableDiagnosticIds = ImmutableArray.Create(
VSTHRD114AvoidReturningNullTaskAnalyzer.Id);

/// <inheritdoc />
public override ImmutableArray<string> FixableDiagnosticIds => ReusableFixableDiagnosticIds;

/// <inheritdoc />
public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
foreach (var diagnostic in context.Diagnostics)
{
var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
var syntaxRoot = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

if (!(syntaxRoot.FindNode(diagnostic.Location.SourceSpan) is LiteralExpressionSyntax nullLiteral))
{
continue;
}

var methodDeclaration = nullLiteral.FirstAncestorOrSelf<MethodDeclarationSyntax>();
if (methodDeclaration == null)
{
continue;
}

if (!(methodDeclaration.ReturnType is GenericNameSyntax genericReturnType))
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_CompletedTask, ct => ApplyTaskCompletedTaskFix(ct), "CompletedTask"), diagnostic);
}
else
{
if (genericReturnType.TypeArgumentList.Arguments.Count != 1)
{
continue;
}

context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_FromResult, ct => ApplyTaskFromResultFix(genericReturnType.TypeArgumentList.Arguments[0], ct), "FromResult"), diagnostic);
}

Task<Document> ApplyTaskCompletedTaskFix(CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.IdentifierName("CompletedTask"))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}

Task<Document> ApplyTaskFromResultFix(TypeSyntax returnTypeArgument, CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.GenericName("FromResult").AddTypeArgumentListArguments(returnTypeArgument)))
.AddArgumentListArguments(SyntaxFactory.Argument(nullLiteral))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,38 @@ public Test()

this.SolutionTransforms.Add((solution, projectId) =>
{
var parseOptions = (CSharpParseOptions)solution.GetProject(projectId).ParseOptions;
solution = solution.WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));
Project project = solution.GetProject(projectId)!;

var parseOptions = (CSharpParseOptions)project.ParseOptions!;
project = project.WithParseOptions(parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));

if (this.HasEntryPoint)
{
var compilationOptions = solution.GetProject(projectId).CompilationOptions;
solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
project = project.WithCompilationOptions(project.CompilationOptions!.WithOutputKind(OutputKind.ConsoleApplication));
}

if (this.IncludeMicrosoftVisualStudioThreading)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
}

if (this.IncludeWindowsBase)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
}

if (this.IncludeVisualStudioSdk)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));

var nugetPackagesFolder = Environment.CurrentDirectory;
foreach (var reference in ReferencesHelper.VSSDKPackageReferences)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
}
}

return solution;
return project.Solution;
});

this.TestState.AdditionalFilesFactories.Add(() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,52 @@ public static partial class VisualBasicCodeFixVerifier<TAnalyzer, TCodeFix>
{
public class Test : VisualBasicCodeFixTest<TAnalyzer, TCodeFix, XUnitVerifier>
{
private static readonly ImmutableArray<string> VSSDKPackageReferences = ImmutableArray.Create(new string[] {
"Microsoft.VisualStudio.Shell.Interop.dll",
"Microsoft.VisualStudio.Shell.Interop.11.0.dll",
"Microsoft.VisualStudio.Shell.Interop.14.0.DesignTime.dll",
"Microsoft.VisualStudio.Shell.Immutable.14.0.dll",
"Microsoft.VisualStudio.Shell.14.0.dll",
});

public Test()
{
this.ReferenceAssemblies = ReferencesHelper.DefaultReferences;

this.SolutionTransforms.Add((solution, projectId) =>
{
var parseOptions = (VisualBasicParseOptions)solution.GetProject(projectId).ParseOptions;
solution = solution.WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.VisualBasic15_5));
Project? project = solution.GetProject(projectId);

var parseOptions = (VisualBasicParseOptions)project!.ParseOptions!;
project = project.WithParseOptions(parseOptions.WithLanguageVersion(LanguageVersion.VisualBasic15_5));

if (this.HasEntryPoint)
{
var compilationOptions = solution.GetProject(projectId).CompilationOptions;
solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
project = project.WithCompilationOptions(project.CompilationOptions!.WithOutputKind(OutputKind.ConsoleApplication));
}

if (this.IncludeMicrosoftVisualStudioThreading)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
}

if (this.IncludeWindowsBase)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
}

if (this.IncludeVisualStudioSdk)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));

var nugetPackagesFolder = Environment.CurrentDirectory;
foreach (var reference in ReferencesHelper.VSSDKPackageReferences)
foreach (var reference in VisualBasicCodeFixVerifier<TAnalyzer, TCodeFix>.Test.VSSDKPackageReferences)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
}
}

return solution;
return project.Solution;
});

this.TestState.AdditionalFilesFactories.Add(() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Testing;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.CodeAnalysis.Testing.Verifiers;
using Xunit;
using Verify = MultiAnalyzerTests.Verifier;

Expand All @@ -32,7 +30,7 @@ Task<int> FooAsync() {
return Task.FromResult(1);
}

Task BarAsync() => null;
Task BarAsync() => Task.CompletedTask;

static void SetTaskSourceIfCompleted<T>(Task<T> task, TaskCompletionSource<T> tcs) {
if (task.IsCompleted) {
Expand Down Expand Up @@ -171,7 +169,7 @@ public Task BAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}

internal Task CAsync() {
Expand All @@ -181,7 +179,7 @@ internal Task CAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}

private void D<T>() { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ static Task<bool> MyMethodAsync()
{
var projectA = solution.AddProject("ProjectA", "ProjectA", LanguageNames.CSharp)
.WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary))
.WithMetadataReferences(solution.GetProject(projectId).MetadataReferences.Concat(test!.TestState.AdditionalReferences))
.WithMetadataReferences(solution.GetProject(projectId)!.MetadataReferences.Concat(test!.TestState.AdditionalReferences))
.AddDocument("SpecialTasks.cs", specialTasksCs).Project;
solution = projectA.Solution;
solution = solution.AddProjectReference(projectId, new ProjectReference(projectA.Id));
Expand Down
Loading