Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async-iterator method can return IAsyncEnumerator<T> #31114

Merged
merged 3 commits into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2482,10 +2482,16 @@ protected bool IsGenericTaskReturningAsyncMethod()
return symbol?.Kind == SymbolKind.Method && ((MethodSymbol)symbol).IsGenericTaskReturningAsync(this.Compilation);
}

protected bool IsIAsyncEnumerableReturningAsyncMethod()
protected bool IsIAsyncEnumerableOrIAsyncEnumeratorReturningAsyncMethod()
{
var symbol = this.ContainingMemberOrLambda;
return symbol?.Kind == SymbolKind.Method && ((MethodSymbol)symbol).IsIAsyncEnumerableReturningAsync(this.Compilation);
if (symbol?.Kind == SymbolKind.Method)
{
var method = (MethodSymbol)symbol;
return method.IsIAsyncEnumerableReturningAsync(this.Compilation) ||
method.IsIAsyncEnumeratorReturningAsync(this.Compilation);
}
return false;
}

protected virtual TypeSymbol GetCurrentReturnType(out RefKind refKind)
Expand Down Expand Up @@ -2547,7 +2553,7 @@ private BoundStatement BindReturn(ReturnStatementSyntax syntax, DiagnosticBag di
diagnostics.Add(ErrorCode.ERR_MustNotHaveRefReturn, syntax.ReturnKeyword.GetLocation());
hasErrors = true;
}
else if (IsIAsyncEnumerableReturningAsyncMethod())
else if (IsIAsyncEnumerableOrIAsyncEnumeratorReturningAsyncMethod())
{
diagnostics.Add(ErrorCode.ERR_ReturnInIterator, syntax.ReturnKeyword.GetLocation());
hasErrors = true;
Expand Down
8 changes: 5 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/InMethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ internal override TypeSymbol GetIteratorElementType(YieldStatementSyntax node, D
}
else if (!returnType.IsErrorType())
{
Error(elementTypeDiagnostics, ErrorCode.ERR_BadIteratorReturn, _methodSymbol.Locations[0], _methodSymbol, returnType);
Error(elementTypeDiagnostics, ErrorCode.ERR_BadIteratorReturn, _methodSymbol.Locations[0], _methodSymbol, returnType);
}
elementType = CreateErrorType();
}
Expand Down Expand Up @@ -184,7 +184,8 @@ internal static TypeSymbolWithAnnotations GetIteratorElementTypeFromReturnType(C
{
if (refKind == RefKind.None && returnType.Kind == SymbolKind.NamedType)
{
switch (returnType.OriginalDefinition.SpecialType)
TypeSymbol originalDefinition = returnType.OriginalDefinition;
switch (originalDefinition.SpecialType)
{
case SpecialType.System_Collections_IEnumerable:
case SpecialType.System_Collections_IEnumerator:
Expand All @@ -200,7 +201,8 @@ internal static TypeSymbolWithAnnotations GetIteratorElementTypeFromReturnType(C
return ((NamedTypeSymbol)returnType).TypeArgumentsNoUseSiteDiagnostics[0];
}

if (returnType.OriginalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T))
if (originalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T) ||
originalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T))
{
return ((NamedTypeSymbol)returnType).TypeArgumentsNoUseSiteDiagnostics[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ private sealed class AsyncIteratorRewriter : AsyncRewriter
private FieldSymbol _promiseOfValueOrEndField; // this struct implements the IValueTaskSource logic
private FieldSymbol _currentField; // stores the current/yielded value

// true if the iterator implements IAsyncEnumerable<T>,
// false if it implements IAsyncEnumerator<T>
private readonly bool _isEnumerable;

internal AsyncIteratorRewriter(
BoundStatement body,
MethodSymbol method,
Expand All @@ -29,14 +33,21 @@ internal AsyncIteratorRewriter(
: base(body, method, methodOrdinal, stateMachineType, slotAllocatorOpt, compilationState, diagnostics)
{
Debug.Assert(method.IteratorElementType != null);

_isEnumerable = method.IsIAsyncEnumerableReturningAsync(method.DeclaringCompilation);
}

protected override void VerifyPresenceOfRequiredAPIs(DiagnosticBag bag)
{
base.VerifyPresenceOfRequiredAPIs(bag);
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerable_T__GetAsyncEnumerator, bag);

if (_isEnumerable)
{
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerable_T__GetAsyncEnumerator, bag);
}
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerator_T__MoveNextAsync, bag);
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerator_T__get_Current, bag);

EnsureWellKnownMember(WellKnownMember.System_IAsyncDisposable__DisposeAsync, bag);
EnsureWellKnownMember(WellKnownMember.System_Threading_Tasks_ValueTask_T__ctor, bag);

Expand All @@ -62,8 +73,11 @@ protected override void GenerateMethodImplementations()
// IAsyncStateMachine methods and constructor
base.GenerateMethodImplementations();

// IAsyncEnumerable
GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator();
if (_isEnumerable)
{
// IAsyncEnumerable
GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator();
}

// IAsyncEnumerator
GenerateIAsyncEnumeratorImplementation_MoveNextAsync();
Expand All @@ -81,6 +95,9 @@ protected override void GenerateMethodImplementations()
GenerateIAsyncDisposable_DisposeAsync();
}

protected override bool PreserveInitialParameterValuesAndThreadId
=> _isEnumerable;

protected override void GenerateControlFields()
{
// the fields are initialized from entry-point method (which replaces the async-iterator method), so they need to be public
Expand Down Expand Up @@ -153,8 +170,8 @@ protected override void GenerateConstructor()

protected override void InitializeStateMachine(ArrayBuilder<BoundStatement> bodyBuilder, NamedTypeSymbol frameType, LocalSymbol stateMachineLocal)
{
// var stateMachineLocal = new {StateMachineType}(FinishedStateMachine)
int initialState = StateMachineStates.FinishedStateMachine;
// var stateMachineLocal = new {StateMachineType}({initialState})
int initialState = _isEnumerable ? StateMachineStates.FinishedStateMachine : StateMachineStates.NotStartedStateMachine;
bodyBuilder.Add(
F.Assignment(
F.Local(stateMachineLocal),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ private Symbol EnsureWellKnownMember(WellKnownMember member, DiagnosticBag bag)
return Binder.GetWellKnownTypeMember(F.Compilation, member, bag, body.Syntax.Location);
}

// Should only be true for async-enumerables, not async-enumerators. Tracked by https://github.com/dotnet/roslyn/issues/31057
protected override bool PreserveInitialParameterValuesAndThreadId
=> method.IsIterator;
=> false;

protected override void GenerateControlFields()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ public AsyncStateMachine(VariableSlotAllocator variableAllocatorOpt, TypeCompila
var elementType = TypeMap.SubstituteType(asyncMethod.IteratorElementType).TypeSymbol;
this.IteratorElementType = elementType;

// IAsyncEnumerable<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T).Construct(elementType));
bool isEnumerable = asyncMethod.IsIAsyncEnumerableReturningAsync(asyncMethod.DeclaringCompilation);
Copy link
Member

@cston cston Nov 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncMethod.DeclaringCompilation [](start = 81, length = 32)

compilation #Resolved

if (isEnumerable)
{
// IAsyncEnumerable<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T).Construct(elementType));
}

// IAsyncEnumerator<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T).Construct(elementType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,15 @@ public static bool IsIAsyncEnumerableReturningAsync(this MethodSymbol method, CS
&& method.ReturnType.TypeSymbol.IsIAsyncEnumerableType(compilation);
}

/// <summary>
/// Returns whether this method is async and returns an IAsyncEnumerator`1.
/// </summary>
public static bool IsIAsyncEnumeratorReturningAsync(this MethodSymbol method, CSharpCompilation compilation)
{
return method.IsAsync
&& method.ReturnType.TypeSymbol.IsIAsyncEnumeratorType(compilation);
}

internal static CSharpSyntaxNode ExtractReturnTypeSyntax(this MethodSymbol method)
{
method = method.PartialDefinitionPart ?? method;
Expand Down
14 changes: 13 additions & 1 deletion src/Compilers/CSharp/Portable/Symbols/TypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,17 @@ internal static bool IsIAsyncEnumerableType(this TypeSymbol type, CSharpCompilat
return (object)namedType.ConstructedFrom == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T);
}

internal static bool IsIAsyncEnumeratorType(this TypeSymbol type, CSharpCompilation compilation)
{
var namedType = type as NamedTypeSymbol;
if ((object)namedType == null || namedType.Arity != 1)
{
return false;
}

return (object)namedType.ConstructedFrom == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T);
}

/// <summary>
/// Returns true if the type is generic or non-generic custom task-like type due to the
/// [AsyncMethodBuilder(typeof(B))] attribute. It returns the "B".
Expand Down Expand Up @@ -1697,7 +1708,8 @@ public static bool IsBadAsyncReturn(this TypeSymbol returnType, CSharpCompilatio
return returnType.SpecialType != SpecialType.System_Void &&
!returnType.IsNonGenericTaskType(declaringCompilation) &&
!returnType.IsGenericTaskType(declaringCompilation) &&
!returnType.IsIAsyncEnumerableType(declaringCompilation);
!returnType.IsIAsyncEnumerableType(declaringCompilation) &&
!returnType.IsIAsyncEnumeratorType(declaringCompilation);
}
}
}
Loading