Skip to content

Commit

Permalink
Collection literals: type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
cston committed Jun 21, 2023
1 parent 6bd8524 commit 019790a
Show file tree
Hide file tree
Showing 6 changed files with 752 additions and 60 deletions.
6 changes: 3 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ private BoundExpression ConvertCollectionLiteralExpression(
bool wasCompilerGenerated,
BindingDiagnosticBag diagnostics)
{
TypeSymbol? elementType;
TypeWithAnnotations elementType;
BoundCollectionLiteralExpression collectionLiteral;
var collectionTypeKind = ConversionsBase.GetCollectionLiteralTypeKind(Compilation, targetType, out elementType);
switch (collectionTypeKind)
Expand All @@ -437,10 +437,10 @@ private BoundExpression ConvertCollectionLiteralExpression(
case CollectionLiteralTypeKind.Array:
case CollectionLiteralTypeKind.Span:
case CollectionLiteralTypeKind.ReadOnlySpan:
collectionLiteral = BindArrayOrSpanCollectionLiteral(node, targetType, wasCompilerGenerated: wasCompilerGenerated, collectionTypeKind, elementType!, diagnostics);
collectionLiteral = BindArrayOrSpanCollectionLiteral(node, targetType, wasCompilerGenerated: wasCompilerGenerated, collectionTypeKind, elementType.Type!, diagnostics);
break;
case CollectionLiteralTypeKind.ListInterface:
collectionLiteral = BindListInterfaceCollectionLiteral(node, targetType, wasCompilerGenerated: wasCompilerGenerated, elementType!, diagnostics);
collectionLiteral = BindListInterfaceCollectionLiteral(node, targetType, wasCompilerGenerated: wasCompilerGenerated, elementType.Type!, diagnostics);
break;
default:
throw ExceptionUtilities.UnexpectedValue(collectionTypeKind);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -1584,15 +1583,15 @@ private static bool HasAnonymousFunctionConversion(BoundExpression source, TypeS
return IsAnonymousFunctionCompatibleWithType((UnboundLambda)source, destination) == LambdaConversionResult.Success;
}

internal static CollectionLiteralTypeKind GetCollectionLiteralTypeKind(CSharpCompilation compilation, TypeSymbol destination, out TypeSymbol? elementType)
internal static CollectionLiteralTypeKind GetCollectionLiteralTypeKind(CSharpCompilation compilation, TypeSymbol destination, out TypeWithAnnotations elementType)
{
Debug.Assert(compilation is { });

if (destination is ArrayTypeSymbol arrayType)
{
if (arrayType.IsSZArray)
{
elementType = arrayType.ElementType;
elementType = arrayType.ElementTypeWithAnnotations;
return CollectionLiteralTypeKind.Array;
}
}
Expand All @@ -1604,32 +1603,31 @@ internal static CollectionLiteralTypeKind GetCollectionLiteralTypeKind(CSharpCom
{
return CollectionLiteralTypeKind.ReadOnlySpan;
}
else if (implementsIEnumerable(compilation, destination))
else if (implementsIEnumerable(compilation, destination, out elementType))
{
elementType = null;
return CollectionLiteralTypeKind.CollectionInitializer;
}
else if (isListInterface(compilation, destination, out elementType))
{
return CollectionLiteralTypeKind.ListInterface;
}

elementType = null;
elementType = default;
return CollectionLiteralTypeKind.None;

static bool isSpanType(CSharpCompilation compilation, TypeSymbol targetType, WellKnownType spanType, [NotNullWhen(true)] out TypeSymbol? elementType)
static bool isSpanType(CSharpCompilation compilation, TypeSymbol targetType, WellKnownType spanType, out TypeWithAnnotations elementType)
{
if (targetType is NamedTypeSymbol { Arity: 1 } namedType
&& areEqual(namedType.OriginalDefinition, compilation.GetWellKnownType(spanType)))
{
elementType = namedType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0].Type;
elementType = namedType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
return true;
}
elementType = null;
elementType = default;
return false;
}

static bool implementsIEnumerable(CSharpCompilation compilation, TypeSymbol targetType)
static bool implementsIEnumerable(CSharpCompilation compilation, TypeSymbol targetType, out TypeWithAnnotations elementType)
{
ImmutableArray<NamedTypeSymbol> allInterfaces;
switch (targetType.TypeKind)
Expand All @@ -1642,6 +1640,7 @@ static bool implementsIEnumerable(CSharpCompilation compilation, TypeSymbol targ
allInterfaces = ((TypeParameterSymbol)targetType).AllEffectiveInterfacesNoUseSiteDiagnostics;
break;
default:
elementType = default;
return false;
}

Expand All @@ -1651,10 +1650,24 @@ static bool implementsIEnumerable(CSharpCompilation compilation, TypeSymbol targ
// PROTOTYPE: Perhaps adjust the behavior of Binder.CollectionInitializerTypeImplementsIEnumerable()
// and use that method instead.
var ienumerableType = compilation.GetSpecialType(SpecialType.System_Collections_IEnumerable);
return allInterfaces.Any(static (a, b) => areEqual(a, b), ienumerableType);
if (allInterfaces.Any(static (a, b) => areEqual(a, b), ienumerableType))
{
// PROTOTYPE: Spec states target element type should be determined based on "foreach iteration type".
// That includes pattern-based GetEnumerator() etc., not just IEnumerable<T>.
// PROTOTYPE: Test with collection type that implements multiple times: IEnumerable<T>, IEnumerable<U>.
var ienumerableTypeOfT = compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T);
var implementation = allInterfaces.FirstOrDefault(static (a, b) => areEqual(a.OriginalDefinition, b), ienumerableTypeOfT);
elementType = implementation is null
? TypeWithAnnotations.Create(compilation.GetSpecialType(SpecialType.System_Object))
: implementation.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
return true;
}

elementType = default;
return false;
}

static bool isListInterface(CSharpCompilation compilation, TypeSymbol targetType, [NotNullWhen(true)] out TypeSymbol? elementType)
static bool isListInterface(CSharpCompilation compilation, TypeSymbol targetType, out TypeWithAnnotations elementType)
{
if (targetType is NamedTypeSymbol { TypeKind: TypeKind.Interface, Arity: 1 } namedType)
{
Expand All @@ -1667,12 +1680,12 @@ static bool isListInterface(CSharpCompilation compilation, TypeSymbol targetType
// Is the implementation with type argument T?
areEqual(listType.TypeParameters[0], listInterface.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0].Type))
{
elementType = namedType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0].Type;
elementType = namedType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
return true;
}
}
}
elementType = null;
elementType = default;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ private void InferTypeArgsFirstPhase(ref CompoundUseSiteInfo<AssemblySymbol> use
}
}

