diff --git a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerFeature.cs b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerFeature.cs new file mode 100644 index 00000000..4edc1339 --- /dev/null +++ b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerFeature.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNet.Http.Features.Internal; +using Microsoft.Framework.DependencyInjection; + +namespace Microsoft.AspNet.Hosting.Internal +{ + public class RequestServicesFeature : IServiceProvidersFeature, IDisposable + { + private IServiceProvider _appServices; + private IServiceProvider _requestServices; + private IServiceScope _scope; + private bool _requestServicesSet; + + public RequestServicesFeature(IServiceProvider applicationServices) + { + if (applicationServices == null) + { + throw new ArgumentNullException(nameof(applicationServices)); + } + + ApplicationServices = applicationServices; + } + + public IServiceProvider ApplicationServices + { + get + { + return _appServices; + } + set + { + if (value == null) + { + throw new ArgumentNullException(nameof(value)); + } + _appServices = value; + } + } + + public IServiceProvider RequestServices + { + get + { + if (!_requestServicesSet) + { + _scope = ApplicationServices.GetRequiredService().CreateScope(); + _requestServices = _scope.ServiceProvider; + _requestServicesSet = true; + } + return _requestServices; + } + + set + { + _requestServicesSet = true; + RequestServices = value; + } + } + + public void Dispose() + { + _scope?.Dispose(); + _scope = null; + _requestServices = null; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs index ab3fda31..25949db5 100644 --- a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs +++ b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs @@ -5,6 +5,8 @@ using System.Threading.Tasks; using Microsoft.AspNet.Builder; using Microsoft.AspNet.Http; +using Microsoft.AspNet.Http.Features; +using Microsoft.AspNet.Http.Features.Internal; using Microsoft.Framework.DependencyInjection; namespace Microsoft.AspNet.Hosting.Internal @@ -20,7 +22,6 @@ public RequestServicesContainerMiddleware(RequestDelegate next, IServiceProvider { throw new ArgumentNullException(nameof(next)); } - if (services == null) { throw new ArgumentNullException(nameof(services)); @@ -37,32 +38,26 @@ public async Task Invoke(HttpContext httpContext) throw new ArgumentNullException(nameof(httpContext)); } - // All done if there request services is set - if (httpContext.RequestServices != null) + var existingFeature = httpContext.Features.Get(); + + // All done if request services is set + if (existingFeature?.RequestServices != null) { await _next.Invoke(httpContext); return; } - var priorApplicationServices = httpContext.ApplicationServices; - var serviceProvider = priorApplicationServices ?? _services; - var scopeFactory = serviceProvider.GetRequiredService(); - - try + using (var feature = new RequestServicesFeature(_services)) { - // Creates the scope and temporarily swap services - using (var scope = scopeFactory.CreateScope()) + try { - httpContext.ApplicationServices = serviceProvider; - httpContext.RequestServices = scope.ServiceProvider; - + httpContext.Features.Set(feature); await _next.Invoke(httpContext); } - } - finally - { - httpContext.RequestServices = null; - httpContext.ApplicationServices = priorApplicationServices; + finally + { + httpContext.Features.Set(existingFeature); + } } } } diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs index ed4bc90a..a41a9141 100644 --- a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs +++ b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs @@ -11,6 +11,8 @@ using Microsoft.AspNet.Hosting; using Microsoft.AspNet.Hosting.Startup; using Microsoft.AspNet.Http; +using Microsoft.AspNet.Http.Features; +using Microsoft.AspNet.Http.Features.Internal; using Microsoft.Framework.Configuration; using Microsoft.Framework.DependencyInjection; using Microsoft.Framework.Logging; @@ -133,6 +135,103 @@ public async Task ExistingRequestServicesWillNotBeReplaced() Assert.Equal("Found:True", result); } + [Fact] + public async Task SettingApplicationServicesOnFeatureToNullThrows() + { + var server = TestServer.Create(app => + { + app.Run(context => + { + var feature = context.Features.Get(); + Assert.Throws(() => feature.ApplicationServices = null); + return context.Response.WriteAsync("Success"); + }); + }); + string result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("Success", result); + } + + public class ReplaceServiceProvidersFeatureFilter : IStartupFilter, IServiceProvidersFeature + { + public ReplaceServiceProvidersFeatureFilter(IServiceProvider appServices, IServiceProvider requestServices) + { + ApplicationServices = appServices; + RequestServices = requestServices; + } + + public IServiceProvider ApplicationServices { get; set; } + + public IServiceProvider RequestServices { get; set; } + + public Action Configure(Action next) + { + return app => + { + app.Use(async (context, nxt) => + { + context.Features.Set(this); + await nxt(); + }); + next(app); + }; + } + } + + [Fact] + public async Task ExistingServiceProviderFeatureWillNotBeReplaced() + { + var appServices = new ServiceCollection().BuildServiceProvider(); + var server = TestServer.Create(app => + { + app.Run(context => + { + Assert.Equal(appServices, context.ApplicationServices); + Assert.Equal(appServices, context.RequestServices); + return context.Response.WriteAsync("Success"); + }); + }, + services => services.AddInstance(new ReplaceServiceProvidersFeatureFilter(appServices, appServices))); + var result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("Success", result); + } + + public class NullServiceProvidersFeatureFilter : IStartupFilter, IServiceProvidersFeature + { + public IServiceProvider ApplicationServices { get; set; } + + public IServiceProvider RequestServices { get; set; } + + public Action Configure(Action next) + { + return app => + { + app.Use(async (context, nxt) => + { + context.Features.Set(this); + await nxt(); + }); + next(app); + }; + } + } + + [Fact] + public async Task WillReplaceServiceProviderFeatureWithNullRequestServices() + { + var server = TestServer.Create(app => + { + app.Run(context => + { + Assert.NotNull(context.ApplicationServices); + Assert.NotNull(context.RequestServices); + return context.Response.WriteAsync("Success"); + }); + }, + services => services.AddTransient()); + var result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("Success", result); + } + public class EnsureApplicationServicesFilter : IStartupFilter { public Action Configure(Action next)