Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixer: Prefer Memory overloads for Stream async Read/Write methods #3592

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Immutable;
using System.Composition;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;

namespace Microsoft.NetCore.Analyzers.Runtime
{
[ExportCodeFixProvider(LanguageNames.CSharp, LanguageNames.VisualBasic), Shared]
public class PreferStreamAsyncMemoryOverloadsFixer : CodeFixProvider
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
{
public sealed override ImmutableArray<string> FixableDiagnosticIds =>
ImmutableArray.Create(PreferStreamAsyncMemoryOverloads.RuleId);

public sealed override FixAllProvider GetFixAllProvider() =>
WellKnownFixAllProviders.BatchFixer;

public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
Document doc = context.Document;
CancellationToken ct = context.CancellationToken;
SyntaxNode root = await doc.GetSyntaxRootAsync(ct).ConfigureAwait(false);
if (root.FindNode(context.Span) is SyntaxNode node)
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
{
SemanticModel model = await doc.GetSemanticModelAsync(ct).ConfigureAwait(false);
if (model.GetOperation(node, ct) is IInvocationOperation invocation)
{
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
string methodName = invocation.TargetMethod.Name;

string title;
if (methodName == "ReadAsync")
{
title = MicrosoftNetCoreAnalyzersResources.PreferStreamReadAsyncMemoryOverloadsTitle;
}
else if (methodName == "WriteAsync")
{
title = MicrosoftNetCoreAnalyzersResources.PreferStreamWriteAsyncMemoryOverloadsTitle;
}
else
{
return;
}

context.RegisterCodeFix(
CodeAction.Create(
title: title,
createChangedDocument: c => FixInvocation(doc, root, invocation, methodName),
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
equivalenceKey: MicrosoftNetCoreAnalyzersResources.PreferStreamAsyncMemoryOverloadsMessage),
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
context.Diagnostics);
}
}
}

private static Task<Document> FixInvocation(Document doc, SyntaxNode root, IInvocationOperation invocation, string methodName)
{
SyntaxGenerator generator = SyntaxGenerator.GetGenerator(doc);

// The stream object
SyntaxNode instanceNode = invocation.Instance.Syntax;

// Need the byte array object so we can invoke its AsMemory() method
SyntaxNode bufferInstanceNode = invocation.Arguments[0].Value.Syntax; // byte[] buffer

// These arguments are not modified, just moved inside AsMemory
SyntaxNode offsetNode = invocation.Arguments[1].Syntax; // int offset
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
SyntaxNode countNode = invocation.Arguments[2].Syntax; // int count

// Generate an invocation of the AsMemory() method from the byte array object
SyntaxNode asMemoryExpressionNode = generator.MemberAccessExpression(bufferInstanceNode, "AsMemory");
SyntaxNode asMemoryInvocationNode = generator.InvocationExpression(asMemoryExpressionNode, offsetNode, countNode);

// Create a new async method call for the stream object, no arguments yet
SyntaxNode asyncMethodNode = generator.MemberAccessExpression(instanceNode, methodName);

// Add the arguments to the async method call, with or without CancellationToken
SyntaxNode newInvocationExpression;
if (invocation.Arguments.Length > 3)
{
newInvocationExpression = generator.InvocationExpression(
asyncMethodNode, asMemoryInvocationNode,
invocation.Arguments[3].Syntax /* CancellationToken */);
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
newInvocationExpression = generator.InvocationExpression(asyncMethodNode, asMemoryInvocationNode);
}

SyntaxNode newInvocation = generator.ReplaceNode(root, invocation.Syntax, newInvocationExpression);
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
return Task.FromResult(doc.WithSyntaxRoot(newInvocation));
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ public sealed class PreferStreamAsyncMemoryOverloads : DiagnosticAnalyzer
{
internal const string RuleId = "CA1835";

private static readonly LocalizableString s_localizableTitleRead = new LocalizableResourceString(
nameof(MicrosoftNetCoreAnalyzersResources.PreferStreamReadAsyncMemoryOverloadsTitle),
private static readonly LocalizableString s_localizableMessage = new LocalizableResourceString(
nameof(MicrosoftNetCoreAnalyzersResources.PreferStreamAsyncMemoryOverloadsMessage),
MicrosoftNetCoreAnalyzersResources.ResourceManager,
typeof(MicrosoftNetCoreAnalyzersResources));

private static readonly LocalizableString s_localizableMessage = new LocalizableResourceString(
nameof(MicrosoftNetCoreAnalyzersResources.PreferStreamAsyncMemoryOverloadsMessage),
private static readonly LocalizableString s_localizableTitleRead = new LocalizableResourceString(
nameof(MicrosoftNetCoreAnalyzersResources.PreferStreamReadAsyncMemoryOverloadsTitle),
MicrosoftNetCoreAnalyzersResources.ResourceManager,
typeof(MicrosoftNetCoreAnalyzersResources));

Expand Down Expand Up @@ -100,6 +100,7 @@ private void AnalyzeCompilationStart(CompilationStartAnalysisContext context)
{
return;
}

if (!context.Compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemMemory1, out INamedTypeSymbol? memoryType))
{
return;
Expand Down Expand Up @@ -139,21 +140,7 @@ private void AnalyzeCompilationStart(CompilationStartAnalysisContext context)
ParameterInfo.GetParameterInfo(cancellationTokenType)
};

// Create the arrays with the exact parameter order of the desired methods
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
var preferredReadAsyncParameters = new[]
{
ParameterInfo.GetParameterInfo(readOnlyMemoryType), // ReadOnlyMemory<byte> buffer
ParameterInfo.GetParameterInfo(cancellationTokenType), // CancellationToken
};

var preferredWriteAsyncParameters = new[]
{
ParameterInfo.GetParameterInfo(memoryType), // ReadOnlyMemory<byte> buffer
ParameterInfo.GetParameterInfo(cancellationTokenType), // CancellationToken
};

// Retrieve the ReadAsync/WriteSync methods available in Stream
// If we don't find them all, the Memory based overloads are not supported in this .NET version
IEnumerable<IMethodSymbol> readAsyncMethodGroup = streamType.GetMembers("ReadAsync").OfType<IMethodSymbol>();
IEnumerable<IMethodSymbol> writeAsyncMethodGroup = streamType.GetMembers("WriteAsync").OfType<IMethodSymbol>();

Expand Down Expand Up @@ -184,14 +171,18 @@ private void AnalyzeCompilationStart(CompilationStartAnalysisContext context)

// Retrieve the preferred methods, which are used for constructing the rule message
IMethodSymbol? preferredReadAsyncMethod = readAsyncMethodGroup.FirstOrDefault(x =>
x.Parameters.Count() == 2 && x.Parameters[0].Type is INamedTypeSymbol type && type.ConstructedFrom.Equals(memoryType));
x.Parameters.Count() == 2 &&
x.Parameters[0].Type is INamedTypeSymbol type &&
type.ConstructedFrom.Equals(memoryType));
if (preferredReadAsyncMethod == null)
{
return;
}

IMethodSymbol? preferredWriteAsyncMethod = writeAsyncMethodGroup.FirstOrDefault(x =>
x.Parameters.Count() == 2 && x.Parameters[0].Type is INamedTypeSymbol type && type.ConstructedFrom.Equals(readOnlyMemoryType));
x.Parameters.Count() == 2 &&
x.Parameters[0].Type is INamedTypeSymbol type &&
type.ConstructedFrom.Equals(readOnlyMemoryType));
if (preferredWriteAsyncMethod == null)
{
return;
Expand Down Expand Up @@ -227,20 +218,31 @@ private void AnalyzeCompilationStart(CompilationStartAnalysisContext context)
context.RegisterOperationAction(context =>
{
IAwaitOperation awaitOperation = (IAwaitOperation)context.Operation;
if (ShouldAnalyze(awaitOperation, configureAwaitMethod, genericConfigureAwaitMethod, streamType, out IMethodSymbol? method) && method != null)

if (ShouldAnalyze(
awaitOperation,
configureAwaitMethod,
genericConfigureAwaitMethod,
streamType,
out IInvocationOperation? invocation,
out IMethodSymbol? method) &&
invocation != null &&
method != null)
{
DiagnosticDescriptor rule;
string ruleMessageMethod;
string ruleMessagePreferredMethod;

// Verify if the method is an undesired Async overload
if (method.Equals(undesiredReadAsyncMethod) || method.Equals(undesiredReadAsyncMethodWithCancellationToken))
if (method.Equals(undesiredReadAsyncMethod) ||
method.Equals(undesiredReadAsyncMethodWithCancellationToken))
{
rule = PreferStreamReadAsyncMemoryOverloadsRule;
ruleMessageMethod = undesiredReadAsyncMethod.Name;
ruleMessagePreferredMethod = preferredReadAsyncMethod.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat);
}
else if (method.Equals(undesiredWriteAsyncMethod) || method.Equals(undesiredWriteAsyncMethodWithCancellationToken))
else if (method.Equals(undesiredWriteAsyncMethod) ||
method.Equals(undesiredWriteAsyncMethodWithCancellationToken))
{
rule = PreferStreamWriteAsyncMemoryOverloadsRule;
ruleMessageMethod = undesiredWriteAsyncMethod.Name;
Expand All @@ -252,7 +254,7 @@ private void AnalyzeCompilationStart(CompilationStartAnalysisContext context)
return;
}

context.ReportDiagnostic(awaitOperation.Operation.CreateDiagnostic(rule, ruleMessageMethod, ruleMessagePreferredMethod));
context.ReportDiagnostic(invocation.CreateDiagnostic(rule, ruleMessageMethod, ruleMessagePreferredMethod));
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
}
},
OperationKind.Await);
Expand All @@ -263,26 +265,30 @@ private static bool ShouldAnalyze(
IMethodSymbol configureAwaitMethod,
IMethodSymbol genericConfigureAwaitMethod,
INamedTypeSymbol streamType,
out IInvocationOperation? actualInvocation,
out IMethodSymbol? actualMethod)
{
actualInvocation = null;
actualMethod = null;

// The await should have a known operation child, check its kind
if (!(awaitOperation.Operation is IInvocationOperation invocation))
if (!(awaitOperation.Operation is IInvocationOperation childOperation))
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
{
return false;
}

IMethodSymbol method = invocation.TargetMethod;
actualInvocation = childOperation;
IMethodSymbol method = childOperation.TargetMethod;

// Check if the child operation of the await is ConfigureAwait
// in which case we should analyze the grandchild operation
if (method.OriginalDefinition.Equals(configureAwaitMethod) ||
method.OriginalDefinition.Equals(genericConfigureAwaitMethod))
{
if (invocation.Instance is IInvocationOperation instanceInvocation)
if (childOperation.Instance is IInvocationOperation instanceOperation)
{
method = instanceInvocation.TargetMethod;
actualInvocation = instanceOperation;
method = instanceOperation.TargetMethod;
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,34 @@
using Microsoft.CodeAnalysis.Testing;
using VerifyCS = Test.Utilities.CSharpCodeFixVerifier<
Microsoft.NetCore.Analyzers.Runtime.PreferStreamAsyncMemoryOverloads,
Microsoft.CodeAnalysis.Testing.EmptyCodeFixProvider>;
Microsoft.NetCore.Analyzers.Runtime.PreferStreamAsyncMemoryOverloadsFixer>;
using VerifyVB = Test.Utilities.VisualBasicCodeFixVerifier<
Microsoft.NetCore.Analyzers.Runtime.PreferStreamAsyncMemoryOverloads,
Microsoft.CodeAnalysis.Testing.EmptyCodeFixProvider>;
Microsoft.NetCore.Analyzers.Runtime.PreferStreamAsyncMemoryOverloadsFixer>;

#pragma warning disable CA1305 // Specify IFormatProvider in string.Format
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved

namespace Microsoft.NetCore.Analyzers.Runtime.UnitTests
{
public class PreferStreamAsyncMemoryOverloadsTestBase
{
protected static Task AnalyzeCSAsync(string source, params DiagnosticResult[] expected) =>
AnalyzeCSForVersionAsync(source, ReferenceAssemblies.NetCore.NetCoreApp50, expected);
VerifyCSForVersionAsync(source, null, ReferenceAssemblies.NetCore.NetCoreApp50, expected);

protected static Task FixCSAsync(string originalSource, string fixedSource, params DiagnosticResult[] expected) =>
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
VerifyCSForVersionAsync(originalSource, fixedSource, ReferenceAssemblies.NetCore.NetCoreApp50, expected);

protected static Task AnalyzeCSUnsupportedAsync(string source, params DiagnosticResult[] expected) =>
AnalyzeCSForVersionAsync(source, ReferenceAssemblies.NetCore.NetCoreApp20, expected);
VerifyCSForVersionAsync(source, null, ReferenceAssemblies.NetCore.NetCoreApp20, expected);

protected static Task AnalyzeVBAsync(string source, params DiagnosticResult[] expected) =>
AnalyzeVBForVersionAsync(source, ReferenceAssemblies.NetCore.NetCoreApp50, expected);
VerifyVBForVersionAsync(source, null, ReferenceAssemblies.NetCore.NetCoreApp50, expected);

protected static Task FixVBAsync(string originalSource, string fixedSource, params DiagnosticResult[] expected) =>
VerifyVBForVersionAsync(originalSource, fixedSource, ReferenceAssemblies.NetCore.NetCoreApp50, expected);

protected static Task AnalyzeVBUnsupportedAsync(string source, params DiagnosticResult[] expected) =>
AnalyzeVBForVersionAsync(source, ReferenceAssemblies.NetCore.NetCoreApp20, expected);
VerifyVBForVersionAsync(source, null, ReferenceAssemblies.NetCore.NetCoreApp20, expected);

protected static DiagnosticResult GetCSResultForRule(int startLine, int startColumn, int endLine, int endColumn, DiagnosticDescriptor rule, string methodName, string methodPreferredName)
=> VerifyCS.Diagnostic(rule)
Expand All @@ -36,27 +44,58 @@ protected static DiagnosticResult GetVBResultForRule(int startLine, int startCol
.WithSpan(startLine, startColumn, endLine, endColumn)
.WithArguments(methodName, methodPreferredName);

private static Task AnalyzeCSForVersionAsync(string source, ReferenceAssemblies version, params DiagnosticResult[] expected)
protected static string GetFormattedSourceCode(string source, string asyncMethodPrefix, string args, bool withConfigureAwait, string language)
{
string configureAwait = string.Empty;

if (withConfigureAwait)
{
char booleanArgumentInitial = 'f';
if (language == LanguageNames.VisualBasic)
{
booleanArgumentInitial = 'F';
}
configureAwait = $".ConfigureAwait({booleanArgumentInitial}alse)";
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
}

asyncMethodPrefix = string.Format("s.{0}Async({1}){2}", asyncMethodPrefix, args, configureAwait);

return string.Format(source, asyncMethodPrefix);
}

private static Task VerifyCSForVersionAsync(string originalSource, string fixedSource, ReferenceAssemblies version, params DiagnosticResult[] expected)
{
var test = new VerifyCS.Test
{
TestCode = source,
TestCode = originalSource,
ReferenceAssemblies = version,
};

if (!string.IsNullOrEmpty(fixedSource))
{
test.FixedCode = fixedSource;
}

test.ExpectedDiagnostics.AddRange(expected);

return test.RunAsync();
}

private static Task AnalyzeVBForVersionAsync(string source, ReferenceAssemblies version, params DiagnosticResult[] expected)
private static Task VerifyVBForVersionAsync(string originalSource, string fixedSource, ReferenceAssemblies version, params DiagnosticResult[] expected)
{
var test = new VerifyVB.Test
{
TestCode = source,
TestCode = originalSource,
ReferenceAssemblies = version,
};

if (!string.IsNullOrEmpty(fixedSource))
{
test.FixedCode = fixedSource;
}

test.ExpectedDiagnostics.AddRange(expected);

return test.RunAsync();
}
}
Expand Down
Loading