From 8b9a2e3b128869302500d314145ccc33ea53ec64 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Tue, 15 Oct 2019 13:52:37 +0200 Subject: [PATCH 1/3] Fixed maxcomplexity middleware --- examples/AspNetCore.StarWars/Startup.cs | 3 +- .../MaxComplexityMiddlewareTests.cs | 87 ++++++++++++++++++- ...ests.Validate_Multiple_Levels_Invalid.snap | 20 +++++ ...eTests.Validate_Multiple_Levels_Valid.snap | 12 +++ .../Middleware/MaxComplexityMiddleware.cs | 3 +- ...ComplexityWithMultipliersVisitorContext.cs | 7 +- 6 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Invalid.snap create mode 100644 src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Valid.snap diff --git a/examples/AspNetCore.StarWars/Startup.cs b/examples/AspNetCore.StarWars/Startup.cs index 079e52121fc..4412cbf79e4 100644 --- a/examples/AspNetCore.StarWars/Startup.cs +++ b/examples/AspNetCore.StarWars/Startup.cs @@ -42,7 +42,8 @@ public void ConfigureServices(IServiceCollection services) .Create(), new QueryExecutionOptions { - TracingPreference = TracingPreference.Always + MaxOperationComplexity = 10, + UseComplexityMultipliers = true }); diff --git a/src/Core/Core.Tests/Execution/Middleware/MaxComplexityMiddlewareTests.cs b/src/Core/Core.Tests/Execution/Middleware/MaxComplexityMiddlewareTests.cs index 67f9d26135a..ec1890202b1 100644 --- a/src/Core/Core.Tests/Execution/Middleware/MaxComplexityMiddlewareTests.cs +++ b/src/Core/Core.Tests/Execution/Middleware/MaxComplexityMiddlewareTests.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using ChilliCream.Testing; using HotChocolate.Execution.Configuration; using HotChocolate.Language; using HotChocolate.Utilities; @@ -368,5 +367,91 @@ input FooInput { new SnapshotNameExtension("complexity", count)); } } + + [Fact] + public async Task Validate_Multiple_Levels_Valid() + { + // arrange + ISchema schema = SchemaBuilder.New() + .AddDocumentFromString( + @" + type Query { + foo(i: Int = 2): Foo + @cost(complexity: 1 multipliers: [""i""]) + } + + type Foo { + bar: Bar + qux: String + } + + type Bar { + baz: String + } + ") + .Use(next => context => + { + context.Result = "baz"; + return Task.CompletedTask; + }) + .Create(); + + IQueryExecutor executor = schema.MakeExecutable(new QueryExecutionOptions + { + UseComplexityMultipliers = true, + MaxOperationComplexity = 4 + }); + + IReadOnlyQueryRequest request = QueryRequestBuilder.New() + .SetQuery("query { foo { bar { baz } } }") + .Create(); + + IExecutionResult result = await executor.ExecuteAsync(request); + + result.MatchSnapshot(); + } + + [Fact] + public async Task Validate_Multiple_Levels_Invalid() + { + // arrange + ISchema schema = SchemaBuilder.New() + .AddDocumentFromString( + @" + type Query { + foo(i: Int = 2): Foo + @cost(complexity: 1 multipliers: [""i""]) + } + + type Foo { + bar: Bar + qux: String + } + + type Bar { + baz: String + } + ") + .Use(next => context => + { + context.Result = "baz"; + return Task.CompletedTask; + }) + .Create(); + + IQueryExecutor executor = schema.MakeExecutable(new QueryExecutionOptions + { + UseComplexityMultipliers = true, + MaxOperationComplexity = 4 + }); + + IReadOnlyQueryRequest request = QueryRequestBuilder.New() + .SetQuery("query { foo(i: 2) { bar { baz } qux } }") + .Create(); + + IExecutionResult result = await executor.ExecuteAsync(request); + + result.MatchSnapshot(); + } } } diff --git a/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Invalid.snap b/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Invalid.snap new file mode 100644 index 00000000000..611e3367035 --- /dev/null +++ b/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Invalid.snap @@ -0,0 +1,20 @@ +{ + "Data": {}, + "Extensions": {}, + "Errors": [ + { + "Message": "The operation that shall be executed has a complexity of 5.\nThe maximum allowed query complexity is 4.", + "Code": null, + "Path": null, + "Locations": [ + { + "Line": 1, + "Column": 1 + } + ], + "Exception": null, + "Extensions": {} + } + ], + "ContextData": {} +} diff --git a/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Valid.snap b/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Valid.snap new file mode 100644 index 00000000000..af1b03d18c8 --- /dev/null +++ b/src/Core/Core.Tests/Execution/Middleware/__snapshots__/MaxComplexityMiddlewareTests.Validate_Multiple_Levels_Valid.snap @@ -0,0 +1,12 @@ +{ + "Data": { + "foo": { + "bar": { + "baz": "baz" + } + } + }, + "Extensions": {}, + "Errors": [], + "ContextData": {} +} diff --git a/src/Core/Core/Execution/Middleware/MaxComplexityMiddleware.cs b/src/Core/Core/Execution/Middleware/MaxComplexityMiddleware.cs index b6ccf6e18d6..7121971a084 100644 --- a/src/Core/Core/Execution/Middleware/MaxComplexityMiddleware.cs +++ b/src/Core/Core/Execution/Middleware/MaxComplexityMiddleware.cs @@ -63,8 +63,7 @@ public Task InvokeAsync(IQueryContext context) IError error = ErrorBuilder.New() .SetMessage(string.Format( CultureInfo.InvariantCulture, - CoreResources - .MaxComplexityMiddleware_NotAllowed, + CoreResources.MaxComplexityMiddleware_NotAllowed, complexity, _options.MaxOperationComplexity)) .AddLocation(context.Operation.Definition) diff --git a/src/Core/Core/Validation/MaxComplexityWithMultipliersVisitorContext.cs b/src/Core/Core/Validation/MaxComplexityWithMultipliersVisitorContext.cs index 441b1fbbfbc..fc86d8db237 100644 --- a/src/Core/Core/Validation/MaxComplexityWithMultipliersVisitorContext.cs +++ b/src/Core/Core/Validation/MaxComplexityWithMultipliersVisitorContext.cs @@ -51,20 +51,19 @@ public override MaxComplexityVisitorContext AddField( { IDirective directive = fieldDefinition.Directives .FirstOrDefault(t => t.Type is CostDirectiveType); - int complexity; CostDirective cost = directive == null ? DefaultCost : directive.ToObject(); - complexity = Complexity + CalculateComplexity( + Complexity = Complexity + CalculateComplexity( new ComplexityContext( fieldDefinition, fieldSelection, FieldPath, _variables, cost)); - if (complexity > MaxComplexity) + if (Complexity > MaxComplexity) { - MaxComplexity = complexity; + MaxComplexity = Complexity; } return new MaxComplexityWithMultipliersVisitorContext( From 266cbdfd1cd5fe79dfe19d96f13918e1d82a0b68 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Wed, 16 Oct 2019 15:56:53 +0200 Subject: [PATCH 2/3] Added squadron --- .../PersistedQueries.Redis.Tests.csproj | 4 +++ .../RedisQueryStorageTests.cs | 25 +++-------------- .../ServiceCollectionExtensionsTests.cs | 25 +++-------------- .../RedisIntegrationTests.cs | 22 +++------------ .../Subscriptions.Redis.Tests/RedisTests.cs | 22 +++------------ .../Subscriptions.Redis.Tests.csproj | 4 +++ .../MongoFilterTests.cs | 22 ++++++++------- .../Types.Filters.Mongo.Tests.csproj | 2 +- .../MongoSortingTests.cs | 27 ++++++++++--------- .../Types.Sorting.Mongo.Tests.csproj | 2 +- 10 files changed, 52 insertions(+), 103 deletions(-) diff --git a/src/Core/PersistedQueries.Redis.Tests/PersistedQueries.Redis.Tests.csproj b/src/Core/PersistedQueries.Redis.Tests/PersistedQueries.Redis.Tests.csproj index 5f6b8a7afe5..50466573a23 100644 --- a/src/Core/PersistedQueries.Redis.Tests/PersistedQueries.Redis.Tests.csproj +++ b/src/Core/PersistedQueries.Redis.Tests/PersistedQueries.Redis.Tests.csproj @@ -20,4 +20,8 @@ + + + + diff --git a/src/Core/PersistedQueries.Redis.Tests/RedisQueryStorageTests.cs b/src/Core/PersistedQueries.Redis.Tests/RedisQueryStorageTests.cs index 92b840a4053..ff493d0d163 100644 --- a/src/Core/PersistedQueries.Redis.Tests/RedisQueryStorageTests.cs +++ b/src/Core/PersistedQueries.Redis.Tests/RedisQueryStorageTests.cs @@ -8,36 +8,19 @@ using Snapshooter.Xunit; using Xunit; using Snapshooter; +using Squadron; namespace HotChocolate.PersistedQueries.Redis { public class RedisQueryStorageTests + : IClassFixture { private ConnectionMultiplexer _connectionMultiplexer; private IDatabase _database; - public RedisQueryStorageTests() + public RedisQueryStorageTests(RedisResource redisResource) { - string endpoint = - Environment.GetEnvironmentVariable("REDIS_ENDPOINT") - ?? "localhost:6379"; - - string password = - Environment.GetEnvironmentVariable("REDIS_PASSWORD"); - - var configuration = new ConfigurationOptions - { - Ssl = !string.IsNullOrEmpty(password), - AbortOnConnectFail = false, - Password = password - }; - - configuration.EndPoints.Add(endpoint); - - _connectionMultiplexer = - ConnectionMultiplexer.Connect(configuration); - - _database = _connectionMultiplexer.GetDatabase(); + _database = redisResource.GetConnection().GetDatabase(); } [Fact] diff --git a/src/Core/PersistedQueries.Redis.Tests/ServiceCollectionExtensionsTests.cs b/src/Core/PersistedQueries.Redis.Tests/ServiceCollectionExtensionsTests.cs index ae21ef852d8..fca81198554 100644 --- a/src/Core/PersistedQueries.Redis.Tests/ServiceCollectionExtensionsTests.cs +++ b/src/Core/PersistedQueries.Redis.Tests/ServiceCollectionExtensionsTests.cs @@ -5,36 +5,19 @@ using Xunit; using HotChocolate.Utilities; using Snapshooter.Xunit; +using Squadron; namespace HotChocolate.PersistedQueries.Redis { public class ServiceCollectionExtensionsTests + : IClassFixture { private ConnectionMultiplexer _connectionMultiplexer; private IDatabase _database; - public ServiceCollectionExtensionsTests() + public ServiceCollectionExtensionsTests(RedisResource redisResource) { - string endpoint = - Environment.GetEnvironmentVariable("REDIS_ENDPOINT") - ?? "localhost:6379"; - - string password = - Environment.GetEnvironmentVariable("REDIS_PASSWORD"); - - var configuration = new ConfigurationOptions - { - Ssl = !string.IsNullOrEmpty(password), - AbortOnConnectFail = false, - Password = password - }; - - configuration.EndPoints.Add(endpoint); - - _connectionMultiplexer = - ConnectionMultiplexer.Connect(configuration); - - _database = _connectionMultiplexer.GetDatabase(); + _database = redisResource.GetConnection().GetDatabase(); } [Fact] diff --git a/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs b/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs index 5b7e7565479..bed2ad68efe 100644 --- a/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs +++ b/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs @@ -8,34 +8,20 @@ using HotChocolate.Types; using HotChocolate.Execution; using HotChocolate.Language; +using Squadron; namespace HotChocolate.Subscriptions.Redis { public class RedisIntegrationTests + : IClassFixture { private readonly IEventSender _sender; private readonly ConfigurationOptions _configuration; - public RedisIntegrationTests() + public RedisIntegrationTests(RedisResource redisResource) { - string endpoint = - Environment.GetEnvironmentVariable("REDIS_ENDPOINT") - ?? "localhost:6379"; - - string password = - Environment.GetEnvironmentVariable("REDIS_PASSWORD"); - - _configuration = new ConfigurationOptions - { - Ssl = !string.IsNullOrEmpty(password), - AbortOnConnectFail = false, - Password = password - }; - - _configuration.EndPoints.Add(endpoint); - var redisEventRegistry = new RedisEventRegistry( - ConnectionMultiplexer.Connect(_configuration), + redisResource.GetConnection(), new JsonPayloadSerializer()); _sender = redisEventRegistry; diff --git a/src/Core/Subscriptions.Redis.Tests/RedisTests.cs b/src/Core/Subscriptions.Redis.Tests/RedisTests.cs index deb0c8d1c03..eb682f96b1d 100644 --- a/src/Core/Subscriptions.Redis.Tests/RedisTests.cs +++ b/src/Core/Subscriptions.Redis.Tests/RedisTests.cs @@ -2,36 +2,22 @@ using System.Threading; using System.Threading.Tasks; using HotChocolate.Language; +using Squadron; using StackExchange.Redis; using Xunit; namespace HotChocolate.Subscriptions.Redis { public class RedisTests + : IClassFixture { private readonly IEventRegistry _registry; private readonly IEventSender _sender; - public RedisTests() + public RedisTests(RedisResource redisResource) { - string endpoint = - Environment.GetEnvironmentVariable("REDIS_ENDPOINT") - ?? "localhost:6379"; - - string password = - Environment.GetEnvironmentVariable("REDIS_PASSWORD"); - - var configuration = new ConfigurationOptions - { - Ssl = !string.IsNullOrEmpty(password), - AbortOnConnectFail = false, - Password = password - }; - - configuration.EndPoints.Add(endpoint); - var redisEventRegistry = new RedisEventRegistry( - ConnectionMultiplexer.Connect(configuration), + redisResource.GetConnection(), new JsonPayloadSerializer()); _sender = redisEventRegistry; diff --git a/src/Core/Subscriptions.Redis.Tests/Subscriptions.Redis.Tests.csproj b/src/Core/Subscriptions.Redis.Tests/Subscriptions.Redis.Tests.csproj index 9cc727da155..0350219d097 100644 --- a/src/Core/Subscriptions.Redis.Tests/Subscriptions.Redis.Tests.csproj +++ b/src/Core/Subscriptions.Redis.Tests/Subscriptions.Redis.Tests.csproj @@ -22,4 +22,8 @@ + + + + diff --git a/src/Core/Types.Filters.Mongo.Tests/MongoFilterTests.cs b/src/Core/Types.Filters.Mongo.Tests/MongoFilterTests.cs index 1e05126d286..620dc6fab7a 100644 --- a/src/Core/Types.Filters.Mongo.Tests/MongoFilterTests.cs +++ b/src/Core/Types.Filters.Mongo.Tests/MongoFilterTests.cs @@ -1,5 +1,4 @@ using System.Linq; -using System; using MongoDB.Driver; using MongoDB.Bson; using Microsoft.Extensions.DependencyInjection; @@ -8,11 +7,20 @@ using System.Threading.Tasks; using Snapshooter.Xunit; using HotChocolate.Types.Relay; +using Squadron; namespace HotChocolate.Types.Filters { public class MongoFilterTests + : IClassFixture { + private readonly MongoResource _mongoResource; + + public MongoFilterTests(MongoResource mongoResource) + { + _mongoResource = mongoResource; + } + [Fact] public async Task GetItems_NoFilter_AllItems_Are_Returned() { @@ -20,9 +28,7 @@ public async Task GetItems_NoFilter_AllItems_Are_Returned() var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton>(sp => { - MongoClient client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); var collection = database.GetCollection("col"); collection.InsertMany(new[] @@ -58,9 +64,7 @@ public async Task GetItems_EqualsFilter_FirstItems_Is_Returned() var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton>(sp => { - MongoClient client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); var collection = database.GetCollection("col"); collection.InsertMany(new[] @@ -96,9 +100,7 @@ public async Task GetItems_With_Paging_EqualsFilter_FirstItems_Is_Returned() var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton>(sp => { - MongoClient client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); var collection = database.GetCollection("col"); collection.InsertMany(new[] diff --git a/src/Core/Types.Filters.Mongo.Tests/Types.Filters.Mongo.Tests.csproj b/src/Core/Types.Filters.Mongo.Tests/Types.Filters.Mongo.Tests.csproj index 390c35046e6..c337ed835ad 100644 --- a/src/Core/Types.Filters.Mongo.Tests/Types.Filters.Mongo.Tests.csproj +++ b/src/Core/Types.Filters.Mongo.Tests/Types.Filters.Mongo.Tests.csproj @@ -16,7 +16,7 @@ - + diff --git a/src/Core/Types.Sorting.Mongo.Tests/MongoSortingTests.cs b/src/Core/Types.Sorting.Mongo.Tests/MongoSortingTests.cs index 751f57ce8ab..1458fb51bfc 100644 --- a/src/Core/Types.Sorting.Mongo.Tests/MongoSortingTests.cs +++ b/src/Core/Types.Sorting.Mongo.Tests/MongoSortingTests.cs @@ -1,4 +1,3 @@ -using System; using MongoDB.Driver; using MongoDB.Bson; using Microsoft.Extensions.DependencyInjection; @@ -7,11 +6,21 @@ using System.Threading.Tasks; using Snapshooter.Xunit; using HotChocolate.Types.Relay; +using Squadron; namespace HotChocolate.Types.Sorting { public class MongoSortingTests + : IClassFixture { + private readonly MongoResource _mongoResource; + + public MongoSortingTests(MongoResource mongoResource) + { + _mongoResource = mongoResource; + } + + [Fact] public async Task GetItems_NoSorting_AllItems_Are_Returned_Unsorted() { @@ -19,9 +28,7 @@ public async Task GetItems_NoSorting_AllItems_Are_Returned_Unsorted() var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(sp => { - var client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); IMongoCollection collection = database.GetCollection("col"); @@ -59,9 +66,7 @@ public async Task GetItems_DescSorting_AllItems_Are_Returned_DescSorted() var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(sp => { - var client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); IMongoCollection collection = database.GetCollection("col"); collection.InsertMany(new[] @@ -98,9 +103,7 @@ public async Task GetItems_With_Paging__DescSorting_AllItems_Are_Returned_DescSo var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(sp => { - var client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); IMongoCollection collection = database.GetCollection("col"); collection.InsertMany(new[] @@ -138,9 +141,7 @@ public async Task GetItems_OnRenamedField_DescSorting_AllItems_Are_Returned_Desc var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(sp => { - var client = new MongoClient(); - IMongoDatabase database = client.GetDatabase( - "db_" + Guid.NewGuid().ToString("N")); + IMongoDatabase database = _mongoResource.CreateDatabase(); IMongoCollection collection = database.GetCollection("col"); collection.InsertMany(new[] diff --git a/src/Core/Types.Sorting.Mongo.Tests/Types.Sorting.Mongo.Tests.csproj b/src/Core/Types.Sorting.Mongo.Tests/Types.Sorting.Mongo.Tests.csproj index 63191837156..453893f2732 100644 --- a/src/Core/Types.Sorting.Mongo.Tests/Types.Sorting.Mongo.Tests.csproj +++ b/src/Core/Types.Sorting.Mongo.Tests/Types.Sorting.Mongo.Tests.csproj @@ -17,7 +17,7 @@ - + From 0c5c48694e8b337827dc31b4ed6d8967a7737209 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Wed, 16 Oct 2019 17:00:19 +0200 Subject: [PATCH 3/3] Fixed redis tests --- .../RedisIntegrationTests.cs | 13 ++++--- ...SubscriptionServiceCollectionExtensions.cs | 35 ++++++++++++++++--- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs b/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs index bed2ad68efe..38f9e2704ce 100644 --- a/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs +++ b/src/Core/Subscriptions.Redis.Tests/RedisIntegrationTests.cs @@ -15,16 +15,15 @@ namespace HotChocolate.Subscriptions.Redis public class RedisIntegrationTests : IClassFixture { + private readonly ConnectionMultiplexer _connection; private readonly IEventSender _sender; - private readonly ConfigurationOptions _configuration; public RedisIntegrationTests(RedisResource redisResource) { - var redisEventRegistry = new RedisEventRegistry( - redisResource.GetConnection(), + _connection = redisResource.GetConnection(); + _sender = new RedisEventRegistry( + _connection, new JsonPayloadSerializer()); - - _sender = redisEventRegistry; } [Fact] @@ -34,7 +33,7 @@ public Task Subscribe() { // arrange var services = new ServiceCollection(); - services.AddRedisSubscriptionProvider(_configuration); + services.AddRedisSubscriptionProvider(_connection); IServiceProvider serviceProvider = services.BuildServiceProvider(); @@ -76,7 +75,7 @@ public Task Subscribe_With_ObjectValue() { // arrange var services = new ServiceCollection(); - services.AddRedisSubscriptionProvider(_configuration); + services.AddRedisSubscriptionProvider(_connection); IServiceProvider serviceProvider = services.BuildServiceProvider(); diff --git a/src/Core/Subscriptions.Redis/RedisSubscriptionServiceCollectionExtensions.cs b/src/Core/Subscriptions.Redis/RedisSubscriptionServiceCollectionExtensions.cs index 5b84a77702f..c9f6b556e99 100644 --- a/src/Core/Subscriptions.Redis/RedisSubscriptionServiceCollectionExtensions.cs +++ b/src/Core/Subscriptions.Redis/RedisSubscriptionServiceCollectionExtensions.cs @@ -10,8 +10,12 @@ public static class RedisSubscriptionServiceCollectionExtensions public static IServiceCollection AddRedisSubscriptionProvider( this IServiceCollection services, ConfigurationOptions options) => - services - .AddRedisSubscriptionProvider(options); + services.AddRedisSubscriptionProvider(options); + + public static IServiceCollection AddRedisSubscriptionProvider( + this IServiceCollection services, + ConnectionMultiplexer connection) => + services.AddRedisSubscriptionProvider(connection); public static IServiceCollection AddRedisSubscriptionProvider( this IServiceCollection services, @@ -23,10 +27,33 @@ public static IServiceCollection AddRedisSubscriptionProvider( throw new ArgumentNullException(nameof(services)); } + services.AddSingleton(sp => + ConnectionMultiplexer.Connect(options)); + AddServices(services); + return services; + } + + public static IServiceCollection AddRedisSubscriptionProvider( + this IServiceCollection services, + ConnectionMultiplexer connection) + where TSerializer : class, IPayloadSerializer + { + if (services == null) + { + throw new ArgumentNullException(nameof(services)); + } + + services.AddSingleton(sp => connection); + AddServices(services); + return services; + } + + private static IServiceCollection AddServices( + IServiceCollection services) + where TSerializer : class, IPayloadSerializer + { services .AddSingleton() - .AddSingleton(sp => - ConnectionMultiplexer.Connect(options)) .AddSingleton() .AddSingleton(sp => sp.GetRequiredService())