diff --git a/src/NHibernate.Test/Async/Linq/ConstantTest.cs b/src/NHibernate.Test/Async/Linq/ConstantTest.cs index 9c2558a6902..601c5980f07 100644 --- a/src/NHibernate.Test/Async/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Async/Linq/ConstantTest.cs @@ -8,6 +8,7 @@ //------------------------------------------------------------------------------ +using System; using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -136,7 +137,7 @@ public async Task ConstantNonCachedInMemberInitExpressionWithConditionAsync() return db.Shippers.Where(o => o.ShipperId == id) .Select(o => new ShipperDto {Number = id, CompanyName = o.CompanyName}).SingleAsync(cancellationToken); } - catch (System.Exception ex) + catch (Exception ex) { return Task.FromException(ex); } @@ -323,7 +324,7 @@ public async Task DmlPlansAreCachedAsync() } [Test] - public async Task PlansWithNonParameterizedConstantsAreNotCachedAsync() + public async Task PlansWithNonParameterizedConstantsAreCachedAsync() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -338,8 +339,90 @@ public async Task PlansWithNonParameterizedConstantsAreNotCachedAsync() select new { c.CustomerId, c.ContactName, Constant = 1 }).FirstAsync()); Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(1), + "Query should be cached."); + } + + [Test] + public async Task PlansWithConstantExpressionsAreNotCachedAsync() + { + var queryPlanCacheType = typeof(QueryPlanCache); + var cache = (SoftLimitMRUCache) queryPlanCacheType.GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + var input = new { Name = "ALFKI" }; + await ((from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyBooleanColumn = c.CustomerId == input.Name }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(1), "Query should be cached"); + + await ((from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyBooleanColumn = c.CustomerId != "ALFKI" }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(2), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.Discontinued && true }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.Discontinued || true }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight > 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(5), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight < 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(6), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight >= 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(7), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight <= 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(8), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight > 0 ? DateTime.Now : default(DateTime?) }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(9), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyColumn = p.UnitPrice + 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(10), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight - 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(11), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight * 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(12), "Query should be cached"); + + await ((from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight / 10 }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(13), "Query should be cached"); + + await ((from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyColumn = c.CustomerId ?? "TEST" }).FirstAsync()); + + Assert.That(cache, Has.Count.EqualTo(14), "Query should be cached"); } [Test] diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 74b20f13afb..70120af6963 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using NHibernate.Criterion; @@ -340,7 +341,7 @@ public void DmlPlansAreCached() } [Test] - public void PlansWithNonParameterizedConstantsAreNotCached() + public void PlansWithNonParameterizedConstantsAreCached() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -355,8 +356,90 @@ public void PlansWithNonParameterizedConstantsAreNotCached() select new { c.CustomerId, c.ContactName, Constant = 1 }).First(); Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(1), + "Query should be cached."); + } + + [Test] + public void PlansWithConstantExpressionsAreNotCached() + { + var queryPlanCacheType = typeof(QueryPlanCache); + var cache = (SoftLimitMRUCache) queryPlanCacheType.GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + var input = new { Name = "ALFKI" }; + (from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyBooleanColumn = c.CustomerId == input.Name }).First(); + + Assert.That(cache, Has.Count.EqualTo(1), "Query should be cached"); + + (from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyBooleanColumn = c.CustomerId != "ALFKI" }).First(); + + Assert.That(cache, Has.Count.EqualTo(2), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.Discontinued && true }).First(); + + Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.Discontinued || true }).First(); + + Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight > 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(5), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight < 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(6), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight >= 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(7), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyBooleanColumn = p.ShippingWeight <= 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(8), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight > 0 ? DateTime.Now : default(DateTime?) }).First(); + + Assert.That(cache, Has.Count.EqualTo(9), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyColumn = p.UnitPrice + 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(10), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight - 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(11), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight * 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(12), "Query should be cached"); + + (from p in db.Products + select new { p.Name, DummyColumn = p.ShippingWeight / 10 }).First(); + + Assert.That(cache, Has.Count.EqualTo(13), "Query should be cached"); + + (from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, DummyColumn = c.CustomerId ?? "TEST" }).First(); + + Assert.That(cache, Has.Count.EqualTo(14), "Query should be cached"); } [Test] diff --git a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs index 52bf954101d..1bd2a22a87c 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs @@ -7,6 +7,9 @@ using NHibernate.Param; using NHibernate.Util; using Remotion.Linq.Parsing; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses; +using NHibernate.Linq.Clauses; namespace NHibernate.Linq.Visitors { @@ -49,7 +52,7 @@ internal Expression Nominate(Expression expression) ContainsUntranslatedMethodCalls = false; _canBeCandidate = true; _stateStack = new Stack(); - _stateStack.Push(false); + _stateStack.Push(true); return Visit(expression); } @@ -67,16 +70,14 @@ public override Expression Visit(Expression expression) return innerExpression; } - var projectConstantsInHql = _stateStack.Peek() || expression.NodeType == ExpressionType.Equal || IsRegisteredFunction(expression); + var isRegisteredFunction = IsRegisteredFunction(expression); + var projectConstantsInHql = (_stateStack.Peek() && IsConstantExpression(expression)) || isRegisteredFunction; // Set some flags, unless we already have proper values for them: // projectConstantsInHql if they are inside a method call executed server side. // ContainsUntranslatedMethodCalls if a method call must be executed locally. - var isMethodCall = expression.NodeType == ExpressionType.Call; - if (isMethodCall && (!projectConstantsInHql || !ContainsUntranslatedMethodCalls)) + if (expression.NodeType == ExpressionType.Call && (!projectConstantsInHql || !ContainsUntranslatedMethodCalls)) { - var isRegisteredFunction = IsRegisteredFunction(expression); - projectConstantsInHql = projectConstantsInHql || isRegisteredFunction; ContainsUntranslatedMethodCalls = ContainsUntranslatedMethodCalls || !isRegisteredFunction; } @@ -96,7 +97,7 @@ public override Expression Visit(Expression expression) if (_canBeCandidate) { - if (CanBeEvaluatedInHqlSelectStatement(expression, projectConstantsInHql)) + if (CanBeEvaluatedInHqlSelectStatement(expression, projectConstantsInHql, isRegisteredFunction)) { HqlCandidates.Add(expression); } @@ -115,6 +116,80 @@ public override Expression Visit(Expression expression) return expression; } + private static bool IsAllowedToProjectInHql(System.Type type) + { + return (type.IsValueType || type == typeof(string)) && typeof(DateTime) != type && typeof(DateTime?) != type && typeof(TimeSpan) != type && typeof(TimeSpan?) != type; + } + + private static bool IsValueType(System.Type type) + { + return type.IsValueType || type == typeof(string); + } + + private static bool IsConstantExpression(Expression expression) + { + //if (expression.NodeType != ExpressionType.Equal) return false; + + if(expression == null) return true; + + switch (expression.NodeType) + { + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.Not: + case ExpressionType.MemberInit: + case ExpressionType.New: + case ExpressionType.NewArrayInit: + case ExpressionType.ListInit: + return true; + case ExpressionType.Extension: + if( expression is QuerySourceReferenceExpression extension && + (extension.ReferencedQuerySource is NhClauseBase || + (extension.ReferencedQuerySource is MainFromClause fromClause && + fromClause.FromExpression.Type.IsGenericType && + fromClause.FromExpression.Type.GetGenericTypeDefinition() == typeof(NhQueryable<>)))) + { + return true; + } + return false; + case ExpressionType.Convert: + var convert = (UnaryExpression) expression; + return convert.Method == null && IsAllowedToProjectInHql(convert.Operand.Type) && IsConstantExpression(convert.Operand); + case ExpressionType.Constant: + var constant = (ConstantExpression) expression; + return constant.Value == null || IsValueType(expression.Type); + case ExpressionType.MemberAccess: + var member = (MemberExpression) expression; + return IsConstantExpression(member.Expression); + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + case ExpressionType.Add: + case ExpressionType.Subtract: + case ExpressionType.Multiply: + case ExpressionType.Divide: + case ExpressionType.Modulo: + var binary = (BinaryExpression) expression; + return IsAllowedToProjectInHql(binary.Left.Type) && IsConstantExpression(binary.Left) && + IsAllowedToProjectInHql(binary.Right.Type) && IsConstantExpression(binary.Right); + case ExpressionType.Coalesce: + var coalesce = (BinaryExpression) expression; + return IsConstantExpression(coalesce.Left) && IsConstantExpression(coalesce.Right); + case ExpressionType.Conditional: + var conditional = (ConditionalExpression) expression; + return IsConstantExpression(conditional.Test) && + IsValueType(conditional.IfTrue.Type) && IsConstantExpression(conditional.IfTrue) && + IsValueType(conditional.IfFalse.Type) && IsConstantExpression(conditional.IfFalse); + default: + return false; + } + } + private bool IsRegisteredFunction(Expression expression) { if (expression.NodeType == ExpressionType.Call) @@ -141,7 +216,7 @@ expression is NhMaxExpression || return false; } - private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql) + private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql, bool isRegisteredFunction) { // HQL can't do New or Member Init if (expression.NodeType == ExpressionType.MemberInit || @@ -166,7 +241,7 @@ private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool proj if (expression.NodeType == ExpressionType.Call) { // Depends if it's in the function registry - return IsRegisteredFunction(expression); + return isRegisteredFunction; } if (expression.NodeType == ExpressionType.Conditional)