From ac40c6151135df6f043a19b6cc1a0e08400876a9 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 27 Apr 2023 17:21:23 +0200 Subject: [PATCH] WIP Closes #30732 Closes #30734 --- .../Properties/RelationalStrings.Designer.cs | 18 ++- .../Properties/RelationalStrings.resx | 9 +- .../Query/ISqlExpressionFactory.cs | 15 +- .../Query/Internal/ContainsTranslator.cs | 31 +++- ...lExpressionSimplifyingExpressionVisitor.cs | 21 +-- .../Query/QuerySqlGenerator.cs | 24 ++- .../Query/RelationalQueryRootProcessor.cs | 2 +- ...yableMethodTranslatingExpressionVisitor.cs | 71 ++++---- ...lationalSqlTranslatingExpressionVisitor.cs | 94 ++++++----- .../Query/SqlExpressionFactory.cs | 103 ++++++++---- .../Query/SqlExpressions/InExpression.cs | 152 ++++++++++++------ .../Query/SqlExpressions/ValuesExpression.cs | 20 +-- .../Query/SqlNullabilityProcessor.cs | 135 ++++++++-------- ...rchConditionConvertingExpressionVisitor.cs | 30 +++- ...yableMethodTranslatingExpressionVisitor.cs | 2 +- src/EFCore/Query/InlineQueryRootExpression.cs | 13 +- .../ParameterExtractingExpressionVisitor.cs | 25 ++- src/EFCore/Query/QueryRootProcessor.cs | 77 ++++----- .../Query/UdfDbFunctionTestBase.cs | 6 +- .../PrimitiveCollectionsQuerySqlServerTest.cs | 12 +- 20 files changed, 526 insertions(+), 334 deletions(-) diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 0fe96122904..882bdef1e60 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -53,6 +53,14 @@ public static string BadSequenceType public static string CannotChangeWhenOpen => GetString("CannotChangeWhenOpen"); + /// + /// The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + /// + public static string CannotTranslateNonConstantNewArrayExpression(object? newArrayExpression) + => string.Format( + GetString("CannotTranslateNonConstantNewArrayExpression", nameof(newArrayExpression)), + newArrayExpression); + /// /// Can't configure a trigger on entity type '{entityType}', which is in a TPH hierarchy and isn't the root. Configure the trigger on the TPH root entity type '{rootEntityType}' instead. /// @@ -622,12 +630,12 @@ public static string DuplicateSeedDataSensitive(object? entityType, object? keyV entityType, keyValue, table); /// - /// Either {param1} or {param2} must be null. + /// Exactly one of '{param1}', '{param2}' or '{param3}' must be set. /// - public static string EitherOfTwoValuesMustBeNull(object? param1, object? param2) + public static string OneOfThreeValuesMustBeSet(object? param1, object? param2, object? param3) => string.Format( - GetString("EitherOfTwoValuesMustBeNull", nameof(param1), nameof(param2)), - param1, param2); + GetString("OneOfThreeValuesMustBeSet", nameof(param1), nameof(param2), nameof(param3)), + param1, param2, param3); /// /// Empty collections are not supported as constant query roots. @@ -1310,7 +1318,7 @@ public static string NoDbCommand => GetString("NoDbCommand"); /// - /// Expression of type '{type}' isn't supported as the Values of an InExpression; only constants and parameters are supported. + /// Expression of type '{type}' isn't supported in the Values of an InExpression; only constants and parameters are supported. /// public static string NonConstantOrParameterAsInExpressionValues(object? type) => string.Format( diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index 938c9f6d455..bc819373c59 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -130,6 +130,9 @@ The instance of DbConnection is currently in use. The connection can only be changed when the existing connection is not being used. + + The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + Can't configure a trigger on entity type '{entityType}', which is in a TPH hierarchy and isn't the root. Configure the trigger on the TPH root entity type '{rootEntityType}' instead. @@ -346,8 +349,8 @@ A seed entity for entity type '{entityType}' has the same key value {keyValue} as another seed entity mapped to the same table '{table}'. Key values should be unique across seed entities. - - Either {param1} or {param2} must be null. + + Exactly one of '{param1}', '{param2}' or '{param3}' must be set. Empty collections are not supported as inline query roots. @@ -912,7 +915,7 @@ Cannot create a DbCommand for a non-relational query. - Expression of type '{type}' isn't supported as the Values of an InExpression; only constants and parameters are supported. + Expression of type '{type}' isn't supported in the Values of an InExpression; only constants and parameters are supported. 'FindMapping' was called on a 'RelationalTypeMappingSource' with a non-relational 'TypeMappingInfo'. diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs index 8780256848d..06537d10c0e 100644 --- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs @@ -403,6 +403,15 @@ SqlFunctionExpression NiladicFunction( /// An expression representing an EXISTS operation in a SQL tree. ExistsExpression Exists(SelectExpression subquery, bool negated); + /// + /// Creates a new which represents an IN operation in a SQL tree. + /// + /// An item to look into values. + /// A subquery in which item is searched. + /// A value indicating if the item should be present in the values or absent. + /// An expression representing an IN operation in a SQL tree. + InExpression In(SqlExpression item, SelectExpression subquery, bool negated); + /// /// Creates a new which represents an IN operation in a SQL tree. /// @@ -410,16 +419,16 @@ SqlFunctionExpression NiladicFunction( /// A list of values in which item is searched. /// A value indicating if the item should be present in the values or absent. /// An expression representing an IN operation in a SQL tree. - InExpression In(SqlExpression item, SqlExpression values, bool negated); + InExpression In(SqlExpression item, IReadOnlyList values, bool negated); /// /// Creates a new which represents an IN operation in a SQL tree. /// /// An item to look into values. - /// A subquery in which item is searched. + /// A parameterized list of values in which the item is searched. /// A value indicating if the item should be present in the values or absent. /// An expression representing an IN operation in a SQL tree. - InExpression In(SqlExpression item, SelectExpression subquery, bool negated); + InExpression In(SqlExpression item, SqlParameterExpression valuesParameter, bool negated); /// /// Creates a new which represents a LIKE in a SQL tree. diff --git a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs index 0e523bf70a9..523a9405c6d 100644 --- a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query.Internal; @@ -38,11 +39,13 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { + SqlExpression? itemExpression = null, valuesExpression = null; + // SqlExpression? values = null; if (method.IsGenericMethod - && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains) + && method.GetGenericMethodDefinition() == EnumerableMethods.Contains && ValidateValues(arguments[0])) { - return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[1]), arguments[0], negated: false); + (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[1]), arguments[0]); } if (arguments.Count == 1 @@ -50,14 +53,34 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) && instance != null && ValidateValues(instance)) { - return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[0]), instance, negated: false); + (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[0]), instance); + } + + if (itemExpression is not null && valuesExpression is not null) + { + switch (valuesExpression) + { + case SqlParameterExpression parameter: + return _sqlExpressionFactory.In(itemExpression, parameter, negated: false); + + case SqlConstantExpression { Value: IEnumerable values }: + var valuesExpressions = new List(); + + foreach (var value in values) + { + // TODO: Type mapping? + valuesExpressions.Add(_sqlExpressionFactory.Constant(value, itemExpression.TypeMapping)); + } + + return _sqlExpressionFactory.In(itemExpression, valuesExpressions, negated: false); + } } return null; } private static bool ValidateValues(SqlExpression values) - => values is SqlConstantExpression || values is SqlParameterExpression; + => values is SqlConstantExpression or SqlParameterExpression; private static SqlExpression RemoveObjectConvert(SqlExpression expression) => expression is SqlUnaryExpression sqlUnaryExpression diff --git a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs index 35f566a35e9..a5638e81505 100644 --- a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs @@ -245,18 +245,11 @@ or ExpressionType.LessThan && leftCandidateInfo.ColumnExpression == rightCandidateInfo.ColumnExpression && leftCandidateInfo.OperationType == rightCandidateInfo.OperationType) { - var leftConstantIsEnumerable = leftCandidateInfo.ConstantValue is IEnumerable - && !(leftCandidateInfo.ConstantValue is string) - && !(leftCandidateInfo.ConstantValue is byte[]); - - var rightConstantIsEnumerable = rightCandidateInfo.ConstantValue is IEnumerable - && !(rightCandidateInfo.ConstantValue is string) - && !(rightCandidateInfo.ConstantValue is byte[]); - - if ((leftCandidateInfo.OperationType == ExpressionType.Equal - && sqlBinaryExpression.OperatorType == ExpressionType.OrElse) - || (leftCandidateInfo.OperationType == ExpressionType.NotEqual - && sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)) + var leftConstantIsEnumerable = leftCandidateInfo.ConstantValue is IEnumerable and not string and not byte[]; + var rightConstantIsEnumerable = rightCandidateInfo.ConstantValue is IEnumerable and not string and not byte[]; + + if ((leftCandidateInfo.OperationType, sqlBinaryExpression.OperatorType) is + (ExpressionType.Equal, ExpressionType.OrElse) or (ExpressionType.NotEqual, ExpressionType.AndAlso)) { object leftValue; object rightValue; @@ -309,7 +302,7 @@ or ExpressionType.LessThan return _sqlExpressionFactory.In( leftCandidateInfo.ColumnExpression, - _sqlExpressionFactory.Constant(resultArray, leftCandidateInfo.TypeMapping), + resultArray.Select(r => _sqlExpressionFactory.Constant(r, leftCandidateInfo.TypeMapping)).ToArray(), leftCandidateInfo.OperationType == ExpressionType.NotEqual); } @@ -323,7 +316,7 @@ or ExpressionType.LessThan return _sqlExpressionFactory.In( leftCandidateInfo.ColumnExpression, - _sqlExpressionFactory.Constant(resultArray, leftCandidateInfo.TypeMapping), + resultArray.Select(r => _sqlExpressionFactory.Constant(r, leftCandidateInfo.TypeMapping)).ToArray(), leftCandidateInfo.OperationType == ExpressionType.NotEqual); } } diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index efac1cd435b..b42892cdebf 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -927,31 +927,29 @@ protected override Expression VisitExists(ExistsExpression existsExpression) /// protected override Expression VisitIn(InExpression inExpression) { - if (inExpression.Values != null) + Check.DebugAssert(inExpression.ValuesParameter is null, "inExpression.ValuesParameter is null"); + + Visit(inExpression.Item); + _relationalCommandBuilder.Append(inExpression.IsNegated ? " NOT IN (" : " IN ("); + + if (inExpression.Values is not null) { - Visit(inExpression.Item); - _relationalCommandBuilder.Append(inExpression.IsNegated ? " NOT IN " : " IN "); - _relationalCommandBuilder.Append("("); - var valuesConstant = (SqlConstantExpression)inExpression.Values; - var valuesList = ((IEnumerable)valuesConstant.Value!) - .Select(v => new SqlConstantExpression(Expression.Constant(v), valuesConstant.TypeMapping)).ToList(); - GenerateList(valuesList, e => Visit(e)); - _relationalCommandBuilder.Append(")"); + GenerateList(inExpression.Values, e => Visit(e)); } else { - Visit(inExpression.Item); - _relationalCommandBuilder.Append(inExpression.IsNegated ? " NOT IN " : " IN "); - _relationalCommandBuilder.AppendLine("("); + _relationalCommandBuilder.AppendLine(); using (_relationalCommandBuilder.Indent()) { Visit(inExpression.Subquery); } - _relationalCommandBuilder.AppendLine().Append(")"); + _relationalCommandBuilder.AppendLine(); } + _relationalCommandBuilder.Append(")"); + return inExpression; } diff --git a/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs b/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs index d2c218bb5c9..10ba014ce94 100644 --- a/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs @@ -30,7 +30,7 @@ public RelationalQueryRootProcessor( /// Indicates that a can be converted to a ; /// this will later be translated to a SQL . /// - protected override bool ShouldConvertToInlineQueryRoot(ConstantExpression constantExpression) + protected override bool ShouldConvertToInlineQueryRoot(NewArrayExpression newArrayExpression) => true; /// diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 536749f75ce..297792169a3 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; +using System.Reflection.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -207,8 +208,8 @@ when entityQueryRootExpression.GetType() == typeof(EntityQueryRootExpression) return new ShapedQueryExpression(selectExpression, shaperExpression); } - case InlineQueryRootExpression constantQueryRootExpression: - return VisitInlineQueryRoot(constantQueryRootExpression) ?? base.VisitExtension(extensionExpression); + case InlineQueryRootExpression inlineQueryRootExpression: + return VisitInlineQueryRoot(inlineQueryRootExpression) ?? base.VisitExtension(extensionExpression); case ParameterQueryRootExpression parameterQueryRootExpression: var sqlParameterExpression = @@ -319,26 +320,34 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp for (var i = 0; i < inlineQueryRootExpression.Values.Count; i++) { - var value = inlineQueryRootExpression.Values[i]; - - // We currently support constants only; supporting non-constant values in VALUES is tracked by #30734. - if (value is not ConstantExpression constantExpression) + // Note that we specifically don't apply the default type mapping to the translation, to allow it to get inferred later based + // on usage. + if (TranslateExpression(inlineQueryRootExpression.Values[i], applyDefaultTypeMapping: false) + is not SqlExpression translatedValue) { - AddTranslationErrorDetails(RelationalStrings.OnlyConstantsSupportedInInlineCollectionQueryRoots); return null; } - if (constantExpression.Value is null) + // We currently support only constants and parameters in VALUES, see #30734 + if (translatedValue is not SqlConstantExpression and not SqlParameterExpression) { - encounteredNull = true; + AddTranslationErrorDetails(RelationalStrings.NonConstantOrParameterAsInExpressionValues(translatedValue.GetType().Name)); + return null; } + // TODO: Poor man's null semantics - we currently only support constants and parameters in SqlNullabilityProcessor, where we + // can see if there's actually a null value or not. When we allow arbitrary expressions ()the ValuesExpression projects out a non-nullable column if we see only non-null constants + // or nullable columns. + // This should be handled properly, possibly in SqlNullabilityProcessor (e.g. any complex expression is assumed to be nullable). + encounteredNull |= + translatedValue is not SqlConstantExpression { Value: not null } and not ColumnExpression { IsNullable: false }; + rowExpressions.Add(new RowValueExpression(new[] { // Since VALUES may not guarantee row ordering, we add an _ord value by which we'll order. _sqlExpressionFactory.Constant(i, intTypeMapping), // Note that for the actual value, we must leave the type mapping null to allow it to get inferred later based on usage - _sqlExpressionFactory.Constant(constantExpression.Value, elementType, typeMapping: null) + translatedValue })); } @@ -1763,10 +1772,13 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( /// Translates the given expression into equivalent SQL representation. /// /// An expression to translate. + /// + /// Whether to apply the default type mapping on the top-most element if it has none. Defaults to . + /// /// A which is translation of given expression or . - protected virtual SqlExpression? TranslateExpression(Expression expression) + protected virtual SqlExpression? TranslateExpression(Expression expression, bool applyDefaultTypeMapping = true) { - var translation = _sqlTranslator.Translate(expression); + var translation = _sqlTranslator.Translate(expression, applyDefaultTypeMapping); if (translation is null) { @@ -1802,7 +1814,7 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( protected virtual Expression ApplyInferredTypeMappings( Expression expression, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping> inferredTypeMappings) - => new RelationalInferredTypeMappingApplier(inferredTypeMappings).Visit(expression); + => new RelationalInferredTypeMappingApplier(_sqlExpressionFactory, inferredTypeMappings).Visit(expression); /// /// Determines whether the given is ordered, typically because orderings have been added to it. @@ -1870,22 +1882,14 @@ private bool TrySimplifyValuesToInExpression( return false; } - var values = new object?[valuesExpression.RowValues.Count]; + var values = new SqlExpression[valuesExpression.RowValues.Count]; for (var i = 0; i < values.Length; i++) { // Skip the first value (_ord), which is irrelevant for Contains - if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue }) - { - values[i] = constantValue; - } - else - { - simplifiedQuery = null; - return false; - } + values[i] = valuesExpression.RowValues[i].Values[1]; } - var inExpression = _sqlExpressionFactory.In(item, _sqlExpressionFactory.Constant(values), isNegated); + var inExpression = _sqlExpressionFactory.In(item, values, isNegated); simplifiedQuery = source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression); return true; } @@ -2751,6 +2755,7 @@ private void RegisterInferredTypeMapping(ColumnExpression columnExpression, Rela /// protected class RelationalInferredTypeMappingApplier : ExpressionVisitor { + private readonly ISqlExpressionFactory _sqlExpressionFactory; private SelectExpression? _currentSelectExpression; /// @@ -2761,10 +2766,15 @@ protected class RelationalInferredTypeMappingApplier : ExpressionVisitor /// /// Creates a new instance of the class. /// + /// The SQL expression factory. /// The inferred type mappings to be applied back on their query roots. public RelationalInferredTypeMappingApplier( + ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping> inferredTypeMappings) - => InferredTypeMappings = inferredTypeMappings; + { + _sqlExpressionFactory = sqlExpressionFactory; + InferredTypeMappings = inferredTypeMappings; + } /// protected override Expression VisitExtension(Expression expression) @@ -2831,30 +2841,27 @@ protected virtual ValuesExpression ApplyTypeMappingsOnValuesExpression(ValuesExp var newValues = new SqlExpression[newColumnNames.Count]; for (var j = 0; j < valuesExpression.ColumnNames.Count; j++) { - Check.DebugAssert(rowValue.Values[j] is SqlConstantExpression, "Non-constant SqlExpression in ValuesExpression"); - if (j == 0 && stripOrdering) { continue; } - var value = (SqlConstantExpression)rowValue.Values[j]; - SqlExpression newValue = value; + var value = rowValue.Values[j]; var inferredTypeMapping = inferredTypeMappings[j]; if (inferredTypeMapping is not null && value.TypeMapping is null) { - newValue = new SqlConstantExpression(Expression.Constant(value.Value, value.Type), inferredTypeMapping); + value = _sqlExpressionFactory.ApplyTypeMapping(value, inferredTypeMapping); // We currently add explicit conversions on the first row, to ensure that the inferred types are properly typed. // See #30605 for removing that when not needed. if (i == 0) { - newValue = new SqlUnaryExpression(ExpressionType.Convert, newValue, newValue.Type, newValue.TypeMapping); + value = new SqlUnaryExpression(ExpressionType.Convert, value, value.Type, value.TypeMapping); } } - newValues[j - (stripOrdering ? 1 : 0)] = newValue; + newValues[j - (stripOrdering ? 1 : 0)] = value; } newRowValues[i] = new RowValueExpression(newValues); diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 38569b6d852..3c945e11873 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -118,33 +118,38 @@ protected virtual void AddTranslationErrorDetails(string details) /// Translates an expression to an equivalent SQL representation. /// /// An expression to translate. + /// + /// Whether to apply the default type mapping on the top-most element if it has none. Defaults to . + /// /// A SQL translation of the given expression. - public virtual SqlExpression? Translate(Expression expression) + public virtual SqlExpression? Translate(Expression expression, bool applyDefaultTypeMapping = true) { TranslationErrorDetails = null; - return TranslateInternal(expression); + return TranslateInternal(expression, applyDefaultTypeMapping); } - private SqlExpression? TranslateInternal(Expression expression) + private SqlExpression? TranslateInternal(Expression expression, bool applyDefaultTypeMapping = true) { var result = Visit(expression); if (result is SqlExpression translation) { - if (translation is SqlUnaryExpression sqlUnaryExpression - && sqlUnaryExpression.OperatorType == ExpressionType.Convert + if (translation is SqlUnaryExpression { OperatorType: ExpressionType.Convert } sqlUnaryExpression && sqlUnaryExpression.Type == typeof(object)) { translation = sqlUnaryExpression.Operand; } - translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation); - - if (translation.TypeMapping == null) + if (applyDefaultTypeMapping) { - // The return type is not-mappable hence return null - return null; + translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation); + + if (translation.TypeMapping == null) + { + // The return type is not-mappable hence return null + return null; + } } return translation; @@ -738,7 +743,9 @@ protected override Expression VisitMember(MemberExpression memberExpression) /// protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) - => GetConstantOrNotTranslated(memberInitExpression); + => TryEvaluateToConstant(memberInitExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -997,11 +1004,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override Expression VisitNew(NewExpression newExpression) - => GetConstantOrNotTranslated(newExpression); + => TryEvaluateToConstant(newExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) - => QueryCompilationContext.NotTranslatedExpression; + { + if (TryEvaluateToConstant(newArrayExpression, out var sqlConstantExpression)) + { + return sqlConstantExpression; + } + + AddTranslationErrorDetails(RelationalStrings.CannotTranslateNonConstantNewArrayExpression(newArrayExpression.Print())); + return QueryCompilationContext.NotTranslatedExpression; + } /// protected override Expression VisitParameter(ParameterExpression parameterExpression) @@ -1090,7 +1107,7 @@ SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionEx _sqlExpressionFactory.Constant(discriminatorValues[0])) : _sqlExpressionFactory.In( entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValues), + discriminatorValues.Select(d => _sqlExpressionFactory.Constant(d)).ToArray(), negated: false); } } @@ -1110,8 +1127,7 @@ SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionEx _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) : _sqlExpressionFactory.In( discriminatorColumn, - _sqlExpressionFactory.Constant( - concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), + concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray(), negated: false); } } @@ -1568,16 +1584,23 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left, .Aggregate((a, b) => Expression.AndAlso(a, b)); } - private static Expression GetConstantOrNotTranslated(Expression expression) - => CanEvaluate(expression) - ? new SqlConstantExpression( + private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(true)] out SqlConstantExpression? sqlConstantExpression) + { + if (CanEvaluate(expression)) + { + sqlConstantExpression = new SqlConstantExpression( Expression.Constant( Expression.Lambda>(Expression.Convert(expression, typeof(object))) .Compile(preferInterpretation: true) .Invoke(), expression.Type), - null) - : QueryCompilationContext.NotTranslatedExpression; + null); + return true; + } + + sqlConstantExpression = null; + return false; + } private bool TryRewriteContainsEntity(Expression source, Expression item, [NotNullWhen(true)] out Expression? result) { @@ -1861,26 +1884,15 @@ when memberInitExpression.Bindings.SingleOrDefault( } private static bool CanEvaluate(Expression expression) - { -#pragma warning disable IDE0066 // Convert switch statement to expression - switch (expression) -#pragma warning restore IDE0066 // Convert switch statement to expression - { - case ConstantExpression: - return true; - - case NewExpression newExpression: - return newExpression.Arguments.All(e => CanEvaluate(e)); - - case MemberInitExpression memberInitExpression: - return CanEvaluate(memberInitExpression.NewExpression) - && memberInitExpression.Bindings.All( - mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); - - default: - return false; - } - } + => expression switch + { + ConstantExpression => true, + NewExpression e => e.Arguments.All(CanEvaluate), + NewArrayExpression e => e.Expressions.All(CanEvaluate), + MemberInitExpression e => CanEvaluate(e.NewExpression) + && e.Bindings.All(mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)), + _ => false + }; private static bool IsNullSqlConstantExpression(Expression expression) => expression is SqlConstantExpression sqlConstant && sqlConstant.Value == null; diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 6dc46ca1058..a79de39ada5 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -229,28 +229,72 @@ private SqlExpression ApplyTypeMappingOnSqlBinary( resultTypeMapping); } - private SqlExpression ApplyTypeMappingOnIn(InExpression inExpression) + private InExpression ApplyTypeMappingOnIn(InExpression inExpression) { - var itemTypeMapping = (inExpression.Values != null - ? ExpressionExtensions.InferTypeMapping(inExpression.Item, inExpression.Values) - : inExpression.Subquery != null - ? ExpressionExtensions.InferTypeMapping(inExpression.Item, inExpression.Subquery.Projection[0].Expression) - : inExpression.Item.TypeMapping) - ?? Dependencies.TypeMappingSource.FindMapping(inExpression.Item.Type, Dependencies.Model); - - var item = ApplyTypeMapping(inExpression.Item, itemTypeMapping); - if (inExpression.Values != null) + var missingTypeMappingInValues = false; + + RelationalTypeMapping? valuesTypeMapping = null; + switch (inExpression) { - var values = ApplyTypeMapping(inExpression.Values, itemTypeMapping); + case { Subquery: SelectExpression subquery }: + valuesTypeMapping = subquery.Projection[0].Expression.TypeMapping; + break; + + case { ValuesParameter: SqlParameterExpression parameter }: + valuesTypeMapping = parameter.TypeMapping; + break; + + case { Values: IReadOnlyList values }: + // Note: there could be conflicting type mappings inside the values; we take the first. + foreach (var value in values) + { + if (value.TypeMapping is null) + { + missingTypeMappingInValues = true; + } + else + { + valuesTypeMapping = value.TypeMapping; + } + } + + break; - return item != inExpression.Item || values != inExpression.Values || inExpression.TypeMapping != _boolTypeMapping - ? new InExpression(item, values, inExpression.IsNegated, _boolTypeMapping) - : inExpression; + default: + throw new ArgumentOutOfRangeException(); } - return item != inExpression.Item || inExpression.TypeMapping != _boolTypeMapping - ? new InExpression(item, inExpression.Subquery!, inExpression.IsNegated, _boolTypeMapping) - : inExpression; + var item = ApplyTypeMapping( + inExpression.Item, + valuesTypeMapping ?? Dependencies.TypeMappingSource.FindMapping(inExpression.Item.Type, Dependencies.Model)); + + switch (inExpression) + { + case { Subquery: SelectExpression subquery }: + return inExpression.Update(item, subquery, values: null, valuesParameter: null); + + case { ValuesParameter: SqlParameterExpression parameter }: + return inExpression.Update( + item, subquery: null, values: null, (SqlParameterExpression)ApplyTypeMapping(parameter, item.TypeMapping)); + + case { Values: IReadOnlyList values }: + SqlExpression[]? newValues = null; + + if (missingTypeMappingInValues) + { + newValues = new SqlExpression[values.Count]; + + for (var i = 0; i < newValues.Length; i++) + { + newValues[i] = ApplyTypeMapping(values[i], item.TypeMapping); + } + } + + return inExpression.Update(item, subquery: null, newValues ?? values, valuesParameter: null); + + default: + throw new ArgumentOutOfRangeException(); + } } /// @@ -541,25 +585,16 @@ public virtual ExistsExpression Exists(SelectExpression subquery, bool negated) => new(subquery, negated, _boolTypeMapping); /// - public virtual InExpression In(SqlExpression item, SqlExpression values, bool negated) - { - var typeMapping = item.TypeMapping ?? Dependencies.TypeMappingSource.FindMapping(item.Type, Dependencies.Model); - - item = ApplyTypeMapping(item, typeMapping); - values = ApplyTypeMapping(values, typeMapping); - - return new InExpression(item, values, negated, _boolTypeMapping); - } + public virtual InExpression In(SqlExpression item, SelectExpression subquery, bool negated) + => ApplyTypeMappingOnIn(new InExpression(item, subquery, negated, _boolTypeMapping)); /// - public virtual InExpression In(SqlExpression item, SelectExpression subquery, bool negated) - { - var sqlExpression = subquery.Projection.Single().Expression; - var typeMapping = sqlExpression.TypeMapping; + public virtual InExpression In(SqlExpression item, IReadOnlyList values, bool negated) + => ApplyTypeMappingOnIn(new InExpression(item, values, negated, _boolTypeMapping)); - item = ApplyTypeMapping(item, typeMapping); - return new InExpression(item, subquery, negated, _boolTypeMapping); - } + /// + public virtual InExpression In(SqlExpression item, SqlParameterExpression valuesParameter, bool negated) + => ApplyTypeMappingOnIn(new InExpression(item, valuesParameter, negated, _boolTypeMapping)); /// public virtual LikeExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null) @@ -622,7 +657,7 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); var predicate = concreteEntityTypes.Count == 1 ? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : In(discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), negated: false); + : In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray(), negated: false); selectExpression.ApplyPredicate(predicate); diff --git a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs index af0777b47bb..4e2b2439f20 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs @@ -17,10 +17,10 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; public class InExpression : SqlExpression { /// - /// Creates a new instance of the class which represents a IN subquery expression. + /// Creates a new instance of the class, representing a SQL IN expression with a subquery. /// /// An item to look into values. - /// A subquery in which item is searched. + /// A subquery in which the item is searched. /// A value indicating if the item should be present in the values or absent. /// The associated with the expression. public InExpression( @@ -28,30 +28,49 @@ public InExpression( SelectExpression subquery, bool negated, RelationalTypeMapping typeMapping) - : this(item, null, subquery, negated, typeMapping) + : this(item, subquery, values: null, valuesParameter: null, negated, typeMapping) { } /// - /// Creates a new instance of the class which represents a IN values expression. + /// Creates a new instance of the class, representing a SQL IN expression with a given list + /// of values. /// /// An item to look into values. - /// A list of values in which item is searched. + /// A list of values in which the item is searched. /// A value indicating if the item should be present in the values or absent. /// The associated with the expression. public InExpression( SqlExpression item, - SqlExpression values, + IReadOnlyList values, bool negated, RelationalTypeMapping typeMapping) - : this(item, values, null, negated, typeMapping) + : this(item, subquery: null, values, valuesParameter: null, negated, typeMapping) + { + } + + /// + /// Creates a new instance of the class, representing a SQL IN expression with a given + /// parameterized list of values. + /// + /// An item to look into values. + /// A parameterized list of values in which the item is searched. + /// A value indicating if the item should be present in the values or absent. + /// The associated with the expression. + public InExpression( + SqlExpression item, + SqlParameterExpression valuesParameter, + bool negated, + RelationalTypeMapping typeMapping) + : this(item, subquery: null, values: null, valuesParameter, negated, typeMapping) { } private InExpression( SqlExpression item, - SqlExpression? values, SelectExpression? subquery, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter, bool negated, RelationalTypeMapping? typeMapping) : base(typeof(bool), typeMapping) @@ -65,6 +84,7 @@ private InExpression( Item = item; Subquery = subquery; Values = values; + ValuesParameter = valuesParameter; IsNegated = negated; } @@ -79,23 +99,54 @@ private InExpression( public virtual bool IsNegated { get; } /// - /// The list of values to search item in. + /// The subquery to search the item in. /// - public virtual SqlExpression? Values { get; } + public virtual SelectExpression? Subquery { get; } /// - /// The subquery to search item in. + /// The list of values to search the item in. /// - public virtual SelectExpression? Subquery { get; } + public virtual IReadOnlyList? Values { get; } + + /// + /// A parameter containing the list of values to search the item in. The parameterized list get expanded to the actual value + /// before the query SQL is generated. + /// + public virtual SqlParameterExpression? ValuesParameter { get; } /// protected override Expression VisitChildren(ExpressionVisitor visitor) { var item = (SqlExpression)visitor.Visit(Item); var subquery = (SelectExpression?)visitor.Visit(Subquery); - var values = (SqlExpression?)visitor.Visit(Values); - return Update(item, values, subquery); + SqlExpression[]? values = null; + if (Values is not null) + { + for (var i = 0; i < Values.Count; i++) + { + var value = Values[i]; + var newValue = (SqlExpression)visitor.Visit(value); + + if (newValue != value && values is null) + { + values = new SqlExpression[Values.Count]; + for (var j = 0; j < i; j++) + { + values[j] = Values[j]; + } + } + + if (values is not null) + { + values[i] = newValue; + } + } + } + + var valuesParameter = (SqlParameterExpression?)visitor.Visit(ValuesParameter); + + return Update(item, subquery, values ?? Values, valuesParameter); } /// @@ -103,30 +154,32 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) /// /// An expression which is negated form of this expression. public virtual InExpression Negate() - => new(Item, Values, Subquery, !IsNegated, TypeMapping); + => new(Item, Subquery, Values, ValuesParameter, !IsNegated, TypeMapping); /// /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will /// return this expression. /// /// The property of the result. - /// The property of the result. /// The property of the result. + /// The property of the result. + /// The property of the result. /// This expression if no children changed, or an expression with the updated children. public virtual InExpression Update( SqlExpression item, - SqlExpression? values, - SelectExpression? subquery) + SelectExpression? subquery, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter) { - if (values != null - && subquery != null) + if ((subquery is null ? 0 : 1) + (values is null ? 0 : 1) + (valuesParameter is null ? 0 : 1) != 1) { - throw new ArgumentException(RelationalStrings.EitherOfTwoValuesMustBeNull(nameof(values), nameof(subquery))); + throw new ArgumentException( + RelationalStrings.OneOfThreeValuesMustBeSet(nameof(subquery), nameof(values), nameof(valuesParameter))); } - return item != Item || subquery != Subquery || values != Values - ? new InExpression(item, values, subquery, IsNegated, TypeMapping) - : this; + return item == Item && subquery == Subquery && values == Values && valuesParameter == ValuesParameter + ? this + : new InExpression(item, subquery, values, valuesParameter, IsNegated, TypeMapping); } /// @@ -136,31 +189,35 @@ protected override void Print(ExpressionPrinter expressionPrinter) expressionPrinter.Append(IsNegated ? " NOT IN " : " IN "); expressionPrinter.Append("("); - if (Subquery != null) - { - using (expressionPrinter.Indent()) - { - expressionPrinter.Visit(Subquery); - } - } - else if (Values is SqlConstantExpression constantValuesExpression - && constantValuesExpression.Value is IEnumerable constantValues) + switch (this) { - var first = true; - foreach (var item in constantValues) - { - if (!first) + case { Subquery: not null }: + using (expressionPrinter.Indent()) { - expressionPrinter.Append(", "); + expressionPrinter.Visit(Subquery); } - first = false; - expressionPrinter.Append(constantValuesExpression.TypeMapping?.GenerateSqlLiteral(item) ?? item?.ToString() ?? "NULL"); - } - } - else - { - expressionPrinter.Visit(Values); + break; + + case { Values: not null }: + for (var i = 0; i < Values.Count; i++) + { + if (i > 0) + { + expressionPrinter.Append(", "); + } + + expressionPrinter.Visit(Values[i]); + } + + break; + + case { ValuesParameter: not null}: + expressionPrinter.Visit(ValuesParameter); + break; + + default: + throw new ArgumentOutOfRangeException(); } expressionPrinter.Append(")"); @@ -177,10 +234,11 @@ private bool Equals(InExpression inExpression) => base.Equals(inExpression) && Item.Equals(inExpression.Item) && IsNegated.Equals(inExpression.IsNegated) + && (Subquery?.Equals(inExpression.Subquery) ?? inExpression.Subquery == null) && (Values?.Equals(inExpression.Values) ?? inExpression.Values == null) - && (Subquery?.Equals(inExpression.Subquery) ?? inExpression.Subquery == null); + && (ValuesParameter?.Equals(inExpression.ValuesParameter) ?? inExpression.ValuesParameter == null); /// public override int GetHashCode() - => HashCode.Combine(base.GetHashCode(), Item, IsNegated, Values, Subquery); + => HashCode.Combine(base.GetHashCode(), Item, IsNegated, Values, Subquery, ValuesParameter); } diff --git a/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs index 56fcde09c15..02de0686118 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs @@ -42,23 +42,9 @@ public ValuesExpression( { Check.NotEmpty(rowValues, nameof(rowValues)); -#if DEBUG - if (rowValues.Any(rv => rv.Values.Count != columnNames.Count)) - { - throw new ArgumentException("All number of all row values doesn't match the number of column names"); - } - - if (rowValues.SelectMany(rv => rv.Values).Any( - v => v is not SqlConstantExpression and not SqlUnaryExpression - { - Operand: SqlConstantExpression, - OperatorType: ExpressionType.Convert - })) - { - // See #30734 for non-constants - throw new ArgumentException("Only constant expressions are supported in ValuesExpression"); - } -#endif + Check.DebugAssert( + rowValues.All(rv => rv.Values.Count == columnNames.Count), + "All row values must have a value count matching the number of column names"); RowValues = rowValues; ColumnNames = columnNames; diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index bf39f8f5947..734b2636f0f 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -671,32 +671,27 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt && subquery.Projection[0].Expression is ColumnExpression columnProjection && !columnProjection.IsNullable); - return inExpression.Update(item, values: null, subquery); + return inExpression.Update(item, subquery, values: null, valuesParameter: null); } // for relational null semantics we don't need to extract null values from the array - if (UseRelationalNulls - || !(inExpression.Values is SqlConstantExpression || inExpression.Values is SqlParameterExpression)) + if (UseRelationalNulls) { - var (valuesExpression, valuesList, _) = ProcessInExpressionValues(inExpression.Values!, extractNullValues: false); + var (processedInExpression2, _) = ProcessInExpressionValues(inExpression, extractNullValues: false); nullable = false; - return valuesList.Count == 0 + return processedInExpression2.Values!.Count == 0 ? _sqlExpressionFactory.Constant(false, inExpression.TypeMapping) - : SimplifyInExpression( - inExpression.Update(item, valuesExpression, subquery: null), - valuesExpression, - valuesList); + : SimplifyInExpression(processedInExpression2); } // for c# null semantics we need to remove nulls from Values and add IsNull/IsNotNull when necessary - var (inValuesExpression, inValuesList, hasNullValue) = ProcessInExpressionValues(inExpression.Values, extractNullValues: true); + var (processedInExpression, hasNullValue) = ProcessInExpressionValues(inExpression, extractNullValues: true); + nullable = false; // either values array is empty or only contains null - if (inValuesList.Count == 0) + if (processedInExpression.Values!.Count == 0) { - nullable = false; - // a IN () -> false // non_nullable IN (NULL) -> false // a NOT IN () -> true @@ -712,15 +707,11 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt : _sqlExpressionFactory.IsNull(item); } - var simplifiedInExpression = SimplifyInExpression( - inExpression.Update(item, inValuesExpression, subquery: null), - inValuesExpression, - inValuesList); + var simplifiedInExpression = SimplifyInExpression(processedInExpression); if (!itemNullable || (allowOptimizedExpansion && !inExpression.IsNegated && !hasNullValue)) { - nullable = false; // non_nullable IN (1, 2) -> non_nullable IN (1, 2) // non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2) @@ -730,8 +721,6 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt return simplifiedInExpression; } - nullable = false; - // nullable IN (1, 2) -> nullable IN (1, 2) AND nullable IS NOT NULL (full) // nullable IN (1, 2, NULL) -> nullable IN (1, 2) OR nullable IS NULL (full) // nullable NOT IN (1, 2) -> nullable NOT IN (1, 2) OR nullable IS NULL (full) @@ -744,61 +733,81 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt simplifiedInExpression, _sqlExpressionFactory.IsNull(item)); - (SqlConstantExpression ProcessedValuesExpression, List ProcessedValuesList, bool HasNullValue) - ProcessInExpressionValues(SqlExpression valuesExpression, bool extractNullValues) + (InExpression ProcessedInExpression, bool HasNullValue) ProcessInExpressionValues(InExpression inExpression, bool extractNullValues) { - var inValues = new List(); + List? processedValues = null; var hasNullValue = false; - RelationalTypeMapping? typeMapping; - IEnumerable values; - switch (valuesExpression) + if (inExpression.ValuesParameter is SqlParameterExpression valuesParameter) { - case SqlConstantExpression sqlConstant: - typeMapping = sqlConstant.TypeMapping; - values = (IEnumerable)sqlConstant.Value!; - break; - - case SqlParameterExpression sqlParameter: - DoNotCache(); - typeMapping = sqlParameter.TypeMapping; - values = (IEnumerable?)ParameterValues[sqlParameter.Name] ?? Array.Empty(); - break; - - default: - throw new InvalidOperationException( - RelationalStrings.NonConstantOrParameterAsInExpressionValues(valuesExpression.GetType().Name)); - } + // The InExpression has a values parameter. Expand it out, embedding its values as constants into the SQL; disable SQL + // caching. + DoNotCache(); + var typeMapping = inExpression.ValuesParameter.TypeMapping; + var values = (IEnumerable?)ParameterValues[valuesParameter.Name] ?? Array.Empty(); - foreach (var value in values) - { - if (value == null && extractNullValues) + processedValues = new List(); + + foreach (var value in values) { - hasNullValue = true; - continue; - } + if (value == null && extractNullValues) + { + hasNullValue = true; + continue; + } - inValues.Add(value); + processedValues.Add(_sqlExpressionFactory.Constant(value, typeMapping)); + } } + else + { + Check.DebugAssert(inExpression.Values is not null, "inExpression.Values is not null"); - var processedValuesExpression = _sqlExpressionFactory.Constant(inValues, typeMapping); + for (var i = 0; i < inExpression.Values.Count; i++) + { + var valueExpression = inExpression.Values[i]; + + var value = valueExpression switch + { + SqlConstantExpression c => c.Value, + SqlParameterExpression p => ParameterValues[p.Name], + + _ => throw new InvalidOperationException( + RelationalStrings.NonConstantOrParameterAsInExpressionValues(valueExpression.GetType().Name)) + }; - return (processedValuesExpression, inValues, hasNullValue); + if (value is null && extractNullValues) + { + hasNullValue = true; + + if (processedValues is null) + { + processedValues = new List(inExpression.Values.Count - 1); + for (var j = 0; j < i; j++) + { + processedValues.Add(inExpression.Values[j]); + } + } + + // Skip the NULL value + continue; + } + + processedValues?.Add(valueExpression); + } + } + + var processedInExpression = inExpression.Update( + inExpression.Item, subquery: null, values: processedValues ?? inExpression.Values, valuesParameter: null); + return (processedInExpression, hasNullValue); } - SqlExpression SimplifyInExpression( - InExpression inExpression, - SqlConstantExpression inValuesExpression, - List inValuesList) - => inValuesList.Count == 1 - ? inExpression.IsNegated - ? (SqlExpression)_sqlExpressionFactory.NotEqual( - inExpression.Item, - _sqlExpressionFactory.Constant(inValuesList[0], inValuesExpression.TypeMapping)) - : _sqlExpressionFactory.Equal( - inExpression.Item, - _sqlExpressionFactory.Constant(inValuesList[0], inExpression.Values!.TypeMapping)) - : inExpression; + SqlExpression SimplifyInExpression(InExpression inExpression) + => inExpression.Values is not [var valueExpression] + ? inExpression + : inExpression.IsNegated + ? _sqlExpressionFactory.NotEqual(inExpression.Item, valueExpression) + : _sqlExpressionFactory.Equal(inExpression.Item, valueExpression); } /// diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index 7ea65b2626a..150640dbe0c 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -223,10 +223,36 @@ protected override Expression VisitIn(InExpression inExpression) _isSearchCondition = false; var item = (SqlExpression)Visit(inExpression.Item); var subquery = (SelectExpression?)Visit(inExpression.Subquery); - var values = (SqlExpression?)Visit(inExpression.Values); + + var values = inExpression.Values; + SqlExpression[]? newValues = null; + if (values is not null) + { + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + var newValue = (SqlExpression)Visit(value); + + if (newValue != value && newValues is null) + { + newValues = new SqlExpression[values.Count]; + for (var j = 0; j < i; j++) + { + newValues[j] = values[j]; + } + } + + if (newValues is not null) + { + newValues[i] = newValue; + } + } + } + + var valuesParameter = (SqlParameterExpression?)Visit(inExpression.ValuesParameter); _isSearchCondition = parentSearchCondition; - return ApplyConversion(inExpression.Update(item, values, subquery), condition: true); + return ApplyConversion(inExpression.Update(item, subquery, newValues ?? values, valuesParameter), condition: true); } /// diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs index 604c95e619a..29d745ca900 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs @@ -331,7 +331,7 @@ public SqlServerInferredTypeMappingApplier( IRelationalTypeMappingSource typeMappingSource, ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping> inferredTypeMappings) - : base(inferredTypeMappings) + : base(sqlExpressionFactory, inferredTypeMappings) => (_typeMappingSource, _sqlExpressionFactory) = (typeMappingSource, sqlExpressionFactory); /// diff --git a/src/EFCore/Query/InlineQueryRootExpression.cs b/src/EFCore/Query/InlineQueryRootExpression.cs index 3106f618c8a..dbbb0b470c2 100644 --- a/src/EFCore/Query/InlineQueryRootExpression.cs +++ b/src/EFCore/Query/InlineQueryRootExpression.cs @@ -45,6 +45,17 @@ public InlineQueryRootExpression(IReadOnlyList values, Type elementT public override Expression DetachQueryProvider() => new InlineQueryRootExpression(Values, ElementType); + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual InlineQueryRootExpression Update(IReadOnlyList values) + => ReferenceEquals(values, Values) || values.SequenceEqual(Values) + ? this + : new InlineQueryRootExpression(values, ElementType); + /// protected override Expression VisitChildren(ExpressionVisitor visitor) { @@ -70,7 +81,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) } } - return newValues is null ? this : new InlineQueryRootExpression(newValues, Type); + return newValues is null ? this : Update(newValues); } /// diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 43ce2d91a3a..8a9fbe4eb7c 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -515,7 +515,7 @@ public IDictionary Find(Expression expression) var parentEvaluatable = _evaluatable; var parentContainsClosure = _containsClosure; - _evaluatable = IsEvaluatableNodeType(expression) + _evaluatable = IsEvaluatableNodeType(expression, out var preferNoEvaluation) // Extension point to disable funcletization && _evaluatableExpressionFilter.IsEvaluatableExpression(expression, _model) // Don't evaluate QueryableMethods if in compiled query @@ -524,7 +524,7 @@ public IDictionary Find(Expression expression) base.Visit(expression); - if (_evaluatable) + if (_evaluatable && !preferNoEvaluation) { // Force parameterization when not in lambda _evaluatableExpressions[expression] = _containsClosure || !_inLambda; @@ -643,10 +643,23 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio return base.VisitConstant(constantExpression); } - private static bool IsEvaluatableNodeType(Expression expression) - => expression.NodeType != ExpressionType.Extension - || expression.CanReduce - && IsEvaluatableNodeType(expression.ReduceAndCheck()); + private static bool IsEvaluatableNodeType(Expression expression, out bool preferNoEvaluation) + { + switch (expression.NodeType) + { + case ExpressionType.NewArrayInit: + preferNoEvaluation = true; + return true; + + case ExpressionType.Extension: + preferNoEvaluation = false; + return expression.CanReduce && IsEvaluatableNodeType(expression.ReduceAndCheck(), out preferNoEvaluation); + + default: + preferNoEvaluation = false; + return true; + } + } private static bool IsQueryableMethod(Expression expression) => expression is MethodCallExpression methodCallExpression diff --git a/src/EFCore/Query/QueryRootProcessor.cs b/src/EFCore/Query/QueryRootProcessor.cs index ac98c75b4f9..307f53fbee1 100644 --- a/src/EFCore/Query/QueryRootProcessor.cs +++ b/src/EFCore/Query/QueryRootProcessor.cs @@ -50,48 +50,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var argument = methodCallExpression.Arguments[i]; var parameterType = parameters[i].ParameterType; - Expression? visitedArgument = null; - // This converts collections over constants and parameters to query roots, for later translation of LINQ operators over them. // The element type doesn't have to be directly mappable; we allow unknown CLR types in order to support value convertors // (the precise type mapping - with the value converter - will be inferred later based on LINQ operators composed on the root). // However, we do exclude element CLR types which are associated to entity types in our model, since Contains over entity // collections isn't yet supported (#30712). - if (parameterType.IsGenericType + var visitedArgument = parameterType.IsGenericType && (parameterType.GetGenericTypeDefinition() == typeof(IEnumerable<>) || parameterType.GetGenericTypeDefinition() == typeof(IQueryable<>)) && parameterType.GetGenericArguments()[0] is var elementClrType - && !_model.FindEntityTypes(elementClrType).Any()) - { - switch (argument) - { - case ConstantExpression { Value: IEnumerable values } constantExpression - when ShouldConvertToInlineQueryRoot(constantExpression): - - var valueExpressions = new List(); - foreach (var value in values) - { - valueExpressions.Add(Expression.Constant(value, elementClrType)); - } - visitedArgument = new InlineQueryRootExpression(valueExpressions, elementClrType); - break; - - // TODO: Support NewArrayExpression, see #30734. - - case ParameterExpression parameterExpression - when parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) - == true - && ShouldConvertToParameterQueryRoot(parameterExpression): - visitedArgument = new ParameterQueryRootExpression(parameterExpression.Type.GetSequenceType(), parameterExpression); - break; - - default: - visitedArgument = null; - break; - } - } - - visitedArgument ??= Visit(argument); + && !_model.FindEntityTypes(elementClrType).Any() + ? VisitQueryRootCandidate(argument, elementClrType) + : Visit(argument); if (visitedArgument != argument) { @@ -117,12 +87,47 @@ when ShouldConvertToInlineQueryRoot(constantExpression): : methodCallExpression.Update(methodCallExpression.Object, newArguments); } + private Expression VisitQueryRootCandidate(Expression expression, Type elementClrType) + { + switch (expression) + { + // An array containing only constants is represented as a ConstantExpression with the array as the value. + // Convert that into a NewArrayExpression for use with InlineQueryRootExpression + case ConstantExpression { Value: IEnumerable values }: + var valueExpressions = new List(); + foreach (var value in values) + { + valueExpressions.Add(Expression.Constant(value, elementClrType)); + } + + if (ShouldConvertToInlineQueryRoot(Expression.NewArrayInit(elementClrType, valueExpressions))) + { + return new InlineQueryRootExpression(valueExpressions, elementClrType); + } + + goto default; + + case NewArrayExpression newArrayExpression + when ShouldConvertToInlineQueryRoot(newArrayExpression): + return new InlineQueryRootExpression(newArrayExpression.Expressions, elementClrType); + + case ParameterExpression parameterExpression + when parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) + == true + && ShouldConvertToParameterQueryRoot(parameterExpression): + return new ParameterQueryRootExpression(parameterExpression.Type.GetSequenceType(), parameterExpression); + + default: + return Visit(expression); + } + } + /// /// Determines whether a should be converted to a . /// This handles cases inline expressions whose elements are all constants. /// - /// The constant expression that's a candidate for conversion to a query root. - protected virtual bool ShouldConvertToInlineQueryRoot(ConstantExpression constantExpression) + /// The new array expression that's a candidate for conversion to a query root. + protected virtual bool ShouldConvertToInlineQueryRoot(NewArrayExpression newArrayExpression) => false; /// diff --git a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs index 980b287dfcd..bbbf4d97cbe 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs @@ -296,7 +296,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) .HasTranslation( args => new InExpression( args.First(), - new SqlConstantExpression(Expression.Constant(abc), typeMapping: null), // args.First().TypeMapping), + new[] { new SqlConstantExpression(Expression.Constant(abc), typeMapping: null) }, // args.First().TypeMapping) negated: false, typeMapping: null)); @@ -306,10 +306,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) args => new InExpression( new InExpression( args.First(), - new SqlConstantExpression(Expression.Constant(abc), args.First().TypeMapping), + new[] { new SqlConstantExpression(Expression.Constant(abc), args.First().TypeMapping) }, negated: false, typeMapping: null), - new SqlConstantExpression(Expression.Constant(trueFalse), typeMapping: null), + new[] { new SqlConstantExpression(Expression.Constant(trueFalse), typeMapping: null) }, negated: false, typeMapping: null)); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs index 8a06980d06e..cf4ec6df7c3 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs @@ -10,7 +10,7 @@ public PrimitiveCollectionsQuerySqlServerTest(PrimitiveCollectionsQuerySqlServer : base(fixture) { Fixture.TestSqlLoggerFactory.Clear(); - // Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); } public override async Task Inline_collection_of_ints_Contains(bool async) @@ -144,18 +144,14 @@ public override async Task Inline_collection_Contains_with_all_parameters(bool a { await base.Inline_collection_Contains_with_all_parameters(async); - // See #30732 for making this better - AssertSql( """ -@__p_0='[2,999]' (Size = 4000) +@__i_0='2' +@__j_1='999' SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE EXISTS ( - SELECT 1 - FROM OpenJson(@__p_0) AS [p0] - WHERE CAST([p0].[value] AS int) = [p].[Id]) +WHERE [p].[Id] IN (@__i_0, @__j_1) """); }