Skip to content

Commit

Permalink
Query: translate stringColumn.FirstOrDefault() using SUBSTRING on Sql…
Browse files Browse the repository at this point in the history
…Server (#10912) (#20840)

Summary of changes:
- Added translation of FirstOrDefault, LastOrDefault for Sql Server, SqLite, Cosmos.

The corresponding tests have been implemented, but tests for the Cosmos  are  disabled, for a while.
  • Loading branch information
Alexey Dvoretskiy authored May 15, 2020
1 parent 7d5c2aa commit 78bc42b
Show file tree
Hide file tree
Showing 14 changed files with 564 additions and 328 deletions.
22 changes: 22 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/StringMethodTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Utilities;
Expand All @@ -26,6 +27,16 @@ private static readonly MethodInfo _startsWithMethodInfo
private static readonly MethodInfo _endsWithMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) });

private static readonly MethodInfo _firstOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.FirstOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private static readonly MethodInfo _lastOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.LastOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand Down Expand Up @@ -55,6 +66,17 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return TranslateSystemFunction("CONTAINS", instance, arguments[0], typeof(bool));
}

if (_firstOrDefaultMethodInfoWithoutArgs.Equals(method))
{
return TranslateSystemFunction("LEFT", arguments[0], _sqlExpressionFactory.Constant(1), typeof(char));
}


if (_lastOrDefaultMethodInfoWithoutArgs.Equals(method))
{
return TranslateSystemFunction("RIGHT", arguments[0], _sqlExpressionFactory.Constant(1), typeof(char));
}

if (_startsWithMethodInfo.Equals(method))
{
return TranslateSystemFunction("STARTSWITH", instance, arguments[0], typeof(bool));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ private static readonly MethodInfo _containsMethodInfo
private static readonly MethodInfo _endsWithMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) });

private static readonly MethodInfo _firstOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.FirstOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private static readonly MethodInfo _lastOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.LastOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private readonly ISqlExpressionFactory _sqlExpressionFactory;

private const char LikeEscapeChar = '\\';
Expand Down Expand Up @@ -109,7 +119,7 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new [] { true, true },
argumentsPropagateNullability: new[] { true, true },
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
Expand Down Expand Up @@ -273,7 +283,7 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method

if (pattern is SqlConstantExpression constantPattern)
{
// Intentionally string.Empty since we don't want to match nulls here.
// Intentionally string.Empty since we don't want to match nulls here.
#pragma warning disable CA1820 // Test for empty strings using string length
if ((string)constantPattern.Value == string.Empty)
#pragma warning restore CA1820 // Test for empty strings using string length
Expand Down Expand Up @@ -305,6 +315,36 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
_sqlExpressionFactory.Constant(0)));
}

if (_firstOrDefaultMethodInfoWithoutArgs.Equals(method))
{
var argument = arguments[0];
return _sqlExpressionFactory.Function(
"SUBSTRING",
new[] { argument, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1) },
nullable: true,
argumentsPropagateNullability: new[] { true, true, true },
method.ReturnType);
}


if (_lastOrDefaultMethodInfoWithoutArgs.Equals(method))
{
var argument = arguments[0];
return _sqlExpressionFactory.Function(
"SUBSTRING",
new[] { argument,
_sqlExpressionFactory.Function(
"LEN",
new[] { argument },
nullable: true,
argumentsPropagateNullability: new[] { true },
typeof(int)),
_sqlExpressionFactory.Constant(1) },
nullable: true,
argumentsPropagateNullability: new[] { true, true, true },
method.ReturnType);
}

if (_startsWithMethodInfo.Equals(method))
{
return TranslateStartsEndsWith(instance, arguments[0], true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ private static readonly MethodInfo _containsMethodInfo
private static readonly MethodInfo _endsWithMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) });

private static readonly MethodInfo _firstOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.FirstOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private static readonly MethodInfo _lastOrDefaultMethodInfoWithoutArgs
= typeof(Enumerable).GetRuntimeMethods().Single(
m => m.Name == nameof(Enumerable.LastOrDefault)
&& m.GetParameters().Length == 1).MakeGenericMethod(new[] { typeof(char) });

private readonly ISqlExpressionFactory _sqlExpressionFactory;
private const char LikeEscapeChar = '\\';

Expand Down Expand Up @@ -224,6 +234,36 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
_sqlExpressionFactory.Constant(0)));
}

if (_firstOrDefaultMethodInfoWithoutArgs.Equals(method))
{
var argument = arguments[0];
return _sqlExpressionFactory.Function(
"substr",
new[] { argument, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1) },
nullable: true,
argumentsPropagateNullability: new[] { true, true, true },
method.ReturnType);
}


if (_lastOrDefaultMethodInfoWithoutArgs.Equals(method))
{
var argument = arguments[0];
return _sqlExpressionFactory.Function(
"substr",
new[] { argument,
_sqlExpressionFactory.Function(
"length",
new[] { argument },
nullable: true,
argumentsPropagateNullability: new[] { true },
typeof(int)),
_sqlExpressionFactory.Constant(1) },
nullable: true,
argumentsPropagateNullability: new[] { true, true, true },
method.ReturnType);
}

if (_startsWithMethodInfo.Equals(method))
{
return TranslateStartsEndsWith(instance, arguments[0], true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1374,12 +1374,12 @@ public override Task Contains_over_entityType_with_null_should_rewrite_to_false(
return base.Contains_over_entityType_with_null_should_rewrite_to_false(async);
}

public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
public override async Task String_FirstOrDefault_in_projection_does_not_do_client_eval(bool async)
{
await base.String_FirstOrDefault_in_projection_does_client_eval(async);
await base.String_FirstOrDefault_in_projection_does_not_do_client_eval(async);

AssertSql(
@"SELECT c[""CustomerID""]
@"SELECT LEFT(c[""CustomerID""], 1) AS c
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], c[""ContactName""]))");
}

[ConditionalTheory(Skip = "Issue #16919")]
public override async Task String_FirstOrDefault_MethodCall(bool async)
{
await base.String_FirstOrDefault_MethodCall(async);
AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (LEFT(c[""ContactName""], 1) = ""A""))");
}

[ConditionalTheory(Skip = "Issue #16919")]
public override async Task String_LastOrDefault_MethodCall(bool async)
{
await base.String_LastOrDefault_MethodCall(async);
AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (RIGHT(c[""ContactName""], 1) = ""s""))");
}

public override async Task String_Contains_MethodCall(bool async)
{
await base.String_Contains_MethodCall(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,17 @@ public virtual Task Empty_subquery_with_contains_negated_returns_true(bool async
ss => ss.Set<NullSemanticsEntity1>().Where(e => !ss.Set<NullSemanticsEntity2>().Where(x => false).Select(x => x.NullableIntA).Contains(e.NullableIntA)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Nullable_string_FirstOrDefault_compared_to_nullable_string_LastOrDefault(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.FirstOrDefault() == e.NullableStringB.LastOrDefault()),
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.MaybeScalar(x => x.FirstOrDefault())
== e.NullableStringB.MaybeScalar(x => x.LastOrDefault())));
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
25 changes: 25 additions & 0 deletions test/EFCore.Specification.Tests/Query/FunkyDataQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,31 @@ public virtual Task String_ends_with_not_equals_nullable_column(bool async)
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_FirstOrDefault_and_LastOrDefault(bool async)
{
return AssertQuery(
async,
ss => ss.Set<FunkyCustomer>().OrderBy(e => e.Id).Select(e => new
{
first = (char?)e.FirstName.FirstOrDefault(),
last = (char?)e.FirstName.LastOrDefault()
}),
ss => ss.Set<FunkyCustomer>().OrderBy(e => e.Id).Select(e => new
{
first = e.FirstName.MaybeScalar(x => x.FirstOrDefault()),
last = e.FirstName.MaybeScalar(x => x.LastOrDefault())
}),
assertOrder: true,
elementAsserter: (e, a) =>
{
AssertEqual(e.first, a.first);
AssertEqual(e.last, a.last);
});
}

protected FunkyDataContext CreateContext() => Fixture.CreateContext();

protected virtual void ClearLog()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,7 @@ public virtual Task Contains_over_entityType_should_materialize_when_composite2(

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
public virtual Task String_FirstOrDefault_in_projection_does_not_do_client_eval(bool async)
{
return AssertQueryScalar(
async,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ public virtual Task String_Contains_Column(bool async)
entryCount: 91);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_FirstOrDefault_MethodCall(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => c.ContactName.FirstOrDefault() == 'A'),
entryCount: 10);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_LastOrDefault_MethodCall(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => c.ContactName.LastOrDefault() == 's'),
entryCount: 9);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Contains_MethodCall(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ ELSE CAST(0 AS bit)
END <> [f].[NullableBool]) OR [f].[NullableBool] IS NULL");
}

public override async Task String_FirstOrDefault_and_LastOrDefault(bool async)
{
await base.String_FirstOrDefault_and_LastOrDefault(async);

AssertSql(
@"SELECT SUBSTRING([f].[FirstName], 1, 1) AS [first], SUBSTRING([f].[FirstName], LEN([f].[FirstName]), 1) AS [last]
FROM [FunkyCustomers] AS [f]
ORDER BY [f].[Id]");
}

protected override void ClearLog()
=> Fixture.TestSqlLoggerFactory.Clear();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1279,12 +1279,12 @@ FROM [Order Details] AS [o0]
WHERE ([o0].[OrderID] > 42) AND (([o0].[OrderID] = [o].[OrderID]) AND ([o0].[ProductID] = [o].[ProductID])))");
}

public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
public override async Task String_FirstOrDefault_in_projection_does_not_do_client_eval(bool async)
{
await base.String_FirstOrDefault_in_projection_does_client_eval(async);
await base.String_FirstOrDefault_in_projection_does_not_do_client_eval(async);

AssertSql(
@"SELECT [c].[CustomerID]
@"SELECT SUBSTRING([c].[CustomerID], 1, 1)
FROM [Customers] AS [c]");
}

Expand Down
Loading

0 comments on commit 78bc42b

Please sign in to comment.