Skip to content

Commit

Permalink
Clean up container management in Cosmos (#33898)
Browse files Browse the repository at this point in the history
  • Loading branch information
roji authored Jun 5, 2024
1 parent 59bf9dd commit 0773a30
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 79 deletions.
8 changes: 8 additions & 0 deletions src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore.Cosmos/Properties/CosmosStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@
<data name="MissingOrderingInSelectExpression" xml:space="preserve">
<value>'Reverse' could not be translated to the server because there is no ordering on the server side.</value>
</data>
<data name="MultipleContainersReferencedInQuery" xml:space="preserve">
<value>Cosmos container '{container1}' is referenced by the query, but '{container2}' is already being referenced. A query can only reference a single Cosmos container.</value>
</data>
<data name="NavigationPropertyIsNotAnEmbeddedEntity" xml:space="preserve">
<value>Navigation '{entityType}.{navigationName}' doesn't point to an embedded entity.</value>
</data>
Expand Down
11 changes: 11 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/CosmosQueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
public class CosmosQueryCompilationContext(QueryCompilationContextDependencies dependencies, bool async)
: QueryCompilationContext(dependencies, async)
{
/// <summary>
/// The name of the Cosmos container against which this query will be executed.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public virtual string? CosmosContainer { get; internal set; }

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// </summary>
public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly CosmosQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly ITypeMappingSource _typeMappingSource;
private readonly IMemberTranslatorProvider _memberTranslatorProvider;
Expand All @@ -31,7 +31,7 @@ public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethod
/// </summary>
public CosmosQueryableMethodTranslatingExpressionVisitor(
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext,
CosmosQueryCompilationContext queryCompilationContext,
ISqlExpressionFactory sqlExpressionFactory,
ITypeMappingSource typeMappingSource,
IMemberTranslatorProvider memberTranslatorProvider,
Expand Down Expand Up @@ -258,13 +258,29 @@ protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVis
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
=> CreateShapedQueryExpression(entityType, _sqlExpressionFactory.Select(entityType));

private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression)
=> new(
private ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression)
{
if (!entityType.IsOwned())
{
var cosmosContainer = entityType.GetContainer();
var existingContainer = _queryCompilationContext.CosmosContainer;
Check.DebugAssert(cosmosContainer is not null, "Non-owned entity type without a Cosmos container");

if (existingContainer is not null && existingContainer != cosmosContainer)
{
throw new InvalidOperationException(CosmosStrings.MultipleContainersReferencedInQuery(cosmosContainer, existingContainer));
}

_queryCompilationContext.CosmosContainer = cosmosContainer;
}

return new ShapedQueryExpression(
queryExpression,
new StructuralTypeShaperExpression(
entityType,
new ProjectionBindingExpression(queryExpression, new ProjectionMember(), typeof(ValueBuffer)),
false));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -304,7 +320,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
}

var translation = _sqlExpressionFactory.Exists(subquery);
var selectExpression = new SelectExpression(subquery.Container, translation);
var selectExpression = new SelectExpression(translation);

return source.Update(
selectExpression,
Expand Down Expand Up @@ -368,18 +384,17 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
{
// Simplify x.Array.Contains[1] => ARRAY_CONTAINS(x.Array, 1) insert of IN+subquery
if (CosmosQueryUtils.TryExtractBareArray(source, out var array, ignoreOrderings: true)
&& TranslateExpression(item) is SqlExpression translatedItem
&& source.QueryExpression is SelectExpression { Container: var container })
&& TranslateExpression(item) is SqlExpression translatedItem)
{
if (array is ArrayConstantExpression arrayConstant)
{
var inExpression = _sqlExpressionFactory.In(translatedItem, arrayConstant.Items);
return source.Update(new SelectExpression(container, inExpression), source.ShaperExpression);
return source.Update(new SelectExpression(inExpression), source.ShaperExpression);
}

(translatedItem, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(translatedItem, array);
var simplifiedTranslation = _sqlExpressionFactory.Function("ARRAY_CONTAINS", new[] { array, translatedItem }, typeof(bool));
return source.UpdateQueryExpression(new SelectExpression(container, simplifiedTranslation));
return source.UpdateQueryExpression(new SelectExpression(simplifiedTranslation));
}

// TODO: Translation to IN, with scalars and with subquery
Expand All @@ -396,11 +411,10 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
{
// Simplify x.Array.Count() => ARRAY_LENGTH(x.Array) instead of (SELECT COUNT(1) FROM i IN x.Array))
if (predicate is null
&& CosmosQueryUtils.TryExtractBareArray(source, out var array, ignoreOrderings: true)
&& source.QueryExpression is SelectExpression { Container: var container })
&& CosmosQueryUtils.TryExtractBareArray(source, out var array, ignoreOrderings: true))
{
var simplifiedTranslation = _sqlExpressionFactory.Function("ARRAY_LENGTH", new[] { array }, typeof(int));
return source.UpdateQueryExpression(new SelectExpression(container, simplifiedTranslation));
return source.UpdateQueryExpression(new SelectExpression(simplifiedTranslation));
}

var selectExpression = (SelectExpression)source.QueryExpression;
Expand Down Expand Up @@ -470,12 +484,11 @@ protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression
// Simplify x.Array[1] => x.Array[1] (using the Cosmos array subscript operator) instead of a subquery with LIMIT/OFFSET
if (!returnDefault
&& CosmosQueryUtils.TryExtractBareArray(source, out var array, out var projectedScalarReference)
&& TranslateExpression(index) is { } translatedIndex
&& source.QueryExpression is SelectExpression { Container: var container })
&& TranslateExpression(index) is { } translatedIndex)
{
var arrayIndex = _sqlExpressionFactory.ArrayIndex(
array, translatedIndex, projectedScalarReference.Type, projectedScalarReference.TypeMapping);
return source.UpdateQueryExpression(new SelectExpression(container, arrayIndex));
return source.UpdateQueryExpression(new SelectExpression(arrayIndex));
}

// Note that Cosmos doesn't support OFFSET/LIMIT in subqueries, so this translation is for top-level entity querying only.
Expand Down Expand Up @@ -1252,8 +1265,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
var innerSelect = new SelectExpression(
[new ProjectionExpression(inlineArray, null!)],
sources: [],
orderings: [],
container: null!)
orderings: [])
{
UsesSingleValueProjection = true
};
Expand Down Expand Up @@ -1289,8 +1301,7 @@ [new ProjectionExpression(inlineArray, null!)],
var innerSelect = new SelectExpression(
[new ProjectionExpression(sqlParameterExpression, null!)],
sources: [],
orderings: [],
container: null!)
orderings: [])
{
UsesSingleValueProjection = true
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class CosmosQueryableMethodTranslatingExpressionVisitorFactory(
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new CosmosQueryableMethodTranslatingExpressionVisitor(
Dependencies,
queryCompilationContext,
(CosmosQueryCompilationContext)queryCompilationContext,
sqlExpressionFactory,
typeMappingSource,
memberTranslatorProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ private sealed class QueryingEnumerable<T> : IEnumerable<T>, IAsyncEnumerable<T>
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly PartitionKey _partitionKeyValue;
private readonly string _cosmosContainer;
private readonly PartitionKey _cosmosPartitionKeyValue;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
private readonly bool _standAloneStateManager;
private readonly bool _threadSafetyChecksEnabled;
Expand All @@ -39,6 +40,7 @@ public QueryingEnumerable(
SelectExpression selectExpression,
Func<CosmosQueryContext, JObject, T> shaper,
Type contextType,
string cosmosContainer,
PartitionKey partitionKeyValueFromExtension,
bool standAloneStateManager,
bool threadSafetyChecksEnabled)
Expand All @@ -61,7 +63,8 @@ public QueryingEnumerable(
throw new InvalidOperationException(CosmosStrings.PartitionKeyMismatch(partitionKeyValueFromExtension, partitionKey));
}

_partitionKeyValue = partitionKey != PartitionKey.None ? partitionKey : partitionKeyValueFromExtension;
_cosmosPartitionKeyValue = partitionKey != PartitionKey.None ? partitionKey : partitionKeyValueFromExtension;
_cosmosContainer = cosmosContainer;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -107,10 +110,10 @@ private sealed class Enumerator : IEnumerator<T>
{
private readonly QueryingEnumerable<T> _queryingEnumerable;
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly PartitionKey _partitionKeyValue;
private readonly string _cosmosContainer;
private readonly PartitionKey _cosmosPartitionKeyValue;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
private readonly bool _standAloneStateManager;
private readonly IConcurrencyDetector _concurrencyDetector;
Expand All @@ -123,9 +126,9 @@ public Enumerator(QueryingEnumerable<T> queryingEnumerable)
_queryingEnumerable = queryingEnumerable;
_cosmosQueryContext = queryingEnumerable._cosmosQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_contextType = queryingEnumerable._contextType;
_partitionKeyValue = queryingEnumerable._partitionKeyValue;
_cosmosContainer = queryingEnumerable._cosmosContainer;
_cosmosPartitionKeyValue = queryingEnumerable._cosmosPartitionKeyValue;
_queryLogger = queryingEnumerable._queryLogger;
_standAloneStateManager = queryingEnumerable._standAloneStateManager;
_exceptionDetector = _cosmosQueryContext.ExceptionDetector;
Expand Down Expand Up @@ -156,8 +159,8 @@ public bool MoveNext()

_enumerator = _cosmosQueryContext.CosmosClient
.ExecuteSqlQuery(
_selectExpression.Container,
_partitionKeyValue,
_cosmosContainer,
_cosmosPartitionKeyValue,
sqlQuery)
.GetEnumerator();
_cosmosQueryContext.InitializeStateManager(_standAloneStateManager);
Expand Down Expand Up @@ -206,10 +209,10 @@ private sealed class AsyncEnumerator : IAsyncEnumerator<T>
{
private readonly QueryingEnumerable<T> _queryingEnumerable;
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly PartitionKey _partitionKeyValue;
private readonly string _cosmosContainer;
private readonly PartitionKey _cosmosPartitionKeyValue;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
private readonly bool _standAloneStateManager;
private readonly CancellationToken _cancellationToken;
Expand All @@ -223,9 +226,9 @@ public AsyncEnumerator(QueryingEnumerable<T> queryingEnumerable, CancellationTok
_queryingEnumerable = queryingEnumerable;
_cosmosQueryContext = queryingEnumerable._cosmosQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_contextType = queryingEnumerable._contextType;
_partitionKeyValue = queryingEnumerable._partitionKeyValue;
_cosmosContainer = queryingEnumerable._cosmosContainer;
_cosmosPartitionKeyValue = queryingEnumerable._cosmosPartitionKeyValue;
_queryLogger = queryingEnumerable._queryLogger;
_standAloneStateManager = queryingEnumerable._standAloneStateManager;
_exceptionDetector = _cosmosQueryContext.ExceptionDetector;
Expand Down Expand Up @@ -254,8 +257,8 @@ public async ValueTask<bool> MoveNextAsync()

_enumerator = _cosmosQueryContext.CosmosClient
.ExecuteSqlQueryAsync(
_selectExpression.Container,
_partitionKeyValue,
_cosmosContainer,
_cosmosPartitionKeyValue,
sqlQuery)
.GetAsyncEnumerator(_cancellationToken);
_cosmosQueryContext.InitializeStateManager(_standAloneStateManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor
private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnumerable<T>, IQueryingEnumerable
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly string _cosmosContainer;
private readonly ReadItemExpression _readItemExpression;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
Expand All @@ -31,13 +32,15 @@ private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnume

public ReadItemQueryingEnumerable(
CosmosQueryContext cosmosQueryContext,
string cosmosContainer,
ReadItemExpression readItemExpression,
Func<CosmosQueryContext, JObject, T> shaper,
Type contextType,
bool standAloneStateManager,
bool threadSafetyChecksEnabled)
{
_cosmosQueryContext = cosmosQueryContext;
_cosmosContainer = cosmosContainer;
_readItemExpression = readItemExpression;
_shaper = shaper;
_contextType = contextType;
Expand Down Expand Up @@ -169,7 +172,7 @@ private bool TryGenerateIdFromKeys(IProperty idProperty, out object value)
private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly ReadItemExpression _readItemExpression;
private readonly string _cosmosContainer;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
Expand All @@ -185,7 +188,7 @@ private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
public Enumerator(ReadItemQueryingEnumerable<T> readItemEnumerable, CancellationToken cancellationToken = default)
{
_cosmosQueryContext = readItemEnumerable._cosmosQueryContext;
_readItemExpression = readItemEnumerable._readItemExpression;
_cosmosContainer = readItemEnumerable._cosmosContainer;
_shaper = readItemEnumerable._shaper;
_contextType = readItemEnumerable._contextType;
_queryLogger = readItemEnumerable._queryLogger;
Expand Down Expand Up @@ -227,7 +230,7 @@ public bool MoveNext()
EntityFrameworkEventSource.Log.QueryExecuting();

_item = _cosmosQueryContext.CosmosClient.ExecuteReadItem(
_readItemExpression.Container,
_cosmosContainer,
partitionKeyValue,
resourceId);

Expand Down Expand Up @@ -279,7 +282,7 @@ public async ValueTask<bool> MoveNextAsync()
EntityFrameworkEventSource.Log.QueryExecuting();

_item = await _cosmosQueryContext.CosmosClient.ExecuteReadItemAsync(
_readItemExpression.Container,
_cosmosContainer,
partitionKeyValue,
resourceId,
_cancellationToken)
Expand Down
Loading

0 comments on commit 0773a30

Please sign in to comment.