Skip to content

Commit

Permalink
Fix expression cloning when table changes in SelectExpression.VisitCh…
Browse files Browse the repository at this point in the history
…ildren

Fixes dotnet#32234
  • Loading branch information
roji committed Nov 29, 2023
1 parent 74cc719 commit b8156da
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,53 @@ public TableReferenceUpdatingExpressionVisitor(SelectExpression oldSelect, Selec
}
}

// Note: this is conceptually the same as ColumnExpressionReplacingExpressionVisitor; I duplicated it since this is for a patch,
// and we want to limit the potential risk (note that this calls the special SelectExpression.VisitChildren() with updateColumns: false,
// to avoid infinite recursion).
private sealed class ColumnTableReferenceUpdater : ExpressionVisitor
{
private readonly SelectExpression _oldSelect;
private readonly SelectExpression _newSelect;

public ColumnTableReferenceUpdater(SelectExpression oldSelect, SelectExpression newSelect)
{
_oldSelect = oldSelect;
_newSelect = newSelect;
}

[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (expression is ConcreteColumnExpression columnExpression
&& _oldSelect._tableReferences.Find(t => ReferenceEquals(t.Table, columnExpression.Table)) is TableReferenceExpression
oldTableReference
&& _newSelect._tableReferences.Find(t => t.Alias == columnExpression.TableAlias) is TableReferenceExpression
newTableReference
&& newTableReference != oldTableReference)
{
return new ConcreteColumnExpression(
columnExpression.Name,
newTableReference,
columnExpression.Type,
columnExpression.TypeMapping!,
columnExpression.IsNullable);
}

return base.Visit(expression);
}

protected override Expression VisitExtension(Expression node)
{
if (node is SelectExpression select)
{
Check.DebugAssert(!select._mutable, "Visiting mutable select expression in ColumnTableReferenceUpdater");
return select.VisitChildren(this, updateColumns: false);
}

return base.VisitExtension(node);
}
}

private sealed class IdentifierComparer : IEqualityComparer<(ColumnExpression Column, ValueComparer Comparer)>
{
public bool Equals((ColumnExpression Column, ValueComparer Comparer) x, (ColumnExpression Column, ValueComparer Comparer) y)
Expand Down
37 changes: 32 additions & 5 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4609,6 +4609,9 @@ private static string GenerateUniqueAlias(HashSet<string> usedAliases, string cu

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> VisitChildren(visitor, updateColumns: true);

private Expression VisitChildren(ExpressionVisitor visitor, bool updateColumns)
{
if (_mutable)
{
Expand Down Expand Up @@ -4797,14 +4800,38 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
newSelectExpression._childIdentifiers.AddRange(
childIdentifier.Zip(_childIdentifiers).Select(e => (e.First, e.Second.Comparer)));

// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
// We duplicated the SelectExpression, and must therefore also update all table reference expressions to point to it.
// If any tables have changed, we must duplicate the TableReferenceExpressions and replace all ColumnExpressions to use
// them; otherwise we end up two SelectExpressions sharing the same TableReferenceExpression instance, and if that's later
// mutated, both SelectExpressions are affected (this happened in AliasUniquifier, see #32234).

// Otherwise, if no tables have changed, we mutate the TableReferenceExpressions (this was the previous code, left it for
// a more low-risk fix). Note that updateColumns is true only if we're already being called from ColumnTableReferenceUpdater
// to replace the ColumnExpressions, in which case we avoid infinite recursion.
if (tablesChanged && updateColumns)
{
tableReference.UpdateTableReference(this, newSelectExpression);
for (var i = 0; i < newTableReferences.Count; i++)
{
newTableReferences[i] = new TableReferenceExpression(newSelectExpression, _tableReferences[i].Alias);
}

var columnTableReferenceUpdater = new ColumnTableReferenceUpdater(this, newSelectExpression);
newSelectExpression = (SelectExpression)columnTableReferenceUpdater.Visit(newSelectExpression);
}
else
{
// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
{
tableReference.UpdateTableReference(this, newSelectExpression);
}

var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
// TODO: Why does need to be done? We've already updated all table references on the new select just above, and
// no ColumnExpression in the query is every supposed to reference a TableReferenceExpression that isn't in the
// select's list... The same thing is done in all other places where TableReferenceUpdatingExpressionVisitor is used.
var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
}

return newSelectExpression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5652,4 +5652,28 @@ public virtual Task Collection_navigation_equal_to_null_for_subquery_using_Eleme
ss => ss.Set<Customer>().Where(c => c.Orders.OrderBy(o => o.OrderID).ElementAtOrDefault(prm).OrderDetails == null),
ss => ss.Set<Customer>().Where(c => c.Orders.OrderBy(o => o.OrderID).ElementAtOrDefault(prm) == null));
}

[ConditionalTheory] // #32234
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
var ids = new[] { 10248, 10249 };

var query = (ISetSource ss) => ss.Set<OrderDetail>()
.Where(e => ids.Contains(e.OrderID))
.GroupBy(e => e.Quantity)
.Select(g => new { g.Key, MaxTimestamp = g.Select(e => e.Order.OrderDate).Max() })
.OrderBy(x => x.MaxTimestamp)
.Select(x => x);

#if DEBUG
// GroupBy debug assert. Issue #26104.
Assert.StartsWith(
"Missing alias in the list",
(await Assert.ThrowsAsync<InvalidOperationException>(
() => AssertQuery(async, query))).Message);
#else
await AssertQuery(async, query);
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7260,6 +7260,44 @@ ORDER BY [o].[OrderID]
""");
}

public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
await base.Parameter_collection_Contains_with_projection_and_ordering(async);

#if DEBUG
// GroupBy debug assert. Issue #26104.
AssertSql();
#else
AssertSql(
"""
@__ids_0='[10248,10249]' (Size = 4000)
SELECT [o].[Quantity] AS [Key], (
SELECT MAX([o3].[OrderDate])
FROM [Order Details] AS [o2]
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
WHERE [o2].[OrderID] IN (
SELECT [i1].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i1]
) AND [o].[Quantity] = [o2].[Quantity]) AS [MaxTimestamp]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i]
)
GROUP BY [o].[Quantity]
ORDER BY (
SELECT MAX([o3].[OrderDate])
FROM [Order Details] AS [o2]
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
WHERE [o2].[OrderID] IN (
SELECT [i0].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i0]
) AND [o].[Quantity] = [o2].[Quantity])
""");
#endif
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit b8156da

Please sign in to comment.