Skip to content

Commit

Permalink
fixes collection expression initialization when a implicit numeric co…
Browse files Browse the repository at this point in the history
…nversion is required (#256)
  • Loading branch information
adrianoc committed Sep 27, 2024
1 parent ab146cb commit 46b6e5a
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System;
using System.Text.RegularExpressions;
using Cecilifier.Core.Tests.Framework;
using NUnit.Framework;

Expand Down Expand Up @@ -51,5 +53,21 @@ public void ListOfT()
foreach(var c in list.ToArray()) System.Console.Write(c);
""",
"CECIL");
}

[Test]
public void ImplicitNumericConversions_Are_Applied([Values("List<long>", "long[]", "Span<long>")] string targetType, [Values("[2, 1]", "[5, 4, 3, 2, 1]")] string items)
{
AssertOutput(
$"""
using System.Collections.Generic;
using System;
{targetType} items = {items};
foreach(var c in items) System.Console.Write(c);
""",
Regex.Replace(items, @"\s+|\[|\]|,", ""),
"ReturnPtrToStack" // This is required only for spans.
);
}
}
9 changes: 8 additions & 1 deletion Cecilifier.Core/AST/ArrayInitializationProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
using System;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Mono.Cecil.Cil;

using Cecilifier.Core.CodeGeneration;
using Cecilifier.Core.Extensions;
using Microsoft.CodeAnalysis.Operations;

namespace Cecilifier.Core.AST;

public class ArrayInitializationProcessor
{
internal static void InitializeUnoptimized<TElement>(ExpressionVisitor visitor, ITypeSymbol elementType, SeparatedSyntaxList<TElement>? elements) where TElement : CSharpSyntaxNode
internal static void InitializeUnoptimized<TElement>(ExpressionVisitor visitor, ITypeSymbol elementType, SeparatedSyntaxList<TElement>? elements, IOperation parentOperation = null) where TElement : CSharpSyntaxNode
{
var context = visitor.Context;
var stelemOpCode = elementType.StelemOpCode();
Expand All @@ -22,11 +24,16 @@ internal static void InitializeUnoptimized<TElement>(ExpressionVisitor visitor,
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Ldc_I4, i);
elements.Value[i].Accept(visitor);

//TODO: Refactor to extract all this into some common method to apply conversions.
var itemType = context.SemanticModel.GetTypeInfo(elements.Value[i]);
if (elementType.IsReferenceType && itemType.Type != null && itemType.Type.IsValueType)
{
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Box, context.TypeResolver.Resolve(itemType.Type));
}
else if(parentOperation != null && parentOperation is ICollectionExpressionOperation ceo && ceo.Elements[i] is IConversionOperation conversion)
{
context.TryApplyNumericConversion(visitor.ILVariable, conversion.Operand.Type, conversion.Type);
}

context.EmitCilInstruction(visitor.ILVariable, stelemOpCode, stelemOpCode == OpCodes.Stelem_Any ? resolvedElementType : null);
}
Expand Down
22 changes: 17 additions & 5 deletions Cecilifier.Core/AST/CollectionExpressionProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
using Cecilifier.Core.Naming;
using Cecilifier.Core.Variables;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using Mono.Cecil.Cil;

namespace Cecilifier.Core.AST;
Expand Down Expand Up @@ -74,15 +76,14 @@ private static void HandleAssignmentToList(ExpressionVisitor visitor, Collection
var spanGetItemMethod = GetSpanIndexerGetMethod(context, resolvedListTypeArgument);
var stindOpCode = listOfTTypeSymbol.TypeArguments[0].StindOpCodeFor();

var collectionExpressionOperation = context.SemanticModel.GetOperation(node).EnsureNotNull<IOperation, ICollectionExpressionOperation>();
foreach (var element in node.Elements)
{
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Ldloca_S, spanToList.VariableName);
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Ldc_I4, index);
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Call, spanGetItemMethod);
visitor.Visit(element);
//TODO: Inject conversions if needed. Test with
// List<object> l = [1, 2, 3];
// List<long> l = [1, 2, 3];
ApplyNumericConversions(context, visitor.ILVariable, collectionExpressionOperation.Elements[index]);
context.EmitCilInstruction(visitor.ILVariable, stindOpCode);
index++;
}
Expand Down Expand Up @@ -111,14 +112,15 @@ private static void HandleAssignmentToSpan(ExpressionVisitor visitor, Collection
.MakeGenericInstanceMethod(context, "InlineArrayElementRef", [$"{inlineArrayLocalVar}.VariableType", context.TypeResolver.Resolve(spanTypeSymbol.TypeArguments[0])]);

var storeOpCode = inlineArrayElementType.StindOpCodeFor();
var collectionExpressionOperation = context.SemanticModel.GetOperation(node).EnsureNotNull<IOperation, ICollectionExpressionOperation>();
var index = 0;
foreach (var element in node.Elements)
{
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Ldloca_S, inlineArrayLocalVar);
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Ldc_I4, index);
context.EmitCilInstruction(visitor.ILVariable, OpCodes.Call, inlineArrayElementRefMethodVar);
visitor.Visit(element);

ApplyNumericConversions(context, visitor.ILVariable, collectionExpressionOperation.Elements[index]);
context.EmitCilInstruction(visitor.ILVariable, storeOpCode, storeOpCode == OpCodes.Ldobj ? context.TypeResolver.Resolve(inlineArrayElementType) : null);
index++;
}
Expand Down Expand Up @@ -196,7 +198,7 @@ private static void HandleAssignmentToArray(ExpressionVisitor visitor, Collectio
if (PrivateImplementationDetailsGenerator.IsApplicableTo(node))
ArrayInitializationProcessor.InitializeOptimized(visitor, arrayTypeSymbol.ElementType, node.Elements);
else
ArrayInitializationProcessor.InitializeUnoptimized<CollectionElementSyntax>(visitor, arrayTypeSymbol.ElementType, node.Elements);
ArrayInitializationProcessor.InitializeUnoptimized<CollectionElementSyntax>(visitor, arrayTypeSymbol.ElementType, node.Elements, visitor.Context.SemanticModel.GetOperation(node));
}

private static string GetSpanIndexerGetMethod(IVisitorContext context, string typeArgument)
Expand All @@ -212,4 +214,14 @@ private static string GetSpanIndexerGetMethod(IVisitorContext context, string ty

return methodVar;
}

static void ApplyNumericConversions(IVisitorContext context, string ilVar, IOperation operation)
{
if (operation is not IConversionOperation { Conversion.IsNumeric: true } elementConversion)
return;

var result = context.TryApplyNumericConversion(ilVar, operation.Type, elementConversion.Type);
if (!result)
throw new Exception();
}
}
36 changes: 4 additions & 32 deletions Cecilifier.Core/AST/ExpressionVisitor.Conversions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Diagnostics;
using System.Linq;
using Cecilifier.Core.Extensions;
using Cecilifier.Core.Misc;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand All @@ -12,10 +11,11 @@ namespace Cecilifier.Core.AST;

partial class ExpressionVisitor
{
private void InjectRequiredConversions(ExpressionSyntax expression, Action loadArrayIntoStack = null)
public void InjectRequiredConversions(ExpressionSyntax expression, Action loadArrayIntoStack = null)
{
var typeInfo = ModelExtensions.GetTypeInfo(Context.SemanticModel, expression);
if (typeInfo.Type == null) return;

var conversion = Context.SemanticModel.GetConversion(expression);
if (conversion.IsImplicit)
{
Expand All @@ -30,37 +30,9 @@ private void InjectRequiredConversions(ExpressionSyntax expression, Action loadA

if (conversion.IsNumeric)
{
Debug.Assert(typeInfo.ConvertedType != null);
switch (typeInfo.ConvertedType.SpecialType)
if (!Context.TryApplyNumericConversion(ilVar, typeInfo.Type, typeInfo.ConvertedType))
{
case SpecialType.System_Single:
Context.EmitCilInstruction(ilVar, OpCodes.Conv_R4);
return;
case SpecialType.System_Double:
Context.EmitCilInstruction(ilVar, OpCodes.Conv_R8);
return;
case SpecialType.System_Byte:
Context.EmitCilInstruction(ilVar, OpCodes.Conv_I1);
return;
case SpecialType.System_Int16:
Context.EmitCilInstruction(ilVar, OpCodes.Conv_I2);
return;
case SpecialType.System_Int32:
// byte/char are pushed as Int32 by the runtime
if (typeInfo.Type.SpecialType != SpecialType.System_SByte && typeInfo.Type.SpecialType != SpecialType.System_Byte && typeInfo.Type.SpecialType != SpecialType.System_Char)
Context.EmitCilInstruction(ilVar, OpCodes.Conv_I4);
return;
case SpecialType.System_Int64:
var convOpCode = typeInfo.Type.SpecialType == SpecialType.System_Char || typeInfo.Type.SpecialType == SpecialType.System_Byte ? OpCodes.Conv_U8 : OpCodes.Conv_I8;
Context.EmitCilInstruction(ilVar, convOpCode);
return;
case SpecialType.System_Decimal:
var operand = typeInfo.ConvertedType.GetMembers().OfType<IMethodSymbol>()
.Single(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length == 1 && m.Parameters[0].Type.SpecialType == typeInfo.Type.SpecialType);
Context.EmitCilInstruction(ilVar, OpCodes.Newobj, operand.MethodResolverExpression(Context));
return;
default:
throw new Exception($"Conversion from {typeInfo.Type} to {typeInfo.ConvertedType} not implemented.");
throw new Exception($"Conversion from {typeInfo.Type} to {typeInfo.ConvertedType} not implemented.");
}
}

Expand Down
15 changes: 11 additions & 4 deletions Cecilifier.Core/AST/StatementVisitor.ForEach.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System;
using System.Linq;
using Cecilifier.Core.Extensions;
using Cecilifier.Core.Misc;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Mono.Cecil.Cil;

Expand Down Expand Up @@ -78,9 +78,9 @@ void ProcessForEachOverEnumerable()
{
ProcessWithInTryCatchFinallyBlock(
_ilVar,
context => ProcessForEach((ForEachHandlerContext) context, node),
Array.Empty<CatchClauseSyntax>(),
context => ProcessForEachFinally((ForEachHandlerContext) context),
context => ProcessForEach(context, node),
[],
ProcessForEachFinally,
context);
}
else
Expand Down Expand Up @@ -129,6 +129,13 @@ private void ProcessForEach(ForEachHandlerContext forEachHandlerContext, ForEach

Context.EmitCilInstruction(_ilVar, loadOpCode, forEachHandlerContext.EnumeratorVariableName);
AddMethodCall(_ilVar, forEachHandlerContext.EnumeratorCurrentMethod);

var info = Context.SemanticModel.GetForEachStatementInfo(node);
if (!node.Type.IsKind(SyntaxKind.RefType) && info.CurrentProperty.ReturnsByRef)
{
Context.EmitCilInstruction(_ilVar, info.ElementType.LdindOpCodeFor());
}

Context.EmitCilInstruction(_ilVar, OpCodes.Stloc, foreachCurrentValueVarName);

// process body of foreach
Expand Down
6 changes: 3 additions & 3 deletions Cecilifier.Core/AST/StatementVisitor.TryCatchFinally.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ namespace Cecilifier.Core.AST
{
internal partial class StatementVisitor
{
private void ProcessTryCatchFinallyBlock(string ilVar, CSharpSyntaxNode tryStatement, CatchClauseSyntax[] catches, Action<object> finallyBlockHandler, object state = null)
private void ProcessTryCatchFinallyBlock<TState>(string ilVar, CSharpSyntaxNode tryStatement, CatchClauseSyntax[] catches, Action<TState> finallyBlockHandler, TState state = default)
{
ProcessWithInTryCatchFinallyBlock(ilVar, _ => tryStatement.Accept(this), catches, finallyBlockHandler, state);
}

private void ProcessWithInTryCatchFinallyBlock(string ilVar, Action<object> toProcess, CatchClauseSyntax[] catches, Action<object> finallyBlockHandler, object state)
private void ProcessWithInTryCatchFinallyBlock<TState>(string ilVar, Action<TState> toProcess, CatchClauseSyntax[] catches, Action<TState> finallyBlockHandler, TState state)
{
var exceptionHandlerTable = new ExceptionHandlerEntry[catches.Length + (finallyBlockHandler != null ? 1 : 0)];

Expand Down Expand Up @@ -88,7 +88,7 @@ private void HandleCatchClause(string ilVar, CatchClauseSyntax node, ExceptionHa
Context.EmitCilInstruction(ilVar, OpCodes.Leave, firstInstructionAfterTryCatchBlock);
}

private void HandleFinallyClause(string ilVar, Action<object> finallyBlockHandler, ExceptionHandlerEntry[] exceptionHandlerTable, object state)
private void HandleFinallyClause<TState>(string ilVar, Action<TState> finallyBlockHandler, ExceptionHandlerEntry[] exceptionHandlerTable, TState state)
{
if (finallyBlockHandler == null)
return;
Expand Down
4 changes: 2 additions & 2 deletions Cecilifier.Core/AST/StatementVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public override void VisitUsingStatement(UsingStatementSyntax node)
localVarDef = StoreTopOfStackInLocalVariable(Context, _ilVar, "tmp", usingType).VariableName;
}

void FinallyBlockHandler(object _)
void FinallyBlockHandler(ForEachHandlerContext _)
{
string? lastFinallyInstructionLabel = null;
if (usingType.TypeKind == TypeKind.TypeParameter || usingType.IsValueType)
Expand All @@ -278,7 +278,7 @@ void FinallyBlockHandler(object _)
AddCecilExpression($"{_ilVar}.Append({lastFinallyInstructionLabel});");
}

ProcessTryCatchFinallyBlock(_ilVar, node.Statement, Array.Empty<CatchClauseSyntax>(), FinallyBlockHandler);
ProcessTryCatchFinallyBlock<ForEachHandlerContext>(_ilVar, node.Statement, Array.Empty<CatchClauseSyntax>(), FinallyBlockHandler);
}

public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node) => node.Accept(new MethodDeclarationVisitor(Context));
Expand Down
2 changes: 1 addition & 1 deletion Cecilifier.Core/AST/SyntaxWalkerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ bool HandleCallOnValueType()
// in this case we need to call System.Index.GetOffset(int32) on a value type (System.Index)
// which requires the address of the value type.
var isSystemIndexUsedAsIndex = IsSystemIndexUsedAsIndex(symbol, parentNode);
var usageResult = parentNode.Accept(UsageVisitor.GetInstance(Context));
var usageResult = parentNode.Accept(UsageVisitor.GetInstance(Context).WithTargetNode(node));
if (isSystemIndexUsedAsIndex || parentNode.IsKind(SyntaxKind.AddressOfExpression) || IsPseudoAssignmentToValueType() || node.IsMemberAccessOnElementAccess() || usageResult.Kind == UsageKind.CallTarget)
{
Context.EmitCilInstruction(ilVar, opCode, operand);
Expand Down
28 changes: 26 additions & 2 deletions Cecilifier.Core/AST/UsageVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@ public static UsageVisitor GetInstance(IVisitorContext context)
return _instance;
}

public UsageVisitor WithTargetNode(CSharpSyntaxNode node)
{
_targetNode = node;
return this;
}

internal static void ResetInstance() => _instance = null;

private static UsageVisitor _instance;

private readonly IVisitorContext context;
private CSharpSyntaxNode _targetNode;

private UsageVisitor(IVisitorContext context)
{
Expand All @@ -31,20 +38,37 @@ private UsageVisitor(IVisitorContext context)
public override UsageResult VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
{
var t = context.SemanticModel.GetSymbolInfo(node);
if (node?.Parent.IsKind(SyntaxKind.InvocationExpression) == true)
if (node.Parent.IsKind(SyntaxKind.InvocationExpression))
return new UsageResult(UsageKind.CallTarget, t.Symbol);

var kind = t.Symbol?.Kind is SymbolKind.Property or SymbolKind.Event or SymbolKind.Method
? UsageKind.CallTarget
: UsageKind.None;

return new UsageResult(kind, t.Symbol);
return NewUsageResult(kind, t.Symbol);
}

public override UsageResult VisitElementAccessExpression(ElementAccessExpressionSyntax node)
{
var symbol = context.SemanticModel.GetSymbolInfo(node).Symbol as IPropertySymbol;
var kind = symbol?.IsIndexer == true ? UsageKind.CallTarget : UsageKind.None;
return NewUsageResult(kind, symbol);
}

public override UsageResult VisitForEachStatement(ForEachStatementSyntax node)
{
if (node.Expression != _targetNode)
return NewUsageResult(UsageKind.None, null);

// if _targetNode is the enumerable (i.e, the `Expression`) in the foreach
// it means we will end up calling `GetEnumerator()` on it.
var symbol = context.SemanticModel.GetForEachStatementInfo(node);
return NewUsageResult(UsageKind.CallTarget, symbol.GetEnumeratorMethod);
}

private UsageResult NewUsageResult(UsageKind kind, ISymbol symbol)
{
_targetNode = null;
return new UsageResult(kind, symbol);
}
}
40 changes: 40 additions & 0 deletions Cecilifier.Core/Extensions/CecilifierContextExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using System;
using System.Linq;
using Cecilifier.Core.AST;
using Cecilifier.Core.Misc;
using Cecilifier.Core.Naming;
using Cecilifier.Core.Variables;
using Microsoft.CodeAnalysis;
using Mono.Cecil.Cil;

namespace Cecilifier.Core.Extensions;

Expand Down Expand Up @@ -35,4 +38,41 @@ internal static DefinitionVariable AddLocalVariableToMethod(this IVisitorContext

return context.DefinitionVariables.RegisterNonMethod(string.Empty, localVarName, VariableMemberKind.LocalVariable, cecilVarDeclName);
}

internal static bool TryApplyNumericConversion(this IVisitorContext context, string ilVar, ITypeSymbol source, ITypeSymbol target)
{
switch (target.SpecialType)
{
case SpecialType.System_Single:
context.EmitCilInstruction(ilVar, OpCodes.Conv_R4);
break;
case SpecialType.System_Double:
context.EmitCilInstruction(ilVar, OpCodes.Conv_R8);
break;
case SpecialType.System_Byte:
context.EmitCilInstruction(ilVar, OpCodes.Conv_I1);
break;
case SpecialType.System_Int16:
context.EmitCilInstruction(ilVar, OpCodes.Conv_I2);
break;
case SpecialType.System_Int32:
// byte/char are pushed as Int32 by the runtime
if (source.SpecialType != SpecialType.System_SByte && source.SpecialType != SpecialType.System_Byte && source.SpecialType != SpecialType.System_Char)
context.EmitCilInstruction(ilVar, OpCodes.Conv_I4);
break;
case SpecialType.System_Int64:
var convOpCode = source.SpecialType == SpecialType.System_Char || source.SpecialType == SpecialType.System_Byte ? OpCodes.Conv_U8 : OpCodes.Conv_I8;
context.EmitCilInstruction(ilVar, convOpCode);
break;
case SpecialType.System_Decimal:
var operand = target.GetMembers().OfType<IMethodSymbol>()
.Single(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length == 1 && m.Parameters[0].Type.SpecialType == source.SpecialType);
context.EmitCilInstruction(ilVar, OpCodes.Newobj, operand.MethodResolverExpression(context));
break;

default: return false;
}

return true;
}
}

0 comments on commit 46b6e5a

Please sign in to comment.