diff --git a/src/EFCore.PG/NpgsqlDbFunctionsExtensions.cs b/src/EFCore.PG/NpgsqlDbFunctionsExtensions.cs index 6742bd994..a83558265 100644 --- a/src/EFCore.PG/NpgsqlDbFunctionsExtensions.cs +++ b/src/EFCore.PG/NpgsqlDbFunctionsExtensions.cs @@ -136,5 +136,15 @@ var regexPattern RegexOptions.Singleline, _regexTimeout); } + + public static string[] HStoreKeys([CanBeNull] this DbFunctions _, [NotNull] IDictionary hstore) + { + return hstore.Keys.ToArray(); + } + + public static string[] HStoreValues([CanBeNull] this DbFunctions _, [NotNull] IDictionary hstore) + { + return hstore.Values.ToArray(); + } } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs index 691c7b958..a79ec9bd5 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs @@ -47,7 +47,9 @@ public class NpgsqlCompositeMethodCallTranslator : RelationalCompositeMethodCall new NpgsqlStringTrimTranslator(), new NpgsqlStringTrimEndTranslator(), new NpgsqlStringTrimStartTranslator(), - new NpgsqlRegexIsMatchTranslator() + new NpgsqlRegexIsMatchTranslator(), + new NpgsqlDictionaryIndexTranslator(), + new NpgsqlHStoreKeysTranslator() }; public NpgsqlCompositeMethodCallTranslator( diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDictionaryIndexTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDictionaryIndexTranslator.cs new file mode 100644 index 000000000..8e736b3ff --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDictionaryIndexTranslator.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using System.Linq; +using System.Reflection; + +namespace Microsoft.EntityFrameworkCore.Query.ExpressionTranslators.Internal +{ + public class NpgsqlDictionaryIndexTranslator : IMethodCallTranslator + { + static readonly PropertyInfo DictionaryPropertyIndexAccessor = + typeof(IDictionary).GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Single(p => p.Name == "Item"); + + public Expression Translate(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.NodeType == ExpressionType.Call + && methodCallExpression.Method.Name == "get_Item" + && typeof(IDictionary).IsAssignableFrom(methodCallExpression.Method.DeclaringType)) + { + return Expression.MakeIndex(methodCallExpression.Object, DictionaryPropertyIndexAccessor, methodCallExpression.Arguments); + } + return null; + } + } +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlHStoreKeysTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlHStoreKeysTranslator.cs new file mode 100644 index 000000000..576b42ff8 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlHStoreKeysTranslator.cs @@ -0,0 +1,41 @@ +using Microsoft.EntityFrameworkCore.Query.Expressions; +using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; +using System; +using System.Linq; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; + +namespace Microsoft.EntityFrameworkCore.Query.ExpressionTranslators.Internal +{ + public class NpgsqlHStoreKeysTranslator : IMethodCallTranslator + { + public Expression Translate(MethodCallExpression methodCallExpression) + { + if(methodCallExpression.Method == + typeof(NpgsqlDbFunctionsExtensions) + .GetRuntimeMethod( + nameof(NpgsqlDbFunctionsExtensions.HStoreKeys), new Type[] { typeof(DbFunctions), typeof(IDictionary) })) + { + return new SqlFunctionExpression("akeys", typeof(string[]), new Expression[] { methodCallExpression.Arguments[1] }); + } + + if (methodCallExpression.Method == + typeof(NpgsqlDbFunctionsExtensions) + .GetRuntimeMethod( + nameof(NpgsqlDbFunctionsExtensions.HStoreValues), new Type[] { typeof(DbFunctions), typeof(IDictionary) })) + { + return new SqlFunctionExpression("avals", typeof(string[]), new Expression[] { methodCallExpression.Arguments[1] }); + } + + if(typeof(IDictionary).IsAssignableFrom(methodCallExpression.Method.DeclaringType) + && methodCallExpression.Method.Name == nameof(IDictionary.ContainsKey)) + { + return new DictionaryContainsKeyExpression(methodCallExpression.Arguments.Single(), methodCallExpression.Object); + } + + return null; + } + } +} diff --git a/src/EFCore.PG/Query/Expressions/Internal/DictionaryContainsKeyExpression.cs b/src/EFCore.PG/Query/Expressions/Internal/DictionaryContainsKeyExpression.cs new file mode 100644 index 000000000..3c61ce8ac --- /dev/null +++ b/src/EFCore.PG/Query/Expressions/Internal/DictionaryContainsKeyExpression.cs @@ -0,0 +1,141 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Utilities; +using System.Diagnostics; +using Microsoft.EntityFrameworkCore.Query.Sql.Internal; + +namespace Microsoft.EntityFrameworkCore.Query.Expressions.Internal +{ + public class DictionaryContainsKeyExpression : Expression + { + /// + /// Creates a new instance of expression that is used by EF translator to fetch + /// HStore keys from database. + /// + /// The key. + /// The dictionary. + public DictionaryContainsKeyExpression( + [NotNull] Expression key, + [NotNull] Expression dictionary) + { + Check.NotNull(dictionary, nameof(dictionary)); + Check.NotNull(key, nameof(key)); + Debug.Assert(typeof(IDictionary).IsAssignableFrom(dictionary.Type)); + + Dictionary = dictionary; + Key = key; + } + + /// + /// Gets the dictionary. + /// + /// + /// The dictionary. + /// + public virtual Expression Dictionary { get; } + + /// + /// Gets the key. + /// + /// + /// The key. + /// + public virtual Expression Key { get; } + + /// + /// Returns the node type of this . (Inherited from .) + /// + /// The that represents this expression. + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + /// Gets the static type of the expression that this represents. (Inherited from .) + /// + /// The that represents the static type of the expression. + public override Type Type => typeof(bool); + + /// + /// Dispatches to the specific visit method for this node type. + /// + protected override Expression Accept(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return visitor is NpgsqlQuerySqlGenerator npsgqlGenerator + ? npsgqlGenerator.VisitDictionaryContainsKey(this) + : base.Accept(visitor); + } + + /// + /// Reduces the node and then calls the method passing the + /// reduced expression. + /// Throws an exception if the node isn't reducible. + /// + /// An instance of . + /// The expression being visited, or an expression which should replace it in the tree. + /// + /// Override this method to provide logic to walk the node's children. + /// A typical implementation will call visitor.Visit on each of its + /// children, and if any of them change, should return a new copy of + /// itself with the modified children. + /// + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var newDictionary = visitor.Visit(Dictionary); + var newKey = visitor.Visit(Key); + + return newKey != Key || newDictionary != Dictionary + ? new DictionaryContainsKeyExpression(newKey, newDictionary) + : this; + } + + /// + /// Tests if this object is considered equal to another. + /// + /// The object to compare with the current object. + /// + /// true if the objects are considered equal, false if they are not. + /// + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) + { + return false; + } + + if (ReferenceEquals(this, obj)) + { + return true; + } + + return obj.GetType() == GetType() && Equals((DictionaryContainsKeyExpression)obj); + } + + bool Equals(DictionaryContainsKeyExpression other) + => Dictionary.Equals(other.Dictionary) && Key.Equals(other.Key); + + /// + /// Returns a hash code for this object. + /// + /// + /// A hash code for this object. + /// + public override int GetHashCode() + { + unchecked + { + return (Key.GetHashCode() * 397) ^ Dictionary.GetHashCode(); + } + } + + /// + /// Creates a representation of the Expression. + /// + /// A representation of the Expression. + public override string ToString() => $"{Dictionary} ? {Key}"; + } +} diff --git a/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs b/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs index 5b865f6e0..be0a7e70d 100644 --- a/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs +++ b/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs @@ -30,6 +30,7 @@ using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; +using System.Collections.Generic; namespace Microsoft.EntityFrameworkCore.Query.Sql.Internal { @@ -138,6 +139,21 @@ protected override Expression VisitUnary(UnaryExpression expression) return base.VisitUnary(expression); } + protected override Expression VisitIndex(IndexExpression expression) + { + if (expression.Indexer.Name == "Item" + && typeof(IDictionary).IsAssignableFrom(expression.Object.Type)) + { + Debug.Assert(expression.Arguments.Count == 1); + var expr = Visit(expression.Object); + Sql.Append(" -> "); + Visit(expression.Arguments[0]); + return null; + } + + return base.VisitIndex(expression); + } + void GenerateArrayIndex([NotNull] BinaryExpression expression) { Debug.Assert(expression.NodeType == ExpressionType.ArrayIndex); @@ -176,6 +192,15 @@ public Expression VisitArrayAny(ArrayAnyExpression arrayAnyExpression) return arrayAnyExpression; } + public Expression VisitDictionaryContainsKey(DictionaryContainsKeyExpression expr) + { + Visit(expr.Dictionary); + Sql.Append(" ? "); + Visit(expr.Key); + + return expr; + } + // PostgreSQL array indexing is 1-based. If the index happens to be a constant, // just increment it. Otherwise, append a +1 in the SQL. Expression GenerateOneBasedIndexExpression(Expression expression) diff --git a/test/EFCore.PG.FunctionalTests/Query/HStoreQueryTest.cs b/test/EFCore.PG.FunctionalTests/Query/HStoreQueryTest.cs new file mode 100644 index 000000000..ddf9d4309 --- /dev/null +++ b/test/EFCore.PG.FunctionalTests/Query/HStoreQueryTest.cs @@ -0,0 +1,237 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Utilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.FunctionalTests.Query +{ + public class HStoreQueryTest : IClassFixture + { + [Fact] + public void HStore_key_value_selector() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.Tags["kind"] == "big").ToList(); + + Assert.Equal(2, actual.Count); + AssertContainsInSql(@"WHERE ""e"".""Tags"" -> 'kind' = 'big'"); + } + } + + [Fact] + public void HStore_projection() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities + .Select(e => new + { + Kind = e.Tags["kind"] + }).ToList(); + + AssertContainsInSql(@"SELECT ""e"".""Tags"" -> 'kind' AS ""Kind"""); + } + } + + [Fact] + public void HStore_add_value() + { + //EF does not detect changes on single entries in dictionary + using (var ctx = CreateContext()) + { + var entity = ctx.SomeEntities.First(e => e.Id == 1); + entity.Tags.Add("m", "d"); + var numberOfSavedEntities = ctx.SaveChanges(); + + ctx.Entry(entity).State = EntityState.Detached; + + entity = ctx.SomeEntities.First(e => e.Id == 1); + + Assert.Equal(1, numberOfSavedEntities); + Assert.True(entity.Tags.ContainsKey("m")); + Assert.Equal("d", entity.Tags["m"]); + } + } + + [Fact] + public void HStore_update_value() + { + //EF does not detect changes on single entries in dictionary + using (var ctx = CreateContext()) + { + var entity = ctx.SomeEntities.First(e => e.Id == 1); + entity.Tags["kind"] = "thick"; + var numberOfSavedEntities = ctx.SaveChanges(); + + ctx.Entry(entity).State = EntityState.Detached; + + entity = ctx.SomeEntities.First(e => e.Id == 1); + + Assert.Equal(1, numberOfSavedEntities); + Assert.Equal("thick", entity.Tags["kind"]); + } + } + + [Fact] + public void HStore_fetch_keys() + { + using (var ctx = CreateContext()) + { + ctx.SomeEntities.Select(e => new + { + TagNames = EF.Functions.HStoreKeys(e.Tags) + }).ToList(); + + AssertContainsInSql(@" akeys(""e"".""Tags"") "); + } + } + + [Fact] + public void HStore_fetch_values() + { + using (var ctx = CreateContext()) + { + ctx.SomeEntities.Select(e => new + { + TagValues = EF.Functions.HStoreValues(e.Tags) + }).ToList(); + + AssertContainsInSql(@" avals(""e"".""Tags"") "); + } + } + + [Fact] + public void HStore_contains_key() + { + using (var ctx = CreateContext()) + { + var selected = ctx.SomeEntities.Where(e => e.Tags.ContainsKey("type")).ToList(); + + Assert.Equal(1, selected.Count); + AssertContainsInSql(@" ""e"".""Tags"" ? 'type' "); + } + } + + [Fact] + public void HStore_key_contains_all_values_from_collection() + { + using (var ctx = CreateContext()) + { + string[] values = new string[] { "big", "small" }; + ctx.SomeEntities.Where(e => values.All(v => e.Tags.Keys.Contains(v))).ToList(); + + AssertContainsInSql(@"WHERE ""e"".""Tags"" ?& ARRAY [ 'big', 'small' "); + } + } + + [Fact] + public void HStore_key_contains_any_value_from_collection() + { + using (var ctx = CreateContext()) + { + string[] values = new string[] { "big", "small" }; + ctx.SomeEntities.Where(e => values.Any(v => e.Tags.Keys.Contains(v))).ToList(); + + AssertContainsInSql(@"WHERE ""e"".""Tags"" ?| ARRAY [ 'big', 'small' "); + } + } + + #region Support + + HStoreFixture Fixture { get; } + + public HStoreQueryTest(HStoreFixture fixture) + { + Fixture = fixture; + Fixture.TestSqlLoggerFactory.Clear(); + } + + HStoreContext CreateContext() => Fixture.CreateContext(); + + void AssertContainsInSql(string expected) + => Assert.Contains(expected, Fixture.TestSqlLoggerFactory.Sql); + + void AssertDoesNotContainInSql(string expected) + => Assert.DoesNotContain(expected, Fixture.TestSqlLoggerFactory.Sql); + + #endregion Support + } + + public class HStoreContext : DbContext + { + public DbSet SomeEntities { get; set; } + public HStoreContext(DbContextOptions options) : base(options) { } + protected override void OnModelCreating(ModelBuilder builder) + { + builder.HasPostgresExtension("hstore"); + } + } + + public class SomeEntity + { + public int Id { get; set; } + public Dictionary Tags { get; set; } + } + + public class HStoreFixture : IDisposable + { + readonly DbContextOptions _options; + public TestSqlLoggerFactory TestSqlLoggerFactory { get; } = new TestSqlLoggerFactory(); + + public HStoreFixture() + { + _testStore = NpgsqlTestStore.CreateScratch(); + _options = new DbContextOptionsBuilder() + .UseNpgsql(_testStore.Connection, b => b.ApplyConfiguration()) + .UseInternalServiceProvider( + new ServiceCollection() + .AddEntityFrameworkNpgsql() + .AddSingleton(TestSqlLoggerFactory) + .BuildServiceProvider()) + .Options; + + using (var ctx = CreateContext()) + { + ctx.Database.EnsureCreated(); + ctx.SomeEntities.Add(new SomeEntity + { + Id = 1, + Tags = new Dictionary(ImmutableDictionary.Empty.AddRange(new KeyValuePair[] + { + new KeyValuePair("kind", "big"), + new KeyValuePair("type", "car") + })) + }); + ctx.SomeEntities.Add(new SomeEntity + { + Id = 2, + Tags = new Dictionary(ImmutableDictionary.Empty.AddRange(new KeyValuePair[] + { + new KeyValuePair("kind", "big") + })) + }); + ctx.SomeEntities.Add(new SomeEntity + { + Id = 3, + Tags = new Dictionary(ImmutableDictionary.Empty.AddRange(new KeyValuePair[] + { + new KeyValuePair("kind", "small") + })) + }); + ctx.SaveChanges(); + } + } + + readonly NpgsqlTestStore _testStore; + public HStoreContext CreateContext() => new HStoreContext(_options); + public void Dispose() => _testStore.Dispose(); + } +}