#nullable enable
private void MakeExplicitParameterTypeInferences(BoundExpression argument, TypeWithAnnotations target, ExactOrBoundsKind kind, ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo)
{
// SPEC: * If Ei is an anonymous function, and Ti is a delegate type or expression tree type,
Expand All @@ -609,6 +610,10 @@ private void MakeExplicitParameterTypeInferences(BoundExpression argument, TypeW
ExplicitParameterTypeInference(argument, target, ref useSiteInfo);
ExplicitReturnTypeInference(argument, target, ref useSiteInfo);
}
else if (argument.Kind == BoundKind.UnconvertedCollectionLiteralExpression)
{
MakeCollectionLiteralTypeInferences((BoundUnconvertedCollectionLiteralExpression)argument, target, ref useSiteInfo);
}
else if (argument.Kind != BoundKind.TupleLiteral ||
!MakeExplicitParameterTypeInferences((BoundTupleLiteral)argument, target, kind, ref useSiteInfo))
{
Expand All @@ -626,6 +631,42 @@ private void MakeExplicitParameterTypeInferences(BoundExpression argument, TypeW
}
}

private void MakeCollectionLiteralTypeInferences(BoundUnconvertedCollectionLiteralExpression argument, TypeWithAnnotations target, ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo)
{
if (target.Type is null)
{
return;
}

if (argument.Elements.Length == 0)
{
return;
}

// PROTOTYPE: Test all these cases, and types that are not constructible collection types.
if (ConversionsBase.GetCollectionLiteralTypeKind(_compilation, target.Type, out var targetElementType) == CollectionLiteralTypeKind.None)
{
return;
}

foreach (var element in argument.Elements)
{
if (element.Kind == BoundKind.UnconvertedCollectionLiteralExpression)
{
MakeCollectionLiteralTypeInferences((BoundUnconvertedCollectionLiteralExpression)element, targetElementType, ref useSiteInfo);
}
else
{
var elementType = _extensions.GetTypeWithAnnotations(element); // PROTOTYPE: Test cases where _extensions returns something other than element.Type.
if (IsReallyAType(elementType.Type)) // PROTOTYPE: Test cases where IsReallyAType() returns false.
{
LowerBoundInference(elementType, targetElementType, ref useSiteInfo);
}
}
}
}
#nullable disable

private bool MakeExplicitParameterTypeInferences(BoundTupleLiteral argument, TypeWithAnnotations target, ExactOrBoundsKind kind, ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo)
{
// try match up element-wise to the destination tuple (or underlying type)
Expand Down
4 changes: 3 additions & 1 deletion src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7228,9 +7228,11 @@ private static NullableAnnotation GetNullableAnnotation(BoundExpression expr)
case BoundKind.UnboundLambda:
case BoundKind.UnconvertedObjectCreationExpression:
case BoundKind.ConvertedTupleLiteral:
case BoundKind.UnconvertedCollectionLiteralExpression:
return NullableAnnotation.NotAnnotated;
default:
Debug.Assert(false); // unexpected value
// PROTOTYPE: Re-enable
//Debug.Assert(false); // unexpected value
return NullableAnnotation.Oblivious;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal sealed partial class LocalRewriter
case CollectionLiteralTypeKind.Span:
case CollectionLiteralTypeKind.ReadOnlySpan:
Debug.Assert(elementType is { });
return VisitArrayOrSpanCollectionLiteralExpression(node, elementType);
return VisitArrayOrSpanCollectionLiteralExpression(node, elementType.Type);
case CollectionLiteralTypeKind.ListInterface:
return VisitListInterfaceCollectionLiteralExpression(node);
default:
Expand Down
Loading

0 comments on commit 019790a

Please sign in to comment.