Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Contains within SQL Server aggregate functions #32478

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty(
}
}

private bool TryTranslateAggregateMethodCall(
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
protected virtual bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,16 @@ public override bool IsBuffering
=> base.IsBuffering
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
&& !_multipleActiveResultSetsEnabled);

/// <summary>
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
/// allow subqueries and aggregates in that context.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public virtual bool InAggregateFunction { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly int _sqlServerCompatibilityLevel;
Expand All @@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu
public SqlServerQueryableMethodTranslatingExpressionVisitor(
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
ISqlServerSingletonOptions sqlServerSingletonOptions)
: base(dependencies, relationalDependencies, queryCompilationContext)
{
Expand Down Expand Up @@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);
}

#region Aggregate functions

// We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't
// support subqueries (or aggregates) within them.

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateAverage(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateSum(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateLongCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMax(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMin(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

#endregion Aggregate functions

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -315,6 +412,47 @@ static IEnumerable<INavigation> GetAllNavigationsInHierarchy(IEntityType entityT
.SelectMany(t => t.GetDeclaredNavigations());
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var translatedSource = base.TranslateContains(source, item);

// SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)).
// As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation
// (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a
// low SQL Server compatibility level were defined)
if (_queryCompilationContext.InAggregateFunction
&& translatedSource is not null
&& TryGetProjection(translatedSource, out var projection)
&& projection is InExpression
{
Item: var translatedItem,
Subquery:
{
Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression],
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
Orderings: [],
Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }]
}
}
&& projectionColumnTable == openJsonExpression)
{
var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter);
return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression));
}

return translatedSource;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate(
return false;
}

private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
{
var shaperExpression = shapedQueryExpression.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
{
projection = sqlExpression;
return true;
}

projection = null;
return false;
}

private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
{
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
/// </summary>
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new SqlServerQueryableMethodTranslatingExpressionVisitor(
Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions);
Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

private static readonly HashSet<string> DateTimeDataTypes
Expand Down Expand Up @@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo
/// </summary>
public SqlServerSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
{
Expand Down Expand Up @@ -432,6 +432,28 @@ private static string EscapeLikePattern(string pattern)
return builder.ToString();
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;

#pragma warning disable EF1001 // Internal EF Core API usage.
var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation);
#pragma warning restore EF1001 // Internal EF Core API usage.

_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;

return result;
}

private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType)
{
var visitedArray = Visit(array);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
=> new SqlServerSqlTranslatingExpressionVisitor(
Dependencies,
queryCompilationContext,
(SqlServerQueryCompilationContext)queryCompilationContext,
queryableMethodTranslatingExpressionVisitor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,86 @@ public override async Task Not_Any_false(bool async)
AssertSql();
}

public override async Task Contains_inside_aggregate_function_with_GroupBy(bool async)
{
// GroupBy. Issue #17313.
await AssertTranslationFailed(() => base.Contains_inside_aggregate_function_with_GroupBy(async));

AssertSql();
}

public override async Task Contains_inside_Average_without_GroupBy(bool async)
{
await base.Contains_inside_Average_without_GroupBy(async);

AssertSql(
"""
SELECT AVG((c["City"] IN ("London", "Berlin") ? 1.0 : 0.0)) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
}

public override async Task Contains_inside_Sum_without_GroupBy(bool async)
{
await base.Contains_inside_Sum_without_GroupBy(async);

AssertSql(
"""
SELECT SUM((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
}

public override async Task Contains_inside_Count_without_GroupBy(bool async)
{
await base.Contains_inside_Count_without_GroupBy(async);

AssertSql(
"""
SELECT COUNT(1) AS c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["City"] IN ("London", "Berlin"))
""");
}

public override async Task Contains_inside_LongCount_without_GroupBy(bool async)
{
await base.Contains_inside_LongCount_without_GroupBy(async);

AssertSql(
"""
SELECT COUNT(1) AS c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["City"] IN ("London", "Berlin"))
""");
}

public override async Task Contains_inside_Max_without_GroupBy(bool async)
{
await base.Contains_inside_Max_without_GroupBy(async);

AssertSql(
"""
SELECT MAX((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
}

public override async Task Contains_inside_Min_without_GroupBy(bool async)
{
await base.Contains_inside_Min_without_GroupBy(async);

AssertSql(
"""
SELECT MIN((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Loading
Loading