Skip to content

Commit

Permalink
Make maxItemCount required
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jun 29, 2024
1 parent 9f94fed commit 3ada916
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ internal static readonly MethodInfo ToPageMethodInfo
/// <returns>A <see cref="CosmosPage{T}" /> containing at most <paramref name="maxItemCount" /> results.</returns>
public static CosmosPage<TSource> ToPage<TSource>(
this IQueryable<TSource> source,
int maxItemCount,
string? continuationToken = null,
int? maxItemCount = null,
int? responseContinuationTokenLimitInKb = null)
=> source.Provider.Execute<CosmosPage<TSource>>(
Expression.Call(
Expand All @@ -221,8 +221,8 @@ public static CosmosPage<TSource> ToPage<TSource>(
arguments:
[
source.Expression,
Expression.Constant(maxItemCount, typeof(int)),
Expression.Constant(continuationToken, typeof(string)),
Expression.Constant(maxItemCount, typeof(int?)),
Expression.Constant(responseContinuationTokenLimitInKb, typeof(int?))
]));

Expand All @@ -244,8 +244,8 @@ public static CosmosPage<TSource> ToPage<TSource>(
/// <returns>A <see cref="CosmosPage{T}" /> containing at most <paramref name="maxItemCount" /> results.</returns>
public static Task<CosmosPage<TSource>> ToPageAsync<TSource>(
this IQueryable<TSource> source,
int maxItemCount,
string? continuationToken = null,
int? maxItemCount = null,
int? responseContinuationTokenLimitInKb = null,
CancellationToken cancellationToken = default)
{
Expand All @@ -261,8 +261,8 @@ public static Task<CosmosPage<TSource>> ToPageAsync<TSource>(
arguments:
[
source.Expression,
Expression.Constant(maxItemCount, typeof(int)),
Expression.Constant(continuationToken, typeof(string)),
Expression.Constant(maxItemCount, typeof(int?)),
Expression.Constant(responseContinuationTokenLimitInKb, typeof(int?)),
Expression.Constant(default(CancellationToken), typeof(CancellationToken))
]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ public override Expression Translate(Expression expression)
if (arguments is not
[
_, // source
ParameterExpression continuationToken,
ParameterExpression maxItemCount,
ParameterExpression continuationToken,
ParameterExpression responseContinuationTokenLimitInKb,
..
]
|| _sqlTranslator.Translate(continuationToken) is not SqlParameterExpression translatedContinuationToken
|| _sqlTranslator.Translate(maxItemCount) is not SqlParameterExpression translatedMaxItemCount
|| _sqlTranslator.Translate(continuationToken) is not SqlParameterExpression translatedContinuationToken
|| _sqlTranslator.Translate(responseContinuationTokenLimitInKb) is not SqlParameterExpression
translatedResponseContinuationTokenLimitInKb)
{
Expand All @@ -135,8 +135,8 @@ public override Expression Translate(Expression expression)
return shapedQuery
.UpdateShaperExpression(new PagingExpression(
shapedQuery.ShaperExpression,
translatedContinuationToken,
translatedMaxItemCount,
translatedContinuationToken,
translatedResponseContinuationTokenLimitInKb,
typeof(CosmosPage<>).MakeGenericType(shapedQuery.ShaperExpression.Type)))
.UpdateResultCardinality(ResultCardinality.Single);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ private sealed class PagingQueryingEnumerable<T> : IEnumerable<CosmosPage<T>>, I
private readonly IDiagnosticsLogger<DbLoggerCategory.Database.Command> _commandLogger;
private readonly bool _standAloneStateManager;
private readonly bool _threadSafetyChecksEnabled;
private readonly string _continuationTokenParameterName;
private readonly string _maxItemCountParameterName;
private readonly string _continuationTokenParameterName;
private readonly string _responseContinuationTokenLimitInKbParameterName;

public PagingQueryingEnumerable(
Expand All @@ -48,8 +48,8 @@ public PagingQueryingEnumerable(
PartitionKey partitionKeyValueFromExtension,
bool standAloneStateManager,
bool threadSafetyChecksEnabled,
string continuationTokenParameterName,
string maxItemCountParameterName,
string continuationTokenParameterName,
string responseContinuationTokenLimitInKbParameterName)
{
_cosmosQueryContext = cosmosQueryContext;
Expand All @@ -62,8 +62,8 @@ public PagingQueryingEnumerable(
_commandLogger = cosmosQueryContext.CommandLogger;
_standAloneStateManager = standAloneStateManager;
_threadSafetyChecksEnabled = threadSafetyChecksEnabled;
_continuationTokenParameterName = continuationTokenParameterName;
_maxItemCountParameterName = maxItemCountParameterName;
_continuationTokenParameterName = continuationTokenParameterName;
_responseContinuationTokenLimitInKbParameterName = responseContinuationTokenLimitInKbParameterName;

var partitionKey = selectExpression.GetPartitionKeyValue(cosmosQueryContext.ParameterValues);
Expand Down Expand Up @@ -159,9 +159,9 @@ private async Task<bool> MoveNextCore()

_hasExecuted = true;

var maxItemCount = (int)_cosmosQueryContext.ParameterValues[_queryingEnumerable._maxItemCountParameterName];
var continuationToken =
(string)_cosmosQueryContext.ParameterValues[_queryingEnumerable._continuationTokenParameterName];
var maxItemCount = (int?)_cosmosQueryContext.ParameterValues[_queryingEnumerable._maxItemCountParameterName];
var responseContinuationTokenLimitInKb = (int?)
_cosmosQueryContext.ParameterValues[_queryingEnumerable._responseContinuationTokenLimitInKbParameterName];

Expand All @@ -183,9 +183,9 @@ private async Task<bool> MoveNextCore()
_commandLogger.ExecutingSqlQuery(_cosmosContainer, _cosmosPartitionKeyValue, sqlQuery);
_cosmosQueryContext.InitializeStateManager(_standAloneStateManager);

var results = maxItemCount.HasValue ? new List<T>(maxItemCount.Value) : [];
var results = new List<T>(maxItemCount);

while (maxItemCount is null or > 0)
while (maxItemCount > 0)
{
queryRequestOptions.MaxItemCount = maxItemCount;
using var feedIterator = cosmosClient.CreateQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery

var shaperBody = shapedQueryExpression.ShaperExpression;

var (paging, continuationToken, maxItemCount, responseContinuationTokenLimitInKb) =
var (paging, maxItemCount, continuationToken, responseContinuationTokenLimitInKb) =
(false, (SqlParameterExpression)null, (SqlParameterExpression)null, (SqlParameterExpression)null);

// If the query is terminated ToPageAsync(), CosmosQueryableMethodTranslatingExpressionVisitor composed a PagingExpression on top
Expand All @@ -54,8 +54,8 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery
if (shaperBody is PagingExpression pagingExpression)
{
paging = true;
continuationToken = pagingExpression.ContinuationToken;
maxItemCount = pagingExpression.MaxItemCount;
continuationToken = pagingExpression.ContinuationToken;
responseContinuationTokenLimitInKb = pagingExpression.ResponseContinuationTokenLimitInKb;

shaperBody = pagingExpression.Expression;
Expand Down Expand Up @@ -113,8 +113,8 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery
Constant(_partitionKeyValueFromExtension, typeof(PartitionKey)),
standAloneStateManagerConstant,
threadSafetyConstant,
Constant(continuationToken.Name),
Constant(maxItemCount.Name),
Constant(continuationToken.Name),
Constant(responseContinuationTokenLimitInKb.Name)),

_ => New(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal.Expressions;
/// </summary>
public class PagingExpression(
Expression expression,
SqlParameterExpression continuationToken,
SqlParameterExpression maxItemCount,
SqlParameterExpression continuationToken,
SqlParameterExpression responseContinuationTokenLimitInKb,
Type type)
: Expression, IPrintableExpression
Expand Down Expand Up @@ -48,15 +48,15 @@ public sealed override ExpressionType NodeType
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SqlParameterExpression ContinuationToken { get; } = continuationToken;
public virtual SqlParameterExpression MaxItemCount { get; } = maxItemCount;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SqlParameterExpression MaxItemCount { get; } = maxItemCount;
public virtual SqlParameterExpression ContinuationToken { get; } = continuationToken;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5340,7 +5340,7 @@ public virtual async Task ToPageAsync()

var page2 = await context.Set<Customer>()
.OrderBy(c => c.CustomerID)
.ToPageAsync(continuationToken: page1.ContinuationToken, maxItemCount: 2);
.ToPageAsync(maxItemCount: 2, page1.ContinuationToken);

Assert.Collection(
page2.Values,
Expand All @@ -5349,7 +5349,7 @@ public virtual async Task ToPageAsync()

var page3 = await context.Set<Customer>()
.OrderBy(c => c.CustomerID)
.ToPageAsync(continuationToken: page2.ContinuationToken);
.ToPageAsync(maxItemCount: totalCustomers, page2.ContinuationToken);

Assert.Equal(totalCustomers - 3, page3.Values.Count);
Assert.Null(page3.ContinuationToken);
Expand Down Expand Up @@ -5383,6 +5383,36 @@ ORDER BY c["CustomerID"]
""");
}

[ConditionalFact]
public virtual async Task ToPageAsync_with_exact_maxItemCount()
{
await using var context = CreateContext();

var totalCustomers = await context.Set<Customer>().CountAsync();

var onlyPage = await context.Set<Customer>()
.OrderBy(c => c.CustomerID)
.ToPageAsync(maxItemCount: totalCustomers);

Assert.Equal("ALFKI", onlyPage.Values[0].CustomerID);
Assert.Equal("WOLZA", onlyPage.Values[^1].CustomerID);
Assert.Null(onlyPage.ContinuationToken);

AssertSql(
"""
SELECT COUNT(1) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""",
//
"""
SELECT c
FROM root c
WHERE (c["Discriminator"] = "Customer")
ORDER BY c["CustomerID"]
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit 3ada916

Please sign in to comment.