Skip to content

Commit

Permalink
Async-iterator method can return IAsyncEnumerator<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
jcouv committed Nov 18, 2018
1 parent 71e1473 commit 09ed675
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 58 deletions.
12 changes: 9 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2482,10 +2482,16 @@ protected bool IsGenericTaskReturningAsyncMethod()
return symbol?.Kind == SymbolKind.Method && ((MethodSymbol)symbol).IsGenericTaskReturningAsync(this.Compilation);
}

protected bool IsIAsyncEnumerableReturningAsyncMethod()
protected bool IsIAsyncEnumerableOrIAsyncEnumeratorReturningAsyncMethod()
{
var symbol = this.ContainingMemberOrLambda;
return symbol?.Kind == SymbolKind.Method && ((MethodSymbol)symbol).IsIAsyncEnumerableReturningAsync(this.Compilation);
if (symbol?.Kind == SymbolKind.Method)
{
var method = (MethodSymbol)symbol;
return method.IsIAsyncEnumerableReturningAsync(this.Compilation) ||
method.IsIAsyncEnumeratorReturningAsync(this.Compilation);
}
return false;
}

protected virtual TypeSymbol GetCurrentReturnType(out RefKind refKind)
Expand Down Expand Up @@ -2547,7 +2553,7 @@ private BoundStatement BindReturn(ReturnStatementSyntax syntax, DiagnosticBag di
diagnostics.Add(ErrorCode.ERR_MustNotHaveRefReturn, syntax.ReturnKeyword.GetLocation());
hasErrors = true;
}
else if (IsIAsyncEnumerableReturningAsyncMethod())
else if (IsIAsyncEnumerableOrIAsyncEnumeratorReturningAsyncMethod())
{
diagnostics.Add(ErrorCode.ERR_ReturnInIterator, syntax.ReturnKeyword.GetLocation());
hasErrors = true;
Expand Down
8 changes: 5 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/InMethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ internal override TypeSymbol GetIteratorElementType(YieldStatementSyntax node, D
}
else if (!returnType.IsErrorType())
{
Error(elementTypeDiagnostics, ErrorCode.ERR_BadIteratorReturn, _methodSymbol.Locations[0], _methodSymbol, returnType);
Error(elementTypeDiagnostics, ErrorCode.ERR_BadIteratorReturn, _methodSymbol.Locations[0], _methodSymbol, returnType);
}
elementType = CreateErrorType();
}
Expand Down Expand Up @@ -184,7 +184,8 @@ internal static TypeSymbolWithAnnotations GetIteratorElementTypeFromReturnType(C
{
if (refKind == RefKind.None && returnType.Kind == SymbolKind.NamedType)
{
switch (returnType.OriginalDefinition.SpecialType)
TypeSymbol originalDefinition = returnType.OriginalDefinition;
switch (originalDefinition.SpecialType)
{
case SpecialType.System_Collections_IEnumerable:
case SpecialType.System_Collections_IEnumerator:
Expand All @@ -200,7 +201,8 @@ internal static TypeSymbolWithAnnotations GetIteratorElementTypeFromReturnType(C
return ((NamedTypeSymbol)returnType).TypeArgumentsNoUseSiteDiagnostics[0];
}

if (returnType.OriginalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T))
if (originalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T) ||
originalDefinition == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T))
{
return ((NamedTypeSymbol)returnType).TypeArgumentsNoUseSiteDiagnostics[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ private sealed class AsyncIteratorRewriter : AsyncRewriter
private FieldSymbol _promiseOfValueOrEndField; // this struct implements the IValueTaskSource logic
private FieldSymbol _currentField; // stores the current/yielded value

// true if the iterator implements IAsyncEnumerable<T>,
// false if it implements IAsyncEnumerator<T>
private readonly bool _isEnumerable;

internal AsyncIteratorRewriter(
BoundStatement body,
MethodSymbol method,
Expand All @@ -29,14 +33,21 @@ internal AsyncIteratorRewriter(
: base(body, method, methodOrdinal, stateMachineType, slotAllocatorOpt, compilationState, diagnostics)
{
Debug.Assert(method.IteratorElementType != null);

_isEnumerable = method.IsIAsyncEnumerableReturningAsync(method.DeclaringCompilation);
}

protected override void VerifyPresenceOfRequiredAPIs(DiagnosticBag bag)
{
base.VerifyPresenceOfRequiredAPIs(bag);
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerable_T__GetAsyncEnumerator, bag);

if (_isEnumerable)
{
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerable_T__GetAsyncEnumerator, bag);
}
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerator_T__MoveNextAsync, bag);
EnsureWellKnownMember(WellKnownMember.System_Collections_Generic_IAsyncEnumerator_T__get_Current, bag);

EnsureWellKnownMember(WellKnownMember.System_IAsyncDisposable__DisposeAsync, bag);
EnsureWellKnownMember(WellKnownMember.System_Threading_Tasks_ValueTask_T__ctor, bag);

Expand All @@ -62,8 +73,11 @@ protected override void GenerateMethodImplementations()
// IAsyncStateMachine methods and constructor
base.GenerateMethodImplementations();

// IAsyncEnumerable
GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator();
if (_isEnumerable)
{
// IAsyncEnumerable
GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator();
}

// IAsyncEnumerator
GenerateIAsyncEnumeratorImplementation_MoveNextAsync();
Expand All @@ -81,6 +95,9 @@ protected override void GenerateMethodImplementations()
GenerateIAsyncDisposable_DisposeAsync();
}

protected override bool PreserveInitialParameterValuesAndThreadId
=> _isEnumerable;

protected override void GenerateControlFields()
{
// the fields are initialized from entry-point method (which replaces the async-iterator method), so they need to be public
Expand Down Expand Up @@ -153,8 +170,8 @@ protected override void GenerateConstructor()

protected override void InitializeStateMachine(ArrayBuilder<BoundStatement> bodyBuilder, NamedTypeSymbol frameType, LocalSymbol stateMachineLocal)
{
// var stateMachineLocal = new {StateMachineType}(FinishedStateMachine)
int initialState = StateMachineStates.FinishedStateMachine;
// var stateMachineLocal = new {StateMachineType}({initialState})
int initialState = _isEnumerable ? StateMachineStates.FinishedStateMachine : StateMachineStates.NotStartedStateMachine;
bodyBuilder.Add(
F.Assignment(
F.Local(stateMachineLocal),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ private Symbol EnsureWellKnownMember(WellKnownMember member, DiagnosticBag bag)
return Binder.GetWellKnownTypeMember(F.Compilation, member, bag, body.Syntax.Location);
}

// Should only be true for async-enumerables, not async-enumerators. Tracked by https://github.com/dotnet/roslyn/issues/31057
protected override bool PreserveInitialParameterValuesAndThreadId
=> method.IsIterator;
=> false;

protected override void GenerateControlFields()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ public AsyncStateMachine(VariableSlotAllocator variableAllocatorOpt, TypeCompila
var elementType = TypeMap.SubstituteType(asyncMethod.IteratorElementType).TypeSymbol;
this.IteratorElementType = elementType;

// IAsyncEnumerable<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T).Construct(elementType));
bool isEnumerable = asyncMethod.IsIAsyncEnumerableReturningAsync(asyncMethod.DeclaringCompilation);
if (isEnumerable)
{
// IAsyncEnumerable<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T).Construct(elementType));
}

// IAsyncEnumerator<TResult>
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T).Construct(elementType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,15 @@ public static bool IsIAsyncEnumerableReturningAsync(this MethodSymbol method, CS
&& method.ReturnType.TypeSymbol.IsIAsyncEnumerableType(compilation);
}

/// <summary>
/// Returns whether this method is async and returns an IAsyncEnumerator`1.
/// </summary>
public static bool IsIAsyncEnumeratorReturningAsync(this MethodSymbol method, CSharpCompilation compilation)
{
return method.IsAsync
&& method.ReturnType.TypeSymbol.IsIAsyncEnumeratorType(compilation);
}

internal static CSharpSyntaxNode ExtractReturnTypeSyntax(this MethodSymbol method)
{
method = method.PartialDefinitionPart ?? method;
Expand Down
14 changes: 13 additions & 1 deletion src/Compilers/CSharp/Portable/Symbols/TypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,17 @@ internal static bool IsIAsyncEnumerableType(this TypeSymbol type, CSharpCompilat
return (object)namedType.ConstructedFrom == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T);
}

internal static bool IsIAsyncEnumeratorType(this TypeSymbol type, CSharpCompilation compilation)
{
var namedType = type as NamedTypeSymbol;
if ((object)namedType == null || namedType.Arity != 1)
{
return false;
}

return (object)namedType.ConstructedFrom == compilation.GetWellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerator_T);
}

/// <summary>
/// Returns true if the type is generic or non-generic custom task-like type due to the
/// [AsyncMethodBuilder(typeof(B))] attribute. It returns the "B".
Expand Down Expand Up @@ -1697,7 +1708,8 @@ public static bool IsBadAsyncReturn(this TypeSymbol returnType, CSharpCompilatio
return returnType.SpecialType != SpecialType.System_Void &&
!returnType.IsNonGenericTaskType(declaringCompilation) &&
!returnType.IsGenericTaskType(declaringCompilation) &&
!returnType.IsIAsyncEnumerableType(declaringCompilation);
!returnType.IsIAsyncEnumerableType(declaringCompilation) &&
!returnType.IsIAsyncEnumeratorType(declaringCompilation);
}
}
}
Loading

0 comments on commit 09ed675

Please sign in to comment.