Skip to content

Commit

Permalink
Remove unneeded Cosmos 3-value logic compensation for InExpression
Browse files Browse the repository at this point in the history
Closes #31063
  • Loading branch information
roji committed Jun 5, 2024
1 parent 0773a30 commit f020387
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 93 deletions.
6 changes: 0 additions & 6 deletions src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions src/EFCore.Cosmos/Properties/CosmosStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,6 @@
<data name="OneOfTwoValuesMustBeSet" xml:space="preserve">
<value>Exactly one of '{param1}' or '{param2}' must be set.</value>
</data>
<data name="OnlyConstantsAndParametersAllowedInContains" xml:space="preserve">
<value>Only constants or parameters are currently allowed in Contains.</value>
</data>
<data name="OrphanedNestedDocument" xml:space="preserve">
<value>The entity of type '{entityType}' is mapped as a part of the document mapped to '{missingEntityType}', but there is no tracked entity of this type with the corresponding key value. Consider using 'DbContextOptionsBuilder.EnableSensitiveDataLogging' to see the key values.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,99 +10,47 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;

public partial class CosmosShapedQueryCompilingExpressionVisitor
{
private sealed class InExpressionValuesExpandingExpressionVisitor : ExpressionVisitor
private sealed class InExpressionValuesExpandingExpressionVisitor(
ISqlExpressionFactory sqlExpressionFactory,
IReadOnlyDictionary<string, object> parametersValues)
: ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IReadOnlyDictionary<string, object> _parametersValues;

public InExpressionValuesExpandingExpressionVisitor(
ISqlExpressionFactory sqlExpressionFactory,
IReadOnlyDictionary<string, object> parametersValues)
{
_sqlExpressionFactory = sqlExpressionFactory;
_parametersValues = parametersValues;
}

public override Expression Visit(Expression expression)
protected override Expression VisitExtension(Expression expression)
{
if (expression is InExpression inExpression)
{
var inValues = new List<SqlExpression>();
var hasNullValue = false;
IReadOnlyList<SqlExpression> values;

switch (inExpression)
{
case { ValuesParameter: SqlParameterExpression valuesParameter }:
{
var typeMapping = valuesParameter.TypeMapping;

foreach (var value in (IEnumerable)_parametersValues[valuesParameter.Name])
{
if (value is null)
{
hasNullValue = true;
continue;
}

inValues.Add(_sqlExpressionFactory.Constant(value, typeMapping));
}

case { Values: IReadOnlyList<SqlExpression> values2 }:
values = values2;
break;
}

case { Values: IReadOnlyList<SqlExpression> values }:
// TODO: IN with subquery (return immediately, nothing to do here)

case { ValuesParameter: SqlParameterExpression valuesParameter }:
{
foreach (var value in values)
var typeMapping = valuesParameter.TypeMapping;
var mutableValues = new List<SqlExpression>();
foreach (var value in (IEnumerable)parametersValues[valuesParameter.Name])
{
if (value is not (SqlConstantExpression or SqlParameterExpression))
{
throw new InvalidOperationException(CosmosStrings.OnlyConstantsAndParametersAllowedInContains);
}

if (IsNull(value))
{
hasNullValue = true;
continue;
}

inValues.Add(value);
mutableValues.Add(sqlExpressionFactory.Constant(value, typeMapping));
}

values = mutableValues;
break;
}

default:
throw new UnreachableException();
}

var updatedInExpression = inValues.Count > 0
? _sqlExpressionFactory.In((SqlExpression)Visit(inExpression.Item), inValues)
: null;

var nullCheckExpression = hasNullValue
? _sqlExpressionFactory.IsNull(inExpression.Item)
: null;

if (updatedInExpression != null
&& nullCheckExpression != null)
{
return _sqlExpressionFactory.OrElse(updatedInExpression, nullCheckExpression);
}

if (updatedInExpression == null
&& nullCheckExpression == null)
{
return _sqlExpressionFactory.Equal(_sqlExpressionFactory.Constant(true), _sqlExpressionFactory.Constant(false));
}

return (SqlExpression)updatedInExpression ?? nullCheckExpression;
return values.Count == 0
? sqlExpressionFactory.ApplyDefaultTypeMapping(sqlExpressionFactory.Constant(false))
: sqlExpressionFactory.In((SqlExpression)Visit(inExpression.Item), values);
}

return base.Visit(expression);
return base.VisitExtension(expression);
}

private bool IsNull(SqlExpression expression)
=> expression is SqlConstantExpression { Value: null }
|| expression is SqlParameterExpression { Name: string parameterName } && _parametersValues[parameterName] is null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1806,9 +1806,8 @@ public override Task Contains_with_local_ordered_read_only_collection_all_null(b
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND (c["CustomerID"] = null))
"""
);
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN (null, null))
""");
});

public override Task Contains_with_local_read_only_collection_inline(bool async)
Expand Down Expand Up @@ -1969,7 +1968,7 @@ public override Task Contains_with_local_collection_empty_inline(bool async)
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND NOT((true = false)))
WHERE ((c["Discriminator"] = "Customer") AND NOT(false))
""");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4158,7 +4158,7 @@ public override Task Entity_equality_contains_with_list_of_null(bool async)
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND (c["CustomerID"] IN ("ALFKI") OR (c["CustomerID"] = null)))
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN (null, "ALFKI"))
""");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public override Task Inline_collection_of_nullable_ints_Contains_null(bool async
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND (c["NullableInt"] IN (999) OR (c["NullableInt"] = null)))
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND c["NullableInt"] IN (null, 999))
""");
});

Expand Down Expand Up @@ -137,7 +137,7 @@ public override Task Inline_collection_Contains_with_zero_values(bool async)
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND (true = false))
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND false)
""");
});

Expand Down Expand Up @@ -230,13 +230,37 @@ FROM root c
""");
});

// TODO: Remove incorrect null semantics compensation for Cosmos: #31063
public override Task Inline_collection_Contains_with_mixed_value_types(bool async)
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Inline_collection_Contains_with_mixed_value_types(async));
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Inline_collection_Contains_with_mixed_value_types(a);

AssertSql(
"""
@__i_0='11'
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND c["Int"] IN (999, @__i_0, c["Id"], (c["Id"] + c["Int"])))
""");
});

// TODO: Remove incorrect null semantics compensation for Cosmos: #31063
public override Task Inline_collection_List_Contains_with_mixed_value_types(bool async)
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Inline_collection_List_Contains_with_mixed_value_types(async));
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Inline_collection_List_Contains_with_mixed_value_types(a);

AssertSql(
"""
@__i_0='11'
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND c["Int"] IN (999, @__i_0, c["Id"], (c["Id"] + c["Int"])))
""");
});

public override Task Inline_collection_Contains_as_Any_with_predicate(bool async)
=> CosmosTestHelpers.Instance.NoSyncTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public virtual async Task Inline_collection_List_Contains_with_mixed_value_types

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => new List<int>() { 999, i, c.Id, c.Id + c.Int }.Contains(c.Int)));
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => new List<int> { 999, i, c.Id, c.Id + c.Int }.Contains(c.Int)));
}

[ConditionalTheory]
Expand Down

0 comments on commit f020387

Please sign in to comment.