Skip to content

Commit

Permalink
Cosmos: use ReadItem for more Find cases
Browse files Browse the repository at this point in the history
This change stops using `ReadItemExpression`, since it does not do any shaper processing, and instead uses `SelectExpression` as for other queries. This allows processing of auto-Includes for owned types, even if Find is being translated to `ReadItem`.

This is somewhat hacky now, but it likely to change again as the processing for Includes and/or complex types is changed. For now, it's a reasonable way to get an important feature working.

Fixes #24202

In addition, the pattern matching has been updated to detect calls to non-generic Find.

Fixes #33881
  • Loading branch information
ajcvickers committed Jun 4, 2024
1 parent 19a7059 commit 481b19a
Show file tree
Hide file tree
Showing 16 changed files with 690 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethod
private readonly IMethodCallTranslatorProvider _methodCallTranslatorProvider;
private readonly CosmosSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly CosmosProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private ReadItemInfo? _readItemExpression;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -80,56 +81,85 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor(
[return: NotNullIfNotNull(nameof(expression))]
public override Expression? Visit(Expression? expression)
{
if (expression is MethodCallExpression
if (_queryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll // Issue #33893
&& expression is MethodCallExpression
{
Method: { Name: nameof(Queryable.FirstOrDefault), IsGenericMethod: true },
Arguments:
[
MethodCallExpression
{
Method: { Name: nameof(Queryable.Where), IsGenericMethod: true },
Arguments:
[
EntityQueryRootExpression { EntityType: var entityType },
UnaryExpression { Operand: LambdaExpression lambdaExpression, NodeType: ExpressionType.Quote }
]
} whereMethodCall
]
} firstOrDefaultMethodCall
&& firstOrDefaultMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.FirstOrDefaultWithoutPredicate
&& whereMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.Where)
Arguments: [MethodCallExpression innerMethodCall]
})
{
var queryProperties = new List<IProperty>();
var parameterNames = new List<string>();

if (ExtractPartitionKeyFromPredicate(entityType, lambdaExpression.Body, queryProperties, parameterNames))
var clrType = innerMethodCall.Type.TryGetSequenceType() ?? typeof(object);
if (innerMethodCall is
{
Method: { Name: nameof(Queryable.Select), IsGenericMethod: true },
Arguments:
[
MethodCallExpression innerInnerMethodCall,
UnaryExpression { NodeType: ExpressionType.Quote } unaryExpression
]
})
{
var entityTypePrimaryKeyProperties = entityType.FindPrimaryKey()!.Properties;
var idProperty = entityType.GetProperties()
.First(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);
var partitionKeyProperties = entityType.GetPartitionKeyProperties();

if (entityTypePrimaryKeyProperties.SequenceEqual(queryProperties)
&& (!partitionKeyProperties.Any()
|| partitionKeyProperties.All(p => entityTypePrimaryKeyProperties.Contains(p)))
&& (idProperty.GetValueGeneratorFactory() != null
|| entityTypePrimaryKeyProperties.Contains(idProperty)))
if (unaryExpression.Operand is LambdaExpression)
{
var propertyParameterList = queryProperties.Zip(
parameterNames,
(property, parameter) => (property, parameter))
.ToDictionary(tuple => tuple.property, tuple => tuple.parameter);
innerMethodCall = innerInnerMethodCall;
}
}

var readItemExpression = new ReadItemExpression(entityType, propertyParameterList);
if (innerMethodCall is
{
Method: { Name: nameof(Queryable.Where), IsGenericMethod: true },
Arguments:
[
EntityQueryRootExpression { EntityType: var entityType },
UnaryExpression { Operand: LambdaExpression lambdaExpression, NodeType: ExpressionType.Quote }
]
})
{
var queryProperties = new List<IProperty>();
var parameterNames = new List<string>();

if (ExtractPartitionKeyFromPredicate(entityType, lambdaExpression.Body, queryProperties, parameterNames))
{
var entityTypePrimaryKeyProperties = entityType.FindPrimaryKey()!.Properties;
var idProperty = entityType.GetProperties()
.First(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);
var partitionKeyProperties = entityType.GetPartitionKeyProperties();

if (entityTypePrimaryKeyProperties.SequenceEqual(queryProperties)
&& (!partitionKeyProperties.Any()
|| partitionKeyProperties.All(p => entityTypePrimaryKeyProperties.Contains(p)))
// This should ideally only be looking for properties with the `IdValueGeneratorFactory` generator. since
// this is how the `id` property will be generated from other key values.
&& ((idProperty.GetValueGeneratorFactory() != null
// If we can't create an instance, then we might not be able to construct the resource id.
&& CanCreateEmptyInstance(entityType))
|| entityTypePrimaryKeyProperties.Contains(idProperty)))
{
var propertyParameterList = queryProperties.Zip(
parameterNames,
(property, parameter) => (property, parameter))
.ToDictionary(tuple => tuple.property, tuple => tuple.parameter);

return CreateShapedQueryExpression(entityType, readItemExpression)
.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
_readItemExpression = new ReadItemInfo(entityType, propertyParameterList, clrType);
}
}
}
}

return base.Visit(expression);

static bool CanCreateEmptyInstance(IEntityType entityType)
{
var binding = entityType.ServiceOnlyConstructorBinding;
if (binding == null)
{
_ = entityType.ConstructorBinding;
binding = entityType.ServiceOnlyConstructorBinding;
}

return binding != null;
}

static bool ExtractPartitionKeyFromPredicate(
IEntityType entityType,
Expression joinCondition,
Expand Down Expand Up @@ -229,7 +259,11 @@ public override ShapedQueryExpression TranslateSubquery(Expression expression)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
=> CreateShapedQueryExpression(entityType, _sqlExpressionFactory.Select(entityType));
=> CreateShapedQueryExpression(
entityType,
_readItemExpression == null
? _sqlExpressionFactory.Select(entityType)
: _sqlExpressionFactory.ReadItem(entityType, _readItemExpression));

private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression)
=> new(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor
private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnumerable<T>, IQueryingEnumerable
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly ReadItemExpression _readItemExpression;
private readonly ReadItemInfo _readItemInfo;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
Expand All @@ -31,14 +31,14 @@ private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnume

public ReadItemQueryingEnumerable(
CosmosQueryContext cosmosQueryContext,
ReadItemExpression readItemExpression,
ReadItemInfo readItemInfo,
Func<CosmosQueryContext, JObject, T> shaper,
Type contextType,
bool standAloneStateManager,
bool threadSafetyChecksEnabled)
{
_cosmosQueryContext = cosmosQueryContext;
_readItemExpression = readItemExpression;
_readItemInfo = readItemInfo;
_shaper = shaper;
_contextType = contextType;
_queryLogger = _cosmosQueryContext.QueryLogger;
Expand All @@ -64,7 +64,7 @@ public string ToQueryString()

private bool TryGetPartitionKey(out PartitionKey partitionKeyValue)
{
var properties = _readItemExpression.EntityType.GetPartitionKeyProperties();
var properties = _readItemInfo.EntityType.GetPartitionKeyProperties();
if (!properties.Any())
{
partitionKeyValue = PartitionKey.None;
Expand Down Expand Up @@ -92,7 +92,7 @@ private bool TryGetPartitionKey(out PartitionKey partitionKeyValue)

private bool TryGetResourceId(out string resourceId)
{
var idProperty = _readItemExpression.EntityType.GetProperties()
var idProperty = _readItemInfo.EntityType.GetProperties()
.FirstOrDefault(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);

if (TryGetParameterValue(idProperty, out var value))
Expand Down Expand Up @@ -121,7 +121,7 @@ private bool TryGetResourceId(out string resourceId)
private bool TryGetParameterValue(IProperty property, out object value)
{
value = null;
return _readItemExpression.PropertyParameters.TryGetValue(property, out var parameterName)
return _readItemInfo.PropertyParameters.TryGetValue(property, out var parameterName)
&& _cosmosQueryContext.ParameterValues.TryGetValue(parameterName, out value);
}

Expand All @@ -136,40 +136,36 @@ private static string GetString(IProperty property, object value)

private bool TryGenerateIdFromKeys(IProperty idProperty, out object value)
{
var entityEntry = Activator.CreateInstance(_readItemExpression.EntityType.ClrType);

#pragma warning disable EF1001 // Internal EF Core API usage.
// The idea here is that if a `IdValueGeneratorFactory` has been configured to generate an `id` value from the
// values of other properties, then we need an entity instance to use with the value generator.
var entityInstance = _readItemInfo.EntityType.GetOrCreateEmptyMaterializer(_cosmosQueryContext.EntityMaterializerSource)
(new MaterializationContext(ValueBuffer.Empty, _cosmosQueryContext.Context));

var internalEntityEntry = new InternalEntityEntry(
_cosmosQueryContext.Context.GetDependencies().StateManager, _readItemExpression.EntityType, entityEntry);
#pragma warning restore EF1001 // Internal EF Core API usage.
_cosmosQueryContext.Context.GetDependencies().StateManager, _readItemInfo.EntityType, entityInstance);

foreach (var keyProperty in _readItemExpression.EntityType.FindPrimaryKey().Properties)
foreach (var keyProperty in _readItemInfo.EntityType.FindPrimaryKey().Properties)
{
var property = _readItemExpression.EntityType.FindProperty(keyProperty.Name);
var property = _readItemInfo.EntityType.FindProperty(keyProperty.Name);

if (TryGetParameterValue(property, out var parameterValue))
{
#pragma warning disable EF1001 // Internal EF Core API usage.
internalEntityEntry[property] = parameterValue;
#pragma warning restore EF1001 // Internal EF Core API usage.
}
}

#pragma warning disable EF1001 // Internal EF Core API usage.
internalEntityEntry.SetEntityState(EntityState.Added);

value = internalEntityEntry[idProperty];

internalEntityEntry.SetEntityState(EntityState.Detached);
#pragma warning restore EF1001 // Internal EF Core API usage.

return value != null;
#pragma warning restore EF1001 // Internal EF Core API usage.
}

private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly ReadItemExpression _readItemExpression;
private readonly ReadItemInfo _readItemInfo;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
Expand All @@ -185,7 +181,7 @@ private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
public Enumerator(ReadItemQueryingEnumerable<T> readItemEnumerable, CancellationToken cancellationToken = default)
{
_cosmosQueryContext = readItemEnumerable._cosmosQueryContext;
_readItemExpression = readItemEnumerable._readItemExpression;
_readItemInfo = readItemEnumerable._readItemInfo;
_shaper = readItemEnumerable._shaper;
_contextType = readItemEnumerable._contextType;
_queryLogger = readItemEnumerable._queryLogger;
Expand Down Expand Up @@ -227,7 +223,7 @@ public bool MoveNext()
EntityFrameworkEventSource.Log.QueryExecuting();

_item = _cosmosQueryContext.CosmosClient.ExecuteReadItem(
_readItemExpression.Container,
_readItemInfo.Container,
partitionKeyValue,
resourceId);

Expand Down Expand Up @@ -279,7 +275,7 @@ public async ValueTask<bool> MoveNextAsync()
EntityFrameworkEventSource.Log.QueryExecuting();

_item = await _cosmosQueryContext.CosmosClient.ExecuteReadItemAsync(
_readItemExpression.Container,
_readItemInfo.Container,
partitionKeyValue,
resourceId,
_cancellationToken)
Expand Down
Loading

0 comments on commit 481b19a

Please sign in to comment.