Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GroupBy over complex type #33493

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,23 @@ protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression s

var remappedKeySelector = RemapLambdaBody(source, keySelector);
var translatedKey = TranslateGroupingKey(remappedKeySelector);
if (translatedKey == null)
switch (translatedKey)
{
// This could be group by entity type
if (remappedKeySelector is not StructuralTypeShaperExpression
// Special handling for GroupBy over entity type: get the entity projection expression out.
// For GroupBy over a complex type, we already get the projection expression out.
case StructuralTypeShaperExpression { StructuralType: IEntityType } shaper:
if (shaper.ValueBufferExpression is not ProjectionBindingExpression pbe)
{
ValueBufferExpression: ProjectionBindingExpression pbe
} shaper)
{
// ValueBufferExpression can be JsonQuery, ProjectionBindingExpression, EntityProjection
// We only allow ProjectionBindingExpression which represents a regular entity
return null;
}
// ValueBufferExpression can be JsonQuery, ProjectionBindingExpression, EntityProjection
// We only allow ProjectionBindingExpression which represents a regular entity
return null;
}

translatedKey = shaper.Update(((SelectExpression)pbe.QueryExpression).GetProjection(pbe));
translatedKey = shaper.Update(((SelectExpression)pbe.QueryExpression).GetProjection(pbe));
break;

case null:
return null;
}

if (elementSelector != null)
Expand Down Expand Up @@ -823,7 +826,7 @@ protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression s
return memberInitExpression.Update(updatedNewExpression, newBindings);

default:
var translation = TranslateExpression(expression);
var translation = TranslateProjection(expression);
if (translation == null)
{
return null;
Expand Down Expand Up @@ -1325,6 +1328,21 @@ protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression so
return translation;
}

private Expression? TranslateProjection(Expression expression, bool applyDefaultTypeMapping = true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: at some point I'll do a cleanup here.. Most of the time we call into RelationalSqlTranslatingExpressionVisitor expecting a SqlExpression as a result, but there are various cases where we can get a StructuralTypeShaperExpression instead (like when projecting out complex types, or in this case, where the GroupBy key translator yields a complex type). We should clean up the APIs.

{
var translation = _sqlTranslator.TranslateProjection(expression, applyDefaultTypeMapping);

if (translation is null)
{
if (_sqlTranslator.TranslationErrorDetails != null)
{
AddTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
}
}

return translation;
}

/// <summary>
/// Translates the given lambda expression for the <see cref="ShapedQueryExpression" /> source into equivalent SQL representation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,11 @@ private static void PopulateGroupByTerms(
projection.DiscriminatorExpression, groupByTerms, groupByAliases, name: DiscriminatorColumnAlias);
}

foreach (var complexProperty in projection.StructuralType.GetComplexProperties())
{
PopulateGroupByTerms(projection.BindComplexProperty(complexProperty), groupByTerms, groupByAliases, name: null);
}

break;

default:
Expand Down
33 changes: 33 additions & 0 deletions test/EFCore.Specification.Tests/Query/ComplexTypeQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,46 @@ from c2 in ss.Set<Customer>()
AssertEqual(e.Complex?.Two, a.Complex?.Two);
});

#region GroupBy

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_property_in_nested_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress.Country.Code).Select(g => new { Code = g.Key, Count = g.Count() }),
elementSorter: g => g.Code);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress).Select(g => new { Address = g.Key, Count = g.Count() }),
elementSorter: g => g.Address.ZipCode,
elementAsserter: (e, a) =>
{
AssertEqual(e.Address, a.Address);
Assert.Equal(e.Count, a.Count);
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_nested_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress.Country).Select(g => new { Country = g.Key, Count = g.Count() }),
elementSorter: g => g.Country.Code);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_with_complex_type_with_group_by_and_first(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.Id).Select(x => x.First()));

#endregion GroupBy

protected DbContext CreateContext()
=> Fixture.CreateContext();
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "804 S. Lakeshore Road",
ZipCode = 38654,
Country = new Country { FullName = "United States", Code = "US" },
Tags = new List<string> { "foo", "bar" }
Tags = ["foo", "bar"]
};

var customer1 = new Customer
Expand All @@ -71,19 +71,14 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "72 Hickory Rd.",
ZipCode = 07728,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string> { "baz" }
Tags = ["baz"]
},
BillingAddress = new Address
{
AddressLine1 = "79 Main St.",
ZipCode = 29293,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string>
{
"a1",
"a2",
"a3"
}
Tags = ["a1", "a2", "a3"]
}
};

Expand All @@ -92,7 +87,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "79 Main St.",
ZipCode = 29293,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string> { "foo", "moo" }
Tags = ["foo", "moo"]
};

var customer3 = new Customer
Expand All @@ -103,12 +98,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
BillingAddress = address3
};

return new List<Customer>
{
customer1,
customer2,
customer3
};
return [customer1, customer2, customer3];
}

private static IReadOnlyList<CustomerGroup> CreateCustomerGroups(IReadOnlyList<Customer> customers)
Expand All @@ -134,12 +124,7 @@ private static IReadOnlyList<CustomerGroup> CreateCustomerGroups(IReadOnlyList<C
OptionalCustomer = null
};

return new List<CustomerGroup>
{
group1,
group2,
group3
};
return [group1, group2, group3];
}

private static IReadOnlyList<ValuedCustomer> CreateValuedCustomers()
Expand Down Expand Up @@ -192,12 +177,7 @@ private static IReadOnlyList<ValuedCustomer> CreateValuedCustomers()
BillingAddress = address3
};

return new List<ValuedCustomer>
{
customer1,
customer2,
customer3
};
return [customer1, customer2, customer3];
}

private static IReadOnlyList<ValuedCustomerGroup> CreateValuedCustomerGroups(IReadOnlyList<ValuedCustomer> customers)
Expand All @@ -223,12 +203,7 @@ private static IReadOnlyList<ValuedCustomerGroup> CreateValuedCustomerGroups(IRe
OptionalCustomer = null
};

return new List<ValuedCustomerGroup>
{
group1,
group2,
group3
};
return [group1, group2, group3];
}

public static Task SeedAsync(PoolableDbContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1131,12 +1131,50 @@ public override async Task Same_complex_type_projected_twice_with_pushdown_as_pa
AssertSql("");
}

#region GroupBy

public override async Task GroupBy_over_property_in_nested_complex_type(bool async)
{
await base.GroupBy_over_property_in_nested_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_Country_Code] AS [Code], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_Country_Code]
""");
}

public override async Task GroupBy_over_complex_type(bool async)
{
await base.GroupBy_over_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_AddressLine1], [c].[ShippingAddress_AddressLine2], [c].[ShippingAddress_Tags], [c].[ShippingAddress_ZipCode], [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_AddressLine1], [c].[ShippingAddress_AddressLine2], [c].[ShippingAddress_Tags], [c].[ShippingAddress_ZipCode], [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName]
""");
}

public override async Task GroupBy_over_nested_complex_type(bool async)
{
await base.GroupBy_over_nested_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName]
""");
}

public override async Task Entity_with_complex_type_with_group_by_and_first(bool async)
{
await base.Entity_with_complex_type_with_group_by_and_first(async);

AssertSql(
"""
"""
SELECT [c3].[Id], [c3].[Name], [c3].[BillingAddress_AddressLine1], [c3].[BillingAddress_AddressLine2], [c3].[BillingAddress_Tags], [c3].[BillingAddress_ZipCode], [c3].[BillingAddress_Country_Code], [c3].[BillingAddress_Country_FullName], [c3].[ShippingAddress_AddressLine1], [c3].[ShippingAddress_AddressLine2], [c3].[ShippingAddress_Tags], [c3].[ShippingAddress_ZipCode], [c3].[ShippingAddress_Country_Code], [c3].[ShippingAddress_Country_FullName]
FROM (
SELECT [c].[Id]
Expand All @@ -1154,6 +1192,8 @@ FROM [Customer] AS [c0]
""");
}

#endregion GroupBy

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,12 +1014,50 @@ public override async Task Same_complex_type_projected_twice_with_pushdown_as_pa
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Same_complex_type_projected_twice_with_pushdown_as_part_of_another_projection(async))).Message);

#region GroupBy

public override async Task GroupBy_over_property_in_nested_complex_type(bool async)
{
await base.GroupBy_over_property_in_nested_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_Country_Code" AS "Code", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_Country_Code"
""");
}

public override async Task GroupBy_over_complex_type(bool async)
{
await base.GroupBy_over_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_AddressLine1", "c"."ShippingAddress_AddressLine2", "c"."ShippingAddress_Tags", "c"."ShippingAddress_ZipCode", "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_AddressLine1", "c"."ShippingAddress_AddressLine2", "c"."ShippingAddress_Tags", "c"."ShippingAddress_ZipCode", "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName"
""");
}

public override async Task GroupBy_over_nested_complex_type(bool async)
{
await base.GroupBy_over_nested_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName"
""");
}

public override async Task Entity_with_complex_type_with_group_by_and_first(bool async)
{
await base.Entity_with_complex_type_with_group_by_and_first(async);

AssertSql(
"""
"""
SELECT "c3"."Id", "c3"."Name", "c3"."BillingAddress_AddressLine1", "c3"."BillingAddress_AddressLine2", "c3"."BillingAddress_Tags", "c3"."BillingAddress_ZipCode", "c3"."BillingAddress_Country_Code", "c3"."BillingAddress_Country_FullName", "c3"."ShippingAddress_AddressLine1", "c3"."ShippingAddress_AddressLine2", "c3"."ShippingAddress_Tags", "c3"."ShippingAddress_ZipCode", "c3"."ShippingAddress_Country_Code", "c3"."ShippingAddress_Country_FullName"
FROM (
SELECT "c"."Id"
Expand All @@ -1037,6 +1075,8 @@ LEFT JOIN (
""");
}

#endregion GroupBy

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Loading