diff --git a/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs b/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs index 87454d75..c6d0cd59 100644 --- a/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs +++ b/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs @@ -126,49 +126,51 @@ internal RequestState(HttpRequestMessage request, PathString pathBase, IHttpAppl } var contextFeatures = new FeatureCollection(); - contextFeatures.Set(new RequestFeature()); + var requestFeature = new RequestFeature(); + contextFeatures.Set(requestFeature); _responseFeature = new ResponseFeature(); contextFeatures.Set(_responseFeature); - Context = application.CreateContext(contextFeatures); - var httpContext = Context.HttpContext; + var requestLifetimeFeature = new HttpRequestLifetimeFeature(); + contextFeatures.Set(requestLifetimeFeature); - var serverRequest = httpContext.Request; - serverRequest.Protocol = "HTTP/" + request.Version.ToString(2); - serverRequest.Scheme = request.RequestUri.Scheme; - serverRequest.Method = request.Method.ToString(); + requestFeature.Protocol = "HTTP/" + request.Version.ToString(2); + requestFeature.Scheme = request.RequestUri.Scheme; + requestFeature.Method = request.Method.ToString(); var fullPath = PathString.FromUriComponent(request.RequestUri); PathString remainder; if (fullPath.StartsWithSegments(pathBase, out remainder)) { - serverRequest.PathBase = pathBase; - serverRequest.Path = remainder; + requestFeature.PathBase = pathBase.Value; + requestFeature.Path = remainder.Value; } else { - serverRequest.PathBase = PathString.Empty; - serverRequest.Path = fullPath; + requestFeature.PathBase = string.Empty; + requestFeature.Path = fullPath.Value; } - serverRequest.QueryString = QueryString.FromUriComponent(request.RequestUri); + requestFeature.QueryString = QueryString.FromUriComponent(request.RequestUri).Value; foreach (var header in request.Headers) { - serverRequest.Headers.Append(header.Key, header.Value.ToArray()); + requestFeature.Headers.Append(header.Key, header.Value.ToArray()); } var requestContent = request.Content; if (requestContent != null) { foreach (var header in request.Content.Headers) { - serverRequest.Headers.Append(header.Key, header.Value.ToArray()); + requestFeature.Headers.Append(header.Key, header.Value.ToArray()); } } _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest); - httpContext.Response.Body = _responseStream; - httpContext.Response.StatusCode = 200; - httpContext.RequestAborted = _requestAbortedSource.Token; + _responseFeature.Body = _responseStream; + _responseFeature.StatusCode = 200; + requestLifetimeFeature.RequestAborted = _requestAbortedSource.Token; + + Context = application.CreateContext(contextFeatures); } public Context Context { get; private set; }