Skip to content

Commit

Permalink
Merge pull request #63 from win7user10/feature/ISS-44-support-of-stri…
Browse files Browse the repository at this point in the history
…ng-enum-values

Feature/Support of string enum values
  • Loading branch information
win7user10 authored Oct 29, 2022
2 parents 59380f0 + 9a3c02d commit 75fc5ef
Show file tree
Hide file tree
Showing 29 changed files with 326 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using Laraue.EfCoreTriggers.Common.Migrations;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.Extensions.DependencyInjection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ public static IServiceCollection AddDefaultServices(this IServiceCollection serv

.AddMethodCallConverter<CountVisitor>()

.AddScoped<VisitingInfo>()

.AddScoped<IUpdateExpressionVisitor, UpdateExpressionVisitor>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
</PropertyGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.1' ">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="5.0.14" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="5.0.14">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="5.0.17" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="5.0.17">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="5.0.14" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="5.0.17" />
</ItemGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'net6.0' ">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="6.0.2">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.10" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="6.0.10">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.10" />
</ItemGroup>

</Project>
16 changes: 16 additions & 0 deletions src/Laraue.EfCoreTriggers.Common/Services/IDbSchemaRetriever.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,21 @@ public interface IDbSchemaRetriever
/// <returns></returns>
PropertyInfo[] GetPrimaryKeyMembers(Type type);

/// <summary>
/// Get info about cases participating in relations between two types.
/// </summary>
/// <param name="type"></param>
/// <param name="type2"></param>
/// <returns></returns>
KeyInfo[] GetForeignKeyMembers(Type type, Type type2);

/// <summary>
/// Some type can be overriden, for example Enum can be store as string in the DB.
/// In these cases clr type will be returned from this function.
/// </summary>
/// <param name="type">Entity type.</param>
/// <param name="memberInfo">Entity member.</param>
/// <param name="clrType">Actual type if it annotation was found.</param>
/// <returns></returns>
Type GetActualClrType(Type type, MemberInfo memberInfo);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ public class EfCoreDbSchemaRetriever : IDbSchemaRetriever
/// <summary>
/// Cached column names for entity properties.
/// </summary>
private static readonly Dictionary<MemberInfo, string> ColumnNamesCache = new();
private readonly Dictionary<MemberInfo, string> _columnNamesCache = new();

/// <summary>
/// Cached table names for entities.
/// </summary>
private static readonly Dictionary<Type, string> TableNamesCache = new();
private readonly Dictionary<Type, string> _tableNamesCache = new();

/// <summary>
/// Cached database schema names.
/// </summary>
private static readonly Dictionary<Type, string> TableSchemasCache = new();
private readonly Dictionary<Type, string> _tableSchemasCache = new();

/// <summary>
/// Model used for generating SQL. From this model takes column names, table names and other meta information.
Expand All @@ -42,38 +42,49 @@ public EfCoreDbSchemaRetriever(IModel model)
/// <inheritdoc />
public string GetColumnName(Type type, MemberInfo memberInfo)
{
if (!ColumnNamesCache.ContainsKey(memberInfo))
if (!_columnNamesCache.ContainsKey(memberInfo))
{
var entityType = Model.FindEntityType(type);

if (entityType == null)
{
throw new InvalidOperationException($"DbSet<{type}> should be added to the DbContext");
}

var property = entityType.FindProperty(memberInfo.Name);
var entityType = GetEntityType(type);
var property = GetColumn(type, memberInfo);
var identifier = (StoreObjectIdentifier)StoreObjectIdentifier.Create(entityType, StoreObjectType.Table);
ColumnNamesCache.Add(memberInfo, property.GetColumnName(identifier));
_columnNamesCache.Add(memberInfo, property.GetColumnName(identifier));
}

if (!ColumnNamesCache.TryGetValue(memberInfo, out var columnName))
if (!_columnNamesCache.TryGetValue(memberInfo, out var columnName))
{
throw new InvalidOperationException($"Column name for member {memberInfo.Name} is not defined in model");
}

return columnName;
}

private IProperty GetColumn(Type type, MemberInfo memberInfo)
{
return GetEntityType(type).FindProperty(memberInfo.Name);
}

private IEntityType GetEntityType(Type type)
{
var entityType = Model.FindEntityType(type);

if (entityType == null)
{
throw new InvalidOperationException($"DbSet<{type}> should be added to the DbContext");
}

return entityType;
}

/// <inheritdoc />
public string GetTableName(Type entity)
{
if (!TableNamesCache.ContainsKey(entity))
if (!_tableNamesCache.ContainsKey(entity))
{
var entityType = Model.FindEntityType(entity);
TableNamesCache.Add(entity, entityType.GetTableName());
_tableNamesCache.Add(entity, entityType.GetTableName());
}

if (!TableNamesCache.TryGetValue(entity, out var tableName))
if (!_tableNamesCache.TryGetValue(entity, out var tableName))
{
throw new InvalidOperationException($"Table name for entity {entity.FullName} is not defined in model.");
}
Expand All @@ -88,10 +99,10 @@ public string GetTableName(Type entity)
/// <inheritdoc />
public string GetFunctionName(Type entity, string name)
{
if (!TableNamesCache.ContainsKey(entity))
if (!_tableNamesCache.ContainsKey(entity))
{
var entityType = Model.FindEntityType(entity);
TableNamesCache.Add(entity, entityType.GetTableName());
_tableNamesCache.Add(entity, entityType.GetTableName());
}

var schemaName = GetTableSchemaName(entity);
Expand All @@ -108,13 +119,13 @@ public string GetFunctionName(Type entity, string name)
/// <returns></returns>
protected virtual string GetTableSchemaName(Type entity)
{
if (!TableSchemasCache.ContainsKey(entity))
if (!_tableSchemasCache.ContainsKey(entity))
{
var entityType = Model.FindEntityType(entity);
TableSchemasCache.Add(entity, entityType.GetSchema());
_tableSchemasCache.Add(entity, entityType.GetSchema());
}

if (!TableSchemasCache.TryGetValue(entity, out var schemaName))
if (!_tableSchemasCache.TryGetValue(entity, out var schemaName))
{
throw new InvalidOperationException($"Schema for entity {entity.FullName} is not defined in model.");
}
Expand Down Expand Up @@ -154,4 +165,11 @@ public KeyInfo[] GetForeignKeyMembers(Type type, Type type2)

return keys;
}

public Type GetActualClrType(Type type, MemberInfo memberInfo)
{
var columnType = GetColumn(type, memberInfo);

return columnType.FindAnnotation("ProviderClrType")?.Value as Type ?? columnType.ClrType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ public class BinaryExpressionVisitor : BaseExpressionVisitor<BinaryExpression>
{
private readonly IExpressionVisitorFactory _factory;
private readonly ISqlGenerator _generator;
private readonly IDbSchemaRetriever _schemaRetriever;

/// <inheritdoc />
public BinaryExpressionVisitor(IExpressionVisitorFactory factory, ISqlGenerator generator)
public BinaryExpressionVisitor(IExpressionVisitorFactory factory, ISqlGenerator generator, IDbSchemaRetriever schemaRetriever)
{
_factory = factory;
_generator = generator;
_schemaRetriever = schemaRetriever;
}

/// <inheritdoc />
Expand All @@ -25,21 +27,40 @@ public override SqlBuilder Visit(
ArgumentTypes argumentTypes,
VisitedMembers visitedMembers)
{
// Convert(charValue, Int32) == 122 -> charValue == 'z'
if (expression.Left is UnaryExpression
{
NodeType: ExpressionType.Convert,
Operand: MemberExpression memberExpression
}
&& memberExpression.Type == typeof(char)
}
&& expression.Right is ConstantExpression constantExpression)
{
var memberSql = _factory.Visit(memberExpression, argumentTypes, visitedMembers);
var constantSql = _factory.Visit(Expression.Constant(Convert.ToChar(constantExpression.Value)), argumentTypes, visitedMembers);
// Convert(enumValue, Int32) == 1 when enum is stores as string -> enumValue == Enum.Value
var clrType = _schemaRetriever.GetActualClrType(
memberExpression.Member.DeclaringType,
memberExpression.Member);

if (memberExpression.Type.IsEnum && clrType == typeof(string))
{
var valueName = Enum.GetValues(memberExpression.Type)
.Cast<object>()
.First(x => (int)x == (int)constantExpression.Value)
.ToString();

var sb = _factory.Visit(memberExpression, argumentTypes, visitedMembers);
sb.Append($" = {_generator.GetSql(valueName)}");
return sb;
}

// Convert(charValue, Int32) == 122 -> charValue == 'z'
if (memberExpression.Type == typeof(char))
{
var memberSql = _factory.Visit(memberExpression, argumentTypes, visitedMembers);
var constantSql = _factory.Visit(Expression.Constant(Convert.ToChar(constantExpression.Value)), argumentTypes, visitedMembers);

return memberSql
.Append(" = ")
.Append(constantSql);
return memberSql
.Append(" = ")
.Append(constantSql);
}
}

if (expression.Method?.Name == nameof(string.Concat))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ namespace Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
public class ExpressionVisitorFactory : IExpressionVisitorFactory
{
private readonly IServiceProvider _provider;
private readonly VisitingInfo _visitingInfo;

/// <summary>
/// Initializes a new instance of <see cref="ExpressionVisitorFactory"/>.
/// </summary>
/// <param name="provider"></param>
public ExpressionVisitorFactory(IServiceProvider provider)
/// <param name="visitingInfo"></param>
public ExpressionVisitorFactory(IServiceProvider provider, VisitingInfo visitingInfo)
{
_provider = provider;
_visitingInfo = visitingInfo;
}

/// <inheritdoc />
Expand All @@ -27,14 +30,21 @@ public SqlBuilder Visit(Expression expression, ArgumentTypes argumentTypes, Visi
{
BinaryExpression binary => Visit(binary, argumentTypes, visitedMembers),
ConstantExpression constant => Visit(constant, argumentTypes, visitedMembers),
MemberExpression member => Visit(member, argumentTypes, visitedMembers),
MemberExpression member => VisitAndRememberMember(member, argumentTypes, visitedMembers),
MethodCallExpression methodCall => Visit(methodCall, argumentTypes, visitedMembers),
UnaryExpression unary => Visit(unary, argumentTypes, visitedMembers),
NewExpression @new => Visit(@new, argumentTypes, visitedMembers),
LambdaExpression lambda => Visit(lambda, argumentTypes, visitedMembers),
_ => throw new NotSupportedException($"Expression of type {expression.GetType()} is not supported")
};
}

private SqlBuilder VisitAndRememberMember(MemberExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
return _visitingInfo.ExecuteWithChangingMember(
expression.Member,
() => Visit(expression, argumentTypes, visitedMembers));
}

private SqlBuilder Visit<TExpression>(TExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
where TExpression : Expression
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services.Impl.SetExpressionVisitors;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ public class UnaryExpressionVisitor : BaseExpressionVisitor<UnaryExpression>
{
private readonly IExpressionVisitorFactory _factory;
private readonly ISqlGenerator _generator;
private readonly IDbSchemaRetriever _dbSchemaRetriever;

/// <inheritdoc />
public UnaryExpressionVisitor(IExpressionVisitorFactory factory, ISqlGenerator generator)
public UnaryExpressionVisitor(
IExpressionVisitorFactory factory,
ISqlGenerator generator,
IDbSchemaRetriever dbSchemaRetriever)
{
_factory = factory;
_generator = generator;
_dbSchemaRetriever = dbSchemaRetriever;
}

/// <inheritdoc />
Expand Down Expand Up @@ -58,8 +63,14 @@ public override SqlBuilder Visit(UnaryExpression expression, ArgumentTypes argum
/// <returns></returns>
protected virtual bool IsNeedConversion(UnaryExpression unaryExpression)
{
var clrType1 = unaryExpression.Operand.Type;
var clrType2 = unaryExpression.Type;
// Do not execute conversion Type? -> Type, it is actual for CLR only
if (Nullable.GetUnderlyingType(unaryExpression.Type) != null)
{
return false;
}

var clrType1 = GetActualClrType(unaryExpression.Operand);
var clrType2 = GetActualClrType(unaryExpression);
if (clrType1 == typeof(object) || clrType2 == typeof(object))
{
return false;
Expand All @@ -84,4 +95,16 @@ protected virtual string GetConvertExpressionSql(UnaryExpression unaryExpression
? $"CAST({member} AS {sqlType})"
: throw new NotSupportedException($"Converting of type {unaryExpression.Type} is not supported");
}

private Type GetActualClrType(Expression expression)
{
if (expression is not MemberExpression {Expression: ParameterExpression parameterExpression} memberExpression)
{
return EfCoreTriggersHelper.GetNotNullableType(expression.Type);
}

return _dbSchemaRetriever.GetActualClrType(
EfCoreTriggersHelper.GetNotNullableType(parameterExpression.Type),
memberExpression.Member);
}
}
Loading

0 comments on commit 75fc5ef

Please sign in to comment.