Skip to content

Commit

Permalink
Fix conditional expressions not being translated into hql
Browse files Browse the repository at this point in the history
  • Loading branch information
maca88 committed Mar 20, 2019
1 parent 656a88b commit f07db92
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 6 deletions.
71 changes: 71 additions & 0 deletions src/NHibernate.Test/Async/Linq/SelectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,62 @@ public async Task CanSelectWithAnySubQueryAsync()
Assert.AreEqual(1, list.Count(t => !t.HasEntries));
}

[Test]
public async Task CanSelectConditionalAsync()
{
using (var sqlLog = new SqlLogSpy())
{
var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => o.ShippedTo.Contains("test") ? o.ShippedTo : o.Customer.CompanyName)
.OrderBy(o => o)
.Distinct()
.ToListAsync());

Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(2));
}

using (var sqlLog = new SqlLogSpy())
{
var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate)
.FirstOrDefaultAsync());

Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(1));
}

using (var sqlLog = new SqlLogSpy())
{
var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => new
{
Value = o.OrderDate.HasValue
? o.Customer.CompanyName
: (o.ShippingDate.HasValue
? o.Shipper.CompanyName
: o.ShippedTo)
})
.FirstOrDefaultAsync());

var log = sqlLog.GetWholeLog();
Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1));
}
}

[Test]
public async Task CanSelectConditionalSubQueryAsync()
{
var orders = await (db.Customers
.Select(c => new
{
Date = db.Orders.Where(o => o.Customer.CustomerId == c.CustomerId)
.Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate)
.First()
})
.ToListAsync());

Assert.That(orders, Has.Count.GreaterThan(0));
}

[Test, KnownBug("NH-3045")]
public async Task CanSelectFirstElementFromChildCollectionAsync()
{
Expand Down Expand Up @@ -458,5 +514,20 @@ public class Wrapper<T>
public T item;
public string message;
}

private int FindAllOccurrences(string source, string substring)
{
if (source == null)
{
return 0;
}
int n = 0, count = 0;
while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1)
{
n += substring.Length;
++count;
}
return count;
}
}
}
71 changes: 71 additions & 0 deletions src/NHibernate.Test/Linq/SelectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,62 @@ public void CanSelectWithAggregateSubQuery()
Assert.AreEqual(4, timesheets[2].EntryCount);
}

[Test]
public void CanSelectConditional()
{
using (var sqlLog = new SqlLogSpy())
{
var q = db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => o.ShippedTo.Contains("test") ? o.ShippedTo : o.Customer.CompanyName)
.OrderBy(o => o)
.Distinct()
.ToList();

Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(2));
}

using (var sqlLog = new SqlLogSpy())
{
var q = db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate)
.FirstOrDefault();

Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(1));
}

using (var sqlLog = new SqlLogSpy())
{
var q = db.Orders.Where(o => o.Customer.CustomerId == "test")
.Select(o => new
{
Value = o.OrderDate.HasValue
? o.Customer.CompanyName
: (o.ShippingDate.HasValue
? o.Shipper.CompanyName
: o.ShippedTo)
})
.FirstOrDefault();

var log = sqlLog.GetWholeLog();
Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1));
}
}

[Test]
public void CanSelectConditionalSubQuery()
{
var orders = db.Customers
.Select(c => new
{
Date = db.Orders.Where(o => o.Customer.CustomerId == c.CustomerId)
.Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate)
.First()
})
.ToList();

Assert.That(orders, Has.Count.GreaterThan(0));
}

[Test, KnownBug("NH-3045")]
public void CanSelectFirstElementFromChildCollection()
{
Expand Down Expand Up @@ -497,5 +553,20 @@ public class Wrapper<T>
public T item;
public string message;
}

private int FindAllOccurrences(string source, string substring)
{
if (source == null)
{
return 0;
}
int n = 0, count = 0;
while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1)
{
n += substring.Length;
++count;
}
return count;
}
}
}
24 changes: 24 additions & 0 deletions src/NHibernate/Hql/Ast/HqlTreeNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,30 @@ internal HqlIdent(IASTFactory factory, System.Type type)
throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name));
}
}

internal static bool SupportsType(System.Type type)
{
type = type.UnwrapIfNullable();
switch (System.Type.GetTypeCode(type))
{
case TypeCode.Boolean:
case TypeCode.Int16:
case TypeCode.Int32:
case TypeCode.Int64:
case TypeCode.Decimal:
case TypeCode.Single:
case TypeCode.DateTime:
case TypeCode.String:
case TypeCode.Double:
return true;
default:
return new[]
{
typeof(Guid),
typeof(DateTimeOffset)
}.Contains(type);
}
}
}

public class HqlRange : HqlStatement
Expand Down
18 changes: 12 additions & 6 deletions src/NHibernate/Linq/Visitors/SelectClauseNominator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
using NHibernate.Linq.Expressions;
using NHibernate.Util;
Expand All @@ -15,6 +16,7 @@ namespace NHibernate.Linq.Visitors
class SelectClauseHqlNominator : RelinqExpressionVisitor
{
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
private readonly VisitorParameters _parameters;

/// <summary>
/// The expression parts that can be converted to pure HQL.
Expand All @@ -34,6 +36,7 @@ class SelectClauseHqlNominator : RelinqExpressionVisitor

public SelectClauseHqlNominator(VisitorParameters parameters)
{
_parameters = parameters;
_functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry;
}

Expand Down Expand Up @@ -145,7 +148,7 @@ private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool proj
}

// Constants will only be evaluated in HQL if they're inside a method call
if (expression.NodeType == ExpressionType.Constant)
if (expression is ConstantExpression constantExpression && _parameters.ConstantToParameterMap.TryGetValue(constantExpression, out _))
{
return projectConstantsInHql;
}
Expand All @@ -156,14 +159,17 @@ private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool proj
return IsRegisteredFunction(expression);
}

if (expression.NodeType == ExpressionType.Conditional)
if (expression is ConditionalExpression conditionalExpression)
{
// Theoretically, any conditional that returns a CAST-able primitive should be constructable in HQL.
// The type needs to be CAST-able because HQL wraps the CASE clause in a CAST and only supports
// certain types (as defined by the HqlIdent constructor that takes a System.Type as the second argument).
// However, this may still not cover all cases, so to limit the nomination of conditional expressions,
// we will only consider those which are already getting constants projected into them.
return projectConstantsInHql;
return
HqlIdent.SupportsType(conditionalExpression.IfFalse.Type) &&
HqlCandidates.Contains(conditionalExpression.IfFalse) &&
HqlIdent.SupportsType(conditionalExpression.IfTrue.Type) &&
HqlCandidates.Contains(conditionalExpression.IfTrue) &&
HqlCandidates.Contains(conditionalExpression.Test);
}

// Assume all is good
Expand All @@ -175,4 +181,4 @@ private static bool CanBeEvaluatedInHqlStatementShortcut(Expression expression)
return expression is NhCountExpression;
}
}
}
}

0 comments on commit f07db92

Please sign in to comment.