Skip to content

Commit

Permalink
Rename ISQLFunctionExtended interface to ISQLFunctionExtended
Browse files Browse the repository at this point in the history
  • Loading branch information
maca88 committed Feb 21, 2020
1 parent 0b5826e commit 89e9933
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 87 deletions.
4 changes: 1 addition & 3 deletions src/NHibernate.Test/Async/QueryTest/CountFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
using NHibernate.Cfg;
using NHibernate.Dialect.Function;
using NHibernate.DomainModel;
using NHibernate.Engine;
using NHibernate.Type;
using NUnit.Framework;
using Environment=NHibernate.Cfg.Environment;

Expand Down Expand Up @@ -57,4 +55,4 @@ public async Task OverriddenAsync()
await (sf.CloseAsync());
}
}
}
}
3 changes: 1 addition & 2 deletions src/NHibernate.Test/Hql/SimpleFunctionsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ public void ClassicSum()
Assert.Throws<QueryException>(() => csf.Render(args, factoryImpl));
}

// Since v5.3
[Test, Obsolete]
[Test]
public void ClassicCount()
{
//ANSI-SQL92 definition
Expand Down
17 changes: 1 addition & 16 deletions src/NHibernate.Test/QueryTest/CountFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
using NHibernate.Cfg;
using NHibernate.Dialect.Function;
using NHibernate.DomainModel;
using NHibernate.Engine;
using NHibernate.Type;
using NUnit.Framework;
using Environment=NHibernate.Cfg.Environment;

Expand Down Expand Up @@ -46,17 +44,4 @@ public void Overridden()
sf.Close();
}
}

[Serializable]
internal class ClassicCountFunction : ClassicAggregateFunction
{
public ClassicCountFunction() : base("count", true)
{
}

public override IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
}
}
}
28 changes: 12 additions & 16 deletions src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using System;
using System.Collections;
using System.Text;
using System.Collections.Generic;
using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
using NHibernate.Util;

namespace NHibernate.Dialect.Function
{
[Serializable]
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLAggregateFunction
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
{
private IType returnType = null;
private readonly string name;
Expand Down Expand Up @@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping)
return returnType ?? columnType;
}

/// <inheritdoc />
public virtual IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
{
return ReturnType(argumentTypes.FirstOrDefault(), mapping);
}

/// <inheritdoc />
public string FunctionName => name;

public bool HasArguments
{
get { return true; }
Expand Down Expand Up @@ -110,18 +119,5 @@ bool IFunctionGrammar.IsKnownArgument(string token)
}

#endregion

#region ISQLAggregateFunction Members

/// <inheritdoc />
public string FunctionName => name;

/// <inheritdoc />
public virtual IType GetActualReturnType(IType argumentType, IMapping mapping)
{
return ReturnType(argumentType, mapping);
}

#endregion
}
}
4 changes: 1 addition & 3 deletions src/NHibernate/Dialect/Function/ClassicCountFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ namespace NHibernate.Dialect.Function
/// <summary>
/// Classic COUNT sqlfunction that return types as it was done in Hibernate 3.1
/// </summary>
// Since v5.3
[Obsolete("This class has no more usages in NHibernate and will be removed in a future version.")]
[Serializable]
public class ClassicCountFunction : ClassicAggregateFunction
{
Expand All @@ -21,4 +19,4 @@ public override IType ReturnType(IType columnType, IMapping mapping)
return NHibernateUtil.Int32;
}
}
}
}
22 changes: 0 additions & 22 deletions src/NHibernate/Dialect/Function/ISQLAggregateFunction.cs

This file was deleted.

45 changes: 45 additions & 0 deletions src/NHibernate/Dialect/Function/ISQLFunction.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
Expand Down Expand Up @@ -41,4 +43,47 @@ public interface ISQLFunction
/// <returns>SQL fragment for the function.</returns>
SqlString Render(IList args, ISessionFactoryImplementor factory);
}

// 6.0 TODO: Remove
internal static class SQLFunctionExtensions
{
/// <summary>
/// Get the type that will be effectively returned by the underlying database.
/// </summary>
/// <param name="sqlFunction">The sql function.</param>
/// <param name="argumentTypes">The types of arguments.</param>
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
/// is invalid or they are not supported.</returns>
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
/// number of arguments is invalid or they are not supported.</exception>
public static IType GetEffectiveReturnType(
this ISQLFunction sqlFunction,
IEnumerable<IType> argumentTypes,
IMapping mapping,
bool throwOnError)
{
if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
{
try
{
#pragma warning disable 618
return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
#pragma warning restore 618
}
catch (QueryException)
{
if (throwOnError)
{
throw;
}

return null;
}
}

return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
}
}
}
27 changes: 27 additions & 0 deletions src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System.Collections.Generic;
using NHibernate.Engine;
using NHibernate.Type;

namespace NHibernate.Dialect.Function
{
// 6.0 TODO: Merge into ISQLFunction
internal interface ISQLFunctionExtended : ISQLFunction
{
/// <summary>
/// The function name or <see langword="null"/> when multiple functions/operators/statements are used.
/// </summary>
string FunctionName { get; }

/// <summary>
/// Get the type that will be effectively returned by the underlying database.
/// </summary>
/// <param name="argumentTypes">The types of arguments.</param>
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
/// is invalid or they are not supported.</returns>
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
/// number of arguments is invalid or they are not supported.</exception>
IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError);
}
}
5 changes: 2 additions & 3 deletions src/NHibernate/Dialect/MsSql2000Dialect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ protected class CountBigQueryFunction : ClassicAggregateFunction
{
public CountBigQueryFunction() : base("count_big", true) { }

public override IType ReturnType(IType columnType, IMapping mapping)
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
{
return NHibernateUtil.Int64;
}
Expand All @@ -710,8 +710,7 @@ public override IType ReturnType(IType columnType, IMapping mapping)
[Serializable]
private class CountQueryFunction : CountQueryFunctionInfo
{
/// <inheritdoc />
public override IType GetActualReturnType(IType columnType, IMapping mapping)
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
{
return NHibernateUtil.Int32;
}
Expand Down
3 changes: 1 addition & 2 deletions src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ private void EndFunctionTemplate(IASTNode m)
private void OutAggregateFunctionName(IASTNode m)
{
var aggregateNode = (AggregateNode) m;
var template = aggregateNode.SqlFunction;
Out(template == null ? aggregateNode.Text : template.FunctionName);
Out(aggregateNode.SqlFunction?.FunctionName ?? aggregateNode.Text);
}

private void CommaBetweenParameters(String comma)
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public AggregateNode(IToken token)
{
}

internal ISQLAggregateFunction SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLAggregateFunction;
internal ISQLFunctionExtended SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLFunctionExtended;

public override IType DataType
{
Expand Down
35 changes: 16 additions & 19 deletions src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using NHibernate.Dialect.Function;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast;
using NHibernate.Hql.Ast.ANTLR;
Expand Down Expand Up @@ -255,16 +256,22 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)

protected HqlTreeNode VisitNhCount(NhCountExpression expression)
{
string functionName;
HqlExpression countHqlExpression;
if (expression is NhLongCountExpression)
{
return IsCastRequired(expression.Type, "count_big", out _)
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type)
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type);
functionName = "count_big";
countHqlExpression = _hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression());
}
else
{
functionName = "count";
countHqlExpression = _hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression());
}

return IsCastRequired(expression.Type, "count", out _)
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type)
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type);
return IsCastRequired(functionName, expression.Expression, expression.Type)
? (HqlTreeNode) _hqlTreeBuilder.Cast(countHqlExpression, expression.Type)
: _hqlTreeBuilder.TransparentCast(countHqlExpression, expression.Type);
}

protected HqlTreeNode VisitNhMin(NhMinExpression expression)
Expand Down Expand Up @@ -606,7 +613,7 @@ private bool IsCastRequired(Expression expression, System.Type toType, out bool
{
existType = false;
return toType != typeof(object) &&
IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
}

private bool IsCastRequired(IType type, IType toType, out bool existType)
Expand Down Expand Up @@ -650,7 +657,7 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)

private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType)
{
var argumentType = GetType(argumentExpression);
var argumentType = ExpressionsHelper.GetType(_parameters, argumentExpression);
if (argumentType == null || returnType == typeof(object))
{
return false;
Expand All @@ -668,18 +675,8 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio
return true; // Fallback to the old behavior
}

var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory);
var fnReturnType = sqlFunction.GetEffectiveReturnType(new[] {argumentType}, _parameters.SessionFactory, false);
return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _);
}

private IType GetType(Expression expression)
{
// Try to get the mapped type for the member as it may be a non default one
return expression.Type == typeof(object)
? null
: (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _)
? type
: TypeFactory.GetDefaultTypeFor(expression.Type));
}
}
}
23 changes: 23 additions & 0 deletions src/NHibernate/Util/ExpressionsHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@ public static MemberInfo DecodeMemberAccessExpression<TEntity, TResult>(Expressi
return ((MemberExpression)expression.Body).Member;
}

/// <summary>
/// Get the mapped type for the given expression.
/// </summary>
/// <param name="parameters">The query parameters.</param>
/// <param name="expression">The expression.</param>
/// <returns>The mapped type of the expression or <see langword="null"/> when the mapped type was not
/// found and the <paramref name="expression"/> type is <see cref="object"/>.</returns>
internal static IType GetType(VisitorParameters parameters, Expression expression)
{
if (expression is ConstantExpression constantExpression &&
parameters.ConstantToParameterMap.TryGetValue(constantExpression, out var param))
{
return param.Type;
}

if (TryGetMappedType(parameters.SessionFactory, expression, out var type, out _, out _, out _))
{
return type;
}

return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type);
}

/// <summary>
/// Try to get the mapped nullability from the given expression.
/// </summary>
Expand Down

0 comments on commit 89e9933

Please sign in to comment.