Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to #34056 - AOT/Query: for queries with JSON, interceptors generate code with labels that are not uniquified #34323

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private sealed record StackFrame(
Dictionary<ParameterExpression, string> Variables,
HashSet<string> VariableNames,
Dictionary<LabelTarget, string> Labels,
HashSet<string> UnnamedLabelNames);
HashSet<string> UniqueLabelNames);

private readonly Stack<StackFrame> _stack
= new([new StackFrame([], [], [], [])]);
Expand Down Expand Up @@ -160,7 +160,7 @@ protected virtual SyntaxNode TranslateCore(
Check.DebugAssert(_stack.Peek().Variables.Count == 0, "_stack.Peek().Parameters.Count == 0");
Check.DebugAssert(_stack.Peek().VariableNames.Count == 0, "_stack.Peek().ParameterNames.Count == 0");
Check.DebugAssert(_stack.Peek().Labels.Count == 0, "_stack.Peek().Labels.Count == 0");
Check.DebugAssert(_stack.Peek().UnnamedLabelNames.Count == 0, "_stack.Peek().UnnamedLabelNames.Count == 0");
Check.DebugAssert(_stack.Peek().UniqueLabelNames.Count == 0, "_stack.Peek().UniqueLabelNames.Count == 0");

foreach (var unsafeAccessor in _fieldUnsafeAccessors.Values.Concat(_methodUnsafeAccessors.Values))
{
Expand Down Expand Up @@ -714,6 +714,8 @@ static bool IsExpressionValidAsStatement(ExpressionSyntax expression)
void PreprocessLabels()
{
// LINQ label targets can be unnamed, so we need to generate names for unnamed ones and maintain a target->name mapping.
// Also labels can have duplicated names - we need to de-duplicate them before we can generate a valid c# code
// just like we do with variables/parameters
// We need to maintain this as a stack for every block which has labels.
// Normal blocks get their own labels stack frame, which gets popped when we leave the block. Expression labels add their
// labels to their parent's stack frame (since they get lifted).
Expand All @@ -726,21 +728,17 @@ void PreprocessLabels()
continue;
}

var (_, _, labels, unnamedLabelNames) = stackFrame;
var (_, _, labels, uniqueLabelNames) = stackFrame;

// Generate names for unnamed label targets and uniquify
// Generate names for unnamed label targets and uniquify (all label names)
identifier = label.Target.Name ?? "unnamedLabel";
var identifierBase = identifier;
for (var i = 0; unnamedLabelNames.Contains(identifier); i++)
for (var i = 0; uniqueLabelNames.Contains(identifier); i++)
{
identifier = identifierBase + i;
}

if (label.Target.Name is null)
{
unnamedLabelNames.Add(identifier);
}

uniqueLabelNames.Add(identifier);
labels.Add(label.Target, identifier);
}
}
Expand Down Expand Up @@ -1507,15 +1505,24 @@ protected override Expression VisitLambda<T>(Expression<T> lambda)
var expressionBody = body as ExpressionSyntax;
var blockBody = body as BlockSyntax;

// If the lambda body was an expression that had lifted statements (e.g. some block in expression context), we need to create
// a block to contain these statements
if (_liftedState.Statements.Count > 0)
{
Check.DebugAssert(lambda.ReturnType != typeof(void), "lambda.ReturnType != typeof(void)");
Check.DebugAssert(expressionBody != null, "expressionBody != null");

blockBody = Block(_liftedState.Statements.Append(ReturnStatement(expressionBody)));
expressionBody = null;
if (expressionBody != null)
{
// If the lambda body was an expression that had lifted statements (e.g. some block in expression context), we need to create
// a block to contain these statements
blockBody = Block(_liftedState.Statements.Append(ReturnStatement(expressionBody)));
expressionBody = null;
}
else
{
// If the lambda body was already a block, we just prepend lifted statements to the ones already existing in the block
Check.DebugAssert(blockBody != null, "expressionBody != null || blockBody != null");
blockBody = Block(_liftedState.Statements.Concat(blockBody.Statements));
}

_liftedState.Statements.Clear();
}

Expand Down Expand Up @@ -2734,7 +2741,7 @@ private StackFrame PushNewStackFrame()
new Dictionary<ParameterExpression, string>(previousFrame.Variables),
[..previousFrame.VariableNames],
new Dictionary<LabelTarget, string>(previousFrame.Labels),
[..previousFrame.UnnamedLabelNames]);
[..previousFrame.UniqueLabelNames]);

_stack.Push(newFrame);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
using Microsoft.EntityFrameworkCore.Query.Internal;
using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers;
using Blog = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Blog;
using Post = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Post;
using JsonRoot = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.JsonRoot;
using JsonBranch = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.JsonBranch;
namespace Microsoft.EntityFrameworkCore.Query;

public abstract class PrecompiledQueryRelationalFixture
Expand All @@ -27,9 +30,26 @@ protected override IServiceCollection AddServices(IServiceCollection serviceColl

protected override async Task SeedAsync(PrecompiledQueryRelationalTestBase.PrecompiledQueryContext context)
{
context.Blogs.AddRange(
new Blog { Id = 8, Name = "Blog1" },
new Blog { Id = 9, Name = "Blog2" });
var blog1 = new Blog { Id = 8, Name = "Blog1", Json = [] };
var blog2 = new Blog
{
Id = 9,
Name = "Blog2",
Json =
[
new JsonRoot { Number = 1, Text = "One", Inner = new JsonBranch { Date = new DateTime(2001, 1, 1) } },
new JsonRoot { Number = 2, Text = "Two", Inner = new JsonBranch { Date = new DateTime(2002, 2, 2) } },
]};

context.Blogs.AddRange(blog1, blog2);

var post11 = new Post { Id = 11, Title = "Post11", Blog = blog1 };
var post12 = new Post { Id = 12, Title = "Post12", Blog = blog1 };
var post21 = new Post { Id = 21, Title = "Post21", Blog = blog2 };
var post22 = new Post { Id = 22, Title = "Post22", Blog = blog2 };
var post23 = new Post { Id = 23, Title = "Post23", Blog = blog2 };

context.Posts.AddRange(post11, post12, post21, post22, post23);
await context.SaveChangesAsync();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ public virtual Task BinaryExpression()
=> Test("""
var id = 3;
var blogs = await context.Blogs.Where(b => b.Id > id).ToListAsync();

Assert.Equal(2, blogs.Count);
var orderedBlogs = blogs.OrderBy(x => x.Id).ToList();
var blog1 = orderedBlogs[0];
var blog2 = orderedBlogs[1];

Assert.Equal(8, blog1.Id);
Assert.Equal("Blog1", blog1.Name);
Assert.Empty(blog1.Json);

Assert.Equal(9, blog2.Id);
Assert.Equal("Blog2", blog2.Name);
Assert.Equal(2, blog2.Json.Count);

Assert.Equal(1, blog2.Json[0].Number);
Assert.Equal("One", blog2.Json[0].Text);
Assert.Equal(new DateTime(2001, 1, 1), blog2.Json[0].Inner.Date);

Assert.Equal(2, blog2.Json[1].Number);
Assert.Equal("Two", blog2.Json[1].Text);
Assert.Equal(new DateTime(2002, 2, 2), blog2.Json[1].Inner.Date);
""");

[ConditionalFact]
Expand Down Expand Up @@ -729,6 +750,23 @@ public virtual Task Terminating_ExecuteUpdateAsync()
public virtual Task Union()
=> Test(
"""
var posts = await context.Posts.Where(p => p.Id > 11)
.Union(context.Posts.Where(p => p.Id < 21))
.OrderBy(p => p.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(11, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id),
b => Assert.Equal(22, b.Id),
b => Assert.Equal(23, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task UnionOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Union(context.Blogs.Where(b => b.Id < 10))
.OrderBy(b => b.Id)
Expand All @@ -743,6 +781,24 @@ public virtual Task Union()
public virtual Task Concat()
=> Test(
"""
var posts = await context.Posts.Where(p => p.Id > 11)
.Concat(context.Posts.Where(p => p.Id < 21))
.OrderBy(p => p.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(11, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id),
b => Assert.Equal(22, b.Id),
b => Assert.Equal(23, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task ConcatOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Concat(context.Blogs.Where(b => b.Id < 10))
.OrderBy(b => b.Id)
Expand All @@ -759,6 +815,20 @@ public virtual Task Concat()
public virtual Task Intersect()
=> Test(
"""
var posts = await context.Posts.Where(b => b.Id > 11)
.Intersect(context.Posts.Where(b => b.Id < 22))
.OrderBy(b => b.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task IntersectOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Intersect(context.Blogs.Where(b => b.Id > 8))
.OrderBy(b => b.Id)
Expand All @@ -771,6 +841,20 @@ public virtual Task Intersect()
public virtual Task Except()
=> Test(
"""
var posts = await context.Posts.Where(b => b.Id > 11)
.Except(context.Posts.Where(b => b.Id > 21))
.OrderBy(b => b.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task ExceptOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Except(context.Blogs.Where(b => b.Id > 8))
.OrderBy(b => b.Id)
Expand Down Expand Up @@ -1066,6 +1150,20 @@ public class PrecompiledQueryContext(DbContextOptions options) : DbContext(optio
{
public DbSet<Blog> Blogs { get; set; } = null!;
public DbSet<Post> Posts { get; set; } = null!;

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);
modelBuilder.Entity<Blog>().OwnsMany(
x => x.Json,
n =>
{
n.ToJson();
n.OwnsOne(xx => xx.Inner);
});
modelBuilder.Entity<Blog>().HasMany(x => x.Posts).WithOne(x => x.Blog).OnDelete(DeleteBehavior.Cascade);
modelBuilder.Entity<Post>().Property(x => x.Id).ValueGeneratedNever();
}
}

protected PrecompiledQueryRelationalFixture Fixture { get; }
Expand Down Expand Up @@ -1128,8 +1226,21 @@ public Blog(int id, string name)
[DatabaseGenerated(DatabaseGeneratedOption.None)]
public int Id { get; set; }
public string? Name { get; set; }

public List<Post> Posts { get; set; } = new();
public List<JsonRoot> Json { get; set; } = new();
}

public class JsonRoot
{
public int Number { get; set; }
public string? Text { get; set; }

public JsonBranch Inner { get; set; } = null!;
}

public class JsonBranch
{
public DateTime Date { get; set; }
}

public class Post
Expand Down
Loading