From c746a3a9bdb48b5d5ec11324be9e101173b51331 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Thu, 15 Jul 2021 11:41:42 -0700 Subject: [PATCH 1/8] Support optionality via nullability and default values --- .../src/RequestDelegateFactory.cs | 153 +++++- ...ft.AspNetCore.Http.Extensions.Tests.csproj | 2 + .../test/RequestDelegateFactoryTests.cs | 482 +++++++++++++++++- 3 files changed, 594 insertions(+), 43 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index c30b31a8f8ee..f041801bbcd0 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -22,6 +22,8 @@ namespace Microsoft.AspNetCore.Http /// public static partial class RequestDelegateFactory { + private static readonly NullabilityInfoContext nullabilityContext = new NullabilityInfoContext(); + private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -31,12 +33,16 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteObjectReturnMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteObjectReturn), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; + private static readonly MethodInfo GetServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = GetMethodInfo>((response, text) => HttpResponseWritingExtensions.WriteAsync(response, text, default)); private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); private static readonly MethodInfo LogParameterBindingFailureMethod = GetMethodInfo>((httpContext, parameterType, parameterName, sourceValue) => Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue)); + private static readonly MethodInfo LogRequiredParameterNotProvidedMethod = GetMethodInfo>((httpContext, parameterType, parameterName) => + Log.RequiredParameterNotProvided(httpContext, parameterType, parameterName)); + private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target"); private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); @@ -217,11 +223,11 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext } else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } bodyAttribute) { - return BindParameterFromBody(parameter.ParameterType, bodyAttribute.AllowEmpty, factoryContext); + return BindParameterFromBody(parameter, bodyAttribute.AllowEmpty, factoryContext); } else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType))) { - return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + return BindParameterFromService(parameter); } else if (parameter.ParameterType == typeof(HttpContext)) { @@ -256,16 +262,30 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext } else { + + var nullability = nullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { - // If the parameter resolves as a service then get it from services - if (serviceProviderIsService.IsService(parameter.ParameterType)) + // If the parameter is required + if (!isOptional) { - return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + // And we are able to resolve a service for it + return serviceProviderIsService.IsService(parameter.ParameterType) + ? Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) // Then get it from the DI + : BindParameterFromBody(parameter, allowEmpty: false, factoryContext); // Otherwise try to find it in the body + } + // If the parameter is optional + else + { + // Then try to resolve it as an optional service and fallback to a body otherwise + return Expression.Coalesce( + Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr), + BindParameterFromBody(parameter, allowEmpty: false, factoryContext)); } } - return BindParameterFromBody(parameter.ParameterType, allowEmpty: false, factoryContext); + return BindParameterFromBody(parameter, allowEmpty: false, factoryContext); } } @@ -479,13 +499,9 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, return async (target, httpContext) => { - object? bodyValue; + object? bodyValue = defaultBodyValue; - if (factoryContext.AllowEmptyRequestBody && httpContext.Request.ContentLength == 0) - { - bodyValue = defaultBodyValue; - } - else + if (httpContext.Request.ContentLength != 0 && httpContext.Request.HasJsonContentType()) { try { @@ -516,21 +532,53 @@ private static Expression GetValueFromProperty(Expression sourceExpression, stri return Expression.Convert(indexExpression, typeof(string)); } + private static Expression BindParameterFromService(ParameterInfo parameter) + { + var nullability = nullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + + return isOptional + ? Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) + : Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + } + private static Expression BindParameterFromValue(ParameterInfo parameter, Expression valueExpression, FactoryContext factoryContext) { + var nullability = nullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + if (parameter.ParameterType == typeof(string)) { - if (!parameter.HasDefaultValue) + factoryContext.UsingTempSourceString = true; + + if (!isOptional) { - return valueExpression; + var checkRequiredStringParameterBlock = Expression.Block( + Expression.Assign(TempSourceStringExpr, valueExpression), + Expression.IfThen(Expression.Not(TempSourceStringNotNullExpr), + Expression.Block( + Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); + + factoryContext.TryParseParams.Add((TempSourceStringExpr, checkRequiredStringParameterBlock)); + return Expression.Block(TempSourceStringExpr); + } + + // Allow nullable parameters that don't have a default value + if (nullability.ReadState == NullabilityState.Nullable && !parameter.HasDefaultValue) + { + return Expression.Block(Expression.Assign(TempSourceStringExpr, valueExpression)); } - factoryContext.UsingTempSourceString = true; return Expression.Block( Expression.Assign(TempSourceStringExpr, valueExpression), Expression.Condition(TempSourceStringNotNullExpr, TempSourceStringExpr, - Expression.Constant(parameter.DefaultValue))); + Expression.Convert(Expression.Constant(parameter.DefaultValue), parameter.ParameterType))); } factoryContext.UsingTempSourceString = true; @@ -598,6 +646,17 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres var tryParseCall = tryParseMethodCall(parsedValue); + // If the parameter is required, fail to parse and log an error + var checkRequiredParaseableParameterBlock = Expression.Block( + Expression.IfThen(Expression.Not(TempSourceStringNotNullExpr), + Expression.Block( + Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, parameterTypeNameConstant, parameterNameConstant) + ) + ) + ); + // If the parameter is nullable, we need to assign the "parsedValue" local to the nullable parameter on success. Expression tryParseExpression = isNotNullable ? Expression.IfThen(Expression.Not(tryParseCall), failBlock) : @@ -612,11 +671,18 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres tryParseExpression, Expression.Assign(argument, Expression.Constant(parameter.DefaultValue))); - var fullTryParseBlock = Expression.Block( - // tempSourceString = httpContext.RequestValue["id"]; - Expression.Assign(TempSourceStringExpr, valueExpression), - // if (tempSourceString != null) { ... } - ifNotNullTryParse); + var fullTryParseBlock = !isOptional + ? Expression.Block( + // tempSourceString = httpContext.RequestValue["id"]; + Expression.Assign(TempSourceStringExpr, valueExpression), + checkRequiredParaseableParameterBlock, + // if (tempSourceString != null) { ... } + ifNotNullTryParse) + : Expression.Block( + // tempSourceString = httpContext.RequestValue["id"]; + Expression.Assign(TempSourceStringExpr, valueExpression), + // if (tempSourceString != null) { ... } + ifNotNullTryParse); factoryContext.TryParseParams.Add((argument, fullTryParseBlock)); @@ -633,17 +699,46 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext); } - private static Expression BindParameterFromBody(Type parameterType, bool allowEmpty, FactoryContext factoryContext) + private static Expression BindParameterFromBody(ParameterInfo parameter, bool allowEmpty, FactoryContext factoryContext) { if (factoryContext.JsonRequestBodyType is not null) { throw new InvalidOperationException("Action cannot have more than one FromBody attribute."); } - factoryContext.JsonRequestBodyType = parameterType; - factoryContext.AllowEmptyRequestBody = allowEmpty; + var nullability = nullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + + factoryContext.JsonRequestBodyType = parameter.ParameterType; + factoryContext.AllowEmptyRequestBody = allowEmpty || isOptional; + + var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); - return Expression.Convert(BodyValueExpr, parameterType); + if (!isOptional && !allowEmpty) + { + var checkRequiredBodyBlock = Expression.Block( + Expression.Assign(argument, Expression.Convert(BodyValueExpr, parameter.ParameterType)), + Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), + Expression.Block( + Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); + factoryContext.TryParseParams.Add((argument, checkRequiredBodyBlock)); + return argument; + } + + if (parameter.HasDefaultValue) + { + // Convert(bodyValue ?? SomeDefault, Todo) + return Expression.Convert( + Expression.Coalesce(BodyValueExpr, Expression.Constant(parameter.DefaultValue)), + parameter.ParameterType); + } + + return Expression.Convert(BodyValueExpr, parameter.ParameterType); } private static MethodInfo GetMethodInfo(Expression expr) @@ -847,11 +942,19 @@ public static void RequestBodyInvalidDataException(HttpContext httpContext, Inva public static void ParameterBindingFailed(HttpContext httpContext, string parameterTypeName, string parameterName, string sourceValue) => ParameterBindingFailed(GetLogger(httpContext), parameterTypeName, parameterName, sourceValue); + public static void RequiredParameterNotProvided(HttpContext httpContext, string parameterTypeName, string parameterName) + => RequiredParameterNotProvided(GetLogger(httpContext), parameterTypeName, parameterName); + [LoggerMessage(3, LogLevel.Debug, @"Failed to bind parameter ""{ParameterType} {ParameterName}"" from ""{SourceValue}"".", EventName = "ParamaterBindingFailed")] private static partial void ParameterBindingFailed(ILogger logger, string parameterType, string parameterName, string sourceValue); + [LoggerMessage(4, LogLevel.Debug, + @"Required parameter ""{ParameterType} {ParameterName}"" was not provided.", + EventName = "RequiredParameterNotProvided")] + private static partial void RequiredParameterNotProvided(ILogger logger, string parameterType, string parameterName); + private static ILogger GetLogger(HttpContext httpContext) { var loggerFactory = httpContext.RequestServices.GetRequiredService(); diff --git a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj index fffeeb054b84..25444590f57a 100644 --- a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj +++ b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj @@ -2,6 +2,8 @@ $(DefaultNetCoreTargetFramework) + + $(Features.Replace('nullablePublicOnly', '') diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 840319aed9c7..9ba05911af95 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -360,7 +360,7 @@ void TestAction([FromRoute(Name = specifiedName)] int foo) } [Fact] - public async Task UsesDefaultValueIfNoMatchingRouteValue() + public async Task Returns400IfNoMatchingRouteValueForRequiredParam() { const string unmatchedName = "value"; const int unmatchedRouteParam = 42; @@ -375,11 +375,15 @@ void TestAction([FromRoute] int foo) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[unmatchedName] = unmatchedRouteParam.ToString(NumberFormatInfo.InvariantInfo); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(TestAction); await requestDelegate(httpContext); - Assert.Equal(0, deserializedRouteParam); + Assert.Equal(400, httpContext.Response.StatusCode); } public static object?[][] TryParsableParameters @@ -424,7 +428,6 @@ static void Store(HttpContext httpContext, T tryParsable) new object[] { (Action)Store, "42", 42 }, new object[] { (Action)Store, "ValueB", MyEnum.ValueB }, new object[] { (Action)Store, "https://example.org", new MyTryParsableRecord(new Uri("https://example.org")) }, - new object?[] { (Action)Store, null, 0 }, new object?[] { (Action)Store, null, null }, }; } @@ -454,6 +457,10 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues["tryParsable"] = routeValue; + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(action); await requestDelegate(httpContext); @@ -471,6 +478,10 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQ ["tryParsable"] = routeValue }); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(action); await requestDelegate(httpContext); @@ -663,10 +674,13 @@ public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) }; var httpContext = new DefaultHttpContext(); - httpContext.Request.Headers["Content-Type"] = "application/json"; var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); - httpContext.Request.Body = new MemoryStream(requestBodyBytes); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); var jsonOptions = new JsonOptions(); jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); @@ -699,9 +713,15 @@ public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(action); - await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + await requestDelegate(httpContext); + + Assert.Equal(400, httpContext.Response.StatusCode); } [Fact] @@ -765,6 +785,7 @@ void TestAction([FromBody] Todo todo) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(ioException); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); @@ -798,6 +819,7 @@ void TestAction([FromBody] Todo todo) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(invalidDataException); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); @@ -892,18 +914,6 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt Assert.Same(myOriginalService, httpContext.Items["service"]); } - [Theory] - [MemberData(nameof(FromServiceActions))] - public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) - { - var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = new EmptyServiceProvider(); - - var requestDelegate = RequestDelegateFactory.Create(action); - - await Assert.ThrowsAsync(() => requestDelegate(httpContext)); - } - [Fact] public async Task RequestDelegatePopulatesHttpContextParameterWithoutAttribute() { @@ -1354,6 +1364,442 @@ public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate) Assert.Equal("null", responseBody); } + public static IEnumerable QueryParamOptionalityData + { + get + { + string requiredQueryParam(string name) => $"Hello {name}!"; + string defaultValueQueryParam(string name = "DefaultName") => $"Hello {name}!"; + string nullableQueryParam(string? name) => $"Hello {name}!"; + string requiredParseableQueryParam(int age) => $"Age: {age}"; + string defaultValueParseableQueryParam(int age = 12) => $"Age: {age}"; + string nullableQueryParseableParam(int? age) => $"Age: {age}"; + + return new List + { + new object?[] { (Func)requiredQueryParam, "name", null, true, null}, + new object?[] { (Func)requiredQueryParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)defaultValueQueryParam, "name", null, false, "Hello DefaultName!" }, + new object?[] { (Func)defaultValueQueryParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)nullableQueryParam, "name", null, false, "Hello !" }, + new object?[] { (Func)nullableQueryParam, "name", "TestName", false, "Hello TestName!"}, + + new object?[] { (Func)requiredParseableQueryParam, "age", null, true, null}, + new object?[] { (Func)requiredParseableQueryParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)defaultValueParseableQueryParam, "age", null, false, "Age: 12" }, + new object?[] { (Func)defaultValueParseableQueryParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)nullableQueryParseableParam, "age", null, false, "Age: " }, + new object?[] { (Func)nullableQueryParseableParam, "age", "42", false, "Age: 42"}, + }; + } + } + + [Theory] + [MemberData(nameof(QueryParamOptionalityData))] + public async Task RequestDelegateHandlesQueryParamOptionality(Delegate @delegate, string paramName, string? queryParam, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (queryParam is not null) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + [paramName] = queryParam + }); + } + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + var expectedType = paramName == "age" ? "Int32 age" : "String name"; + Assert.Equal($@"Required parameter ""{expectedType}"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable RouteParamOptionalityData + { + get + { + string requiredRouteParam(string name) => $"Hello {name}!"; + string defaultValueRouteParam(string name = "DefaultName") => $"Hello {name}!"; + string nullableRouteParam(string? name) => $"Hello {name}!"; + string requiredParseableRouteParam(int age) => $"Age: {age}"; + string defaultValueParseableRouteParam(int age = 12) => $"Age: {age}"; + string nullableParseableRouteParam(int? age) => $"Age: {age}"; + + return new List + { + new object?[] { (Func)requiredRouteParam, "name", null, true, null}, + new object?[] { (Func)requiredRouteParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)defaultValueRouteParam, "name", null, false, "Hello DefaultName!" }, + new object?[] { (Func)defaultValueRouteParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)nullableRouteParam, "name", null, false, "Hello !" }, + new object?[] { (Func)nullableRouteParam, "name", "TestName", false, "Hello TestName!" }, + + new object?[] { (Func)requiredParseableRouteParam, "age", null, true, null}, + new object?[] { (Func)requiredParseableRouteParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)defaultValueParseableRouteParam, "age", null, false, "Age: 12" }, + new object?[] { (Func)defaultValueParseableRouteParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)nullableParseableRouteParam, "age", null, false, "Age: " }, + new object?[] { (Func)nullableParseableRouteParam, "age", "42", false, "Age: 42"}, + }; + } + } + + [Theory] + [MemberData(nameof(RouteParamOptionalityData))] + public async Task RequestDelegateHandlesRouteParamOptionality(Delegate @delegate, string paramName, string? routeParam, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (routeParam is not null) + { + httpContext.Request.RouteValues[paramName] = routeParam; + } + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + var expectedType = paramName == "age" ? "Int32 age" : "String name"; + Assert.Equal($@"Required parameter ""{expectedType}"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable BodyParamOptionalityData + { + get + { + string requiredBodyParam(Todo todo) => $"Todo: {todo.Name}"; + string defaultValueBodyParam(Todo? todo = null) => $"Todo: {todo?.Name}"; + string nullableBodyParam(Todo? todo) => $"Todo: {todo?.Name}"; + + return new List + { + new object?[] { (Func)requiredBodyParam, false, true, null }, + new object?[] { (Func)requiredBodyParam, true, false, "Todo: Default Todo"}, + new object?[] { (Func)defaultValueBodyParam, false, false, "Todo: "}, + new object?[] { (Func)defaultValueBodyParam, true, false, "Todo: Default Todo"}, + new object?[] { (Func)nullableBodyParam, false, false, "Todo: " }, + new object?[] { (Func)nullableBodyParam, true, false, "Todo: Default Todo" }, + }; + } + } + + [Theory] + [MemberData(nameof(BodyParamOptionalityData))] + public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, bool hasBody, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (hasBody) + { + var todo = new Todo() { Name = "Default Todo" }; + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.ContentLength = stream.Length; + } + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddSingleton(Options.Create(jsonOptions)); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + Assert.Equal(@"Required parameter ""Todo todo"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable ServiceParamOptionalityData + { + get + { + string requiredExplicitService([FromService] MyService service) => $"Service: {service}"; + string defaultValueExplicitServiceParam([FromService] MyService? service = null) => $"Service: {service}"; + string nullableExplicitServiceParam([FromService] MyService? service) => $"Service: {service}"; + + return new List + { + new object?[] { (Func)requiredExplicitService, false, true}, + new object?[] { (Func)requiredExplicitService, true, false}, + + new object?[] { (Func)defaultValueExplicitServiceParam, false, false}, + new object?[] { (Func)defaultValueExplicitServiceParam, true, false}, + + new object?[] { (Func)nullableExplicitServiceParam, false, false}, + new object?[] { (Func)nullableExplicitServiceParam, true, false}, + }; + } + } + + [Theory] + [MemberData(nameof(ServiceParamOptionalityData))] + public async Task RequestDelegateHandlesServiceParamOptionality(Delegate @delegate, bool hasService, bool isInvalid) + { + var httpContext = new DefaultHttpContext(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + if (hasService) + { + var service = new MyService(); + + serviceCollection.AddSingleton(service); + } + var services = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = services; + RequestDelegateFactoryOptions options = new() { ServiceProvider = services }; + + var requestDelegate = RequestDelegateFactory.Create(@delegate, options); + + if (!isInvalid) + { + await requestDelegate(httpContext); + Assert.Equal(200, httpContext.Response.StatusCode); + } + else + { + await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + } + } + + public static IEnumerable ImplicitServiceParamOptionalityData + { + get + { + string requiredImplicitService(MyService name) => $"Hello {name}!"; + string defaultValueImplicitServiceParam(MyService? name = null) => $"Hello {name}!"; + string nullableImplicitServiceParam(MyService? name) => $"Hello {name}!"; + + return new List + { + new object?[] { (Func)requiredImplicitService, false, true}, + new object?[] { (Func)requiredImplicitService, true, false}, + + new object?[] { (Func)defaultValueImplicitServiceParam, false, false}, + new object?[] { (Func)defaultValueImplicitServiceParam, true, false}, + + new object?[] { (Func)nullableImplicitServiceParam, false, false}, + new object?[] { (Func)nullableImplicitServiceParam, true, false} + }; + } + } + + [Theory] + [MemberData(nameof(ImplicitServiceParamOptionalityData))] + public async Task RequestDelegateHandlesImplicitServiceParamOptionality(Delegate @delegate, bool hasService, bool isInvalid) + { + var httpContext = new DefaultHttpContext(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + if (hasService) + { + var service = new MyService(); + serviceCollection.AddSingleton(service); + } + var services = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = services; + RequestDelegateFactoryOptions options = new() { ServiceProvider = services }; + + var requestDelegate = RequestDelegateFactory.Create(@delegate, options); + + await requestDelegate(httpContext); + Assert.Equal(isInvalid ? 400 : 200, httpContext.Response.StatusCode); + } + + [Fact] + public async Task RequestDelegateHandlesRequiredAmbiguousValueFromBody() + { + var invoked = false; + void TestAction(Todo todo) + { + invoked = true; + } + + var httpContext = new DefaultHttpContext(); + + httpContext.Request.Headers["Content-Type"] = "application/json"; + + var todo = new Todo() { Name = "Default Todo" }; + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.ContentLength = stream.Length; + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddSingleton(Options.Create(jsonOptions)); + var services = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = services; + + var requestDelegate = RequestDelegateFactory.Create(TestAction, new() { ServiceProvider = services }); + + await requestDelegate(httpContext); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.True(invoked); + } + + [Fact] + public async Task RequestDelegateHandlesRequiredAmbiguousValueFromService() + { + var invoked = false; + void TestAction(Todo todo) + { + invoked = true; + } + + var httpContext = new DefaultHttpContext(); + + var todo = new Todo() { Name = "Default Todo" }; + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddSingleton(todo); + var services = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = services; + + var requestDelegate = RequestDelegateFactory.Create(TestAction, new() { ServiceProvider = services }); + + await requestDelegate(httpContext); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.True(invoked); + } + +#nullable disable + + [Theory] + [InlineData(true, "Hello TestName!")] + [InlineData(false, "Hello !")] + public async Task CanSetStringParamAsOptionalWithNullabilityDisability(bool provideValue, string expectedResponse) + { + string optionalQueryParam(string name = null) => $"Hello {name}!"; + + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (provideValue) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + } + + var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + + [Theory] + [InlineData(true, "Age: 42")] + [InlineData(false, "Age: 0")] + public async Task CanSetParseableStringParamAsOptionalWithNullabilityDisability(bool provideValue, string expectedResponse) + { + string optionalQueryParam(int age = default(int)) => $"Age: {age}"; + + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (provideValue) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["age"] = "42" + }); + } + + var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + +#nullable enable + private class Todo : ITodo { public int Id { get; set; } From e2396b163dfe824cba9f2978d532f3fe9650191c Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 20 Jul 2021 09:54:56 -0700 Subject: [PATCH 2/8] Clean-up, comments, and tests --- .../src/RequestDelegateFactory.cs | 45 +++++++++++++++-- .../test/RequestDelegateFactoryTests.cs | 50 +++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index f041801bbcd0..d788814ab47c 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -61,6 +61,7 @@ public static partial class RequestDelegateFactory private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask)); private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); + private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); /// /// Creates a implementation for . @@ -279,6 +280,8 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext else { // Then try to resolve it as an optional service and fallback to a body otherwise + // Note: if the parameter provides a default value that value will be parsed + // as part of the body, not as the service instance return Expression.Coalesce( Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr), BindParameterFromBody(parameter, allowEmpty: false, factoryContext)); @@ -553,9 +556,17 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres if (!isOptional) { + // The following is produced if the parameter is required: + // + // tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; + // if (tempSourceString == null) + // { + // wasTryParseFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); + // } var checkRequiredStringParameterBlock = Expression.Block( Expression.Assign(TempSourceStringExpr, valueExpression), - Expression.IfThen(Expression.Not(TempSourceStringNotNullExpr), + Expression.IfThen(TempSourceStringNullExpr, Expression.Block( Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, @@ -574,6 +585,12 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres return Expression.Block(Expression.Assign(TempSourceStringExpr, valueExpression)); } + // The following is produced if the parameter is optional. Note that we convert the + // default value to the target ParameterType to address scenarios where the user is + // is setting null as the default value in a context where nullability is disabled. + // + // tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; + // tempSourceString != null ? tempSourceString : Convert("3", Int32) return Expression.Block( Expression.Assign(TempSourceStringExpr, valueExpression), Expression.Condition(TempSourceStringNotNullExpr, @@ -646,9 +663,16 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres var tryParseCall = tryParseMethodCall(parsedValue); - // If the parameter is required, fail to parse and log an error + // The following code is generated if the parameter is required and + // the method should not be matched. + // + // if (tempSourceString == null) + // { + // wasTryParseFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); + // } var checkRequiredParaseableParameterBlock = Expression.Block( - Expression.IfThen(Expression.Not(TempSourceStringNotNullExpr), + Expression.IfThen(TempSourceStringNullExpr, Expression.Block( Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, @@ -675,6 +699,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres ? Expression.Block( // tempSourceString = httpContext.RequestValue["id"]; Expression.Assign(TempSourceStringExpr, valueExpression), + // if (tempSourceString == null) { ... } only produced when parameter is required checkRequiredParaseableParameterBlock, // if (tempSourceString != null) { ... } ifNotNullTryParse) @@ -714,8 +739,17 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); - if (!isOptional && !allowEmpty) - { + if (!factoryContext.AllowEmptyRequestBody) + { + // If the parameter is required or the user has not explicitly + // set allowBody to be empty then validate that it is required. + // + // ToDo body_local = Convert(bodyValue, ToDo); + // if (body_local == null) + // { + // wasTryParseFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Todo", "body") + // } var checkRequiredBodyBlock = Expression.Block( Expression.Assign(argument, Expression.Convert(BodyValueExpr, parameter.ParameterType)), Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), @@ -738,6 +772,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al parameter.ParameterType); } + // Convert(bodyValue, Todo) return Expression.Convert(BodyValueExpr, parameter.ParameterType); } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 9ba05911af95..bc2bc038f83a 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -1677,6 +1677,56 @@ public async Task RequestDelegateHandlesImplicitServiceParamOptionality(Delegate Assert.Equal(isInvalid ? 400 : 200, httpContext.Response.StatusCode); } + public static IEnumerable AllowEmptyData + { + get + { + string disallowEmptyAndNonOptional([FromBody(AllowEmpty = false)] Todo todo) => $"{todo}"; + string allowEmptyAndNonOptional([FromBody(AllowEmpty = true)] Todo todo) => $"{todo}"; + string allowEmptyAndOptional([FromBody(AllowEmpty = true)] Todo? todo = null) => $"{todo}"; + string disallowEmptyAndOptional([FromBody(AllowEmpty = false)] Todo? todo = null) => $"{todo}"; + + return new List + { + new object?[] { (Func)disallowEmptyAndNonOptional, false }, + new object?[] { (Func)allowEmptyAndNonOptional, true }, + new object?[] { (Func)allowEmptyAndOptional, true }, + new object?[] { (Func)disallowEmptyAndOptional, true } + }; + } + } + + [Theory] + [MemberData(nameof(AllowEmptyData))] + public async Task AllowEmptyOverridesOptionality(Delegate @delegate, bool allowsEmptyRequest) + { + var httpContext = new DefaultHttpContext(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (!allowsEmptyRequest) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + Assert.Equal(@"Required parameter ""Todo todo"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + } + } + [Fact] public async Task RequestDelegateHandlesRequiredAmbiguousValueFromBody() { From c071cfec9b7b32c62109373312445d6f340124a1 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 20 Jul 2021 14:21:35 -0700 Subject: [PATCH 3/8] Address some feedback --- .../src/RequestDelegateFactory.cs | 47 +++++++++---------- .../test/RequestDelegateFactoryTests.cs | 5 ++ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index d788814ab47c..f152c31dbb8c 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Http /// public static partial class RequestDelegateFactory { - private static readonly NullabilityInfoContext nullabilityContext = new NullabilityInfoContext(); + private static readonly NullabilityInfoContext NullabilityContext = new NullabilityInfoContext(); private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -264,7 +264,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext else { - var nullability = nullabilityContext.Create(parameter); + var nullability = NullabilityContext.Create(parameter); var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { @@ -537,7 +537,7 @@ private static Expression GetValueFromProperty(Expression sourceExpression, stri private static Expression BindParameterFromService(ParameterInfo parameter) { - var nullability = nullabilityContext.Create(parameter); + var nullability = NullabilityContext.Create(parameter); var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; return isOptional @@ -547,13 +547,13 @@ private static Expression BindParameterFromService(ParameterInfo parameter) private static Expression BindParameterFromValue(ParameterInfo parameter, Expression valueExpression, FactoryContext factoryContext) { - var nullability = nullabilityContext.Create(parameter); + var nullability = NullabilityContext.Create(parameter); var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); + if (parameter.ParameterType == typeof(string)) { - factoryContext.UsingTempSourceString = true; - if (!isOptional) { // The following is produced if the parameter is required: @@ -565,8 +565,8 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); // } var checkRequiredStringParameterBlock = Expression.Block( - Expression.Assign(TempSourceStringExpr, valueExpression), - Expression.IfThen(TempSourceStringNullExpr, + Expression.Assign(argument, valueExpression), + Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), Expression.Block( Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, @@ -575,26 +575,25 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres ) ); - factoryContext.TryParseParams.Add((TempSourceStringExpr, checkRequiredStringParameterBlock)); - return Expression.Block(TempSourceStringExpr); + factoryContext.TryParseParams.Add((argument, checkRequiredStringParameterBlock)); + return argument; } // Allow nullable parameters that don't have a default value if (nullability.ReadState == NullabilityState.Nullable && !parameter.HasDefaultValue) { - return Expression.Block(Expression.Assign(TempSourceStringExpr, valueExpression)); + return valueExpression; } // The following is produced if the parameter is optional. Note that we convert the // default value to the target ParameterType to address scenarios where the user is // is setting null as the default value in a context where nullability is disabled. // - // tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; - // tempSourceString != null ? tempSourceString : Convert("3", Int32) + // param1_local = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; + // param1_local != null ? param1_local : Convert(null, Int32) return Expression.Block( - Expression.Assign(TempSourceStringExpr, valueExpression), - Expression.Condition(TempSourceStringNotNullExpr, - TempSourceStringExpr, + Expression.Condition(Expression.NotEqual(valueExpression, Expression.Constant(null)), + valueExpression, Expression.Convert(Expression.Constant(parameter.DefaultValue), parameter.ParameterType))); } @@ -648,8 +647,6 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // param2_local = 42; // } - var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); - // If the parameter is nullable, create a "parsedValue" local to TryParse into since we cannot the parameter directly. var parsedValue = isNotNullable ? argument : Expression.Variable(nonNullableParameterType, "parsedValue"); @@ -731,12 +728,13 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al throw new InvalidOperationException("Action cannot have more than one FromBody attribute."); } - var nullability = nullabilityContext.Create(parameter); + var nullability = NullabilityContext.Create(parameter); var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; factoryContext.JsonRequestBodyType = parameter.ParameterType; factoryContext.AllowEmptyRequestBody = allowEmpty || isOptional; + var convertedBodyValue = Expression.Convert(BodyValueExpr, parameter.ParameterType); var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); if (!factoryContext.AllowEmptyRequestBody) @@ -744,15 +742,16 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al // If the parameter is required or the user has not explicitly // set allowBody to be empty then validate that it is required. // - // ToDo body_local = Convert(bodyValue, ToDo); + // Todo body_local = Convert(bodyValue, ToDo); // if (body_local == null) // { // wasTryParseFailure = true; - // Log.RequiredParameterNotProvided(httpContext, "Todo", "body") + // Log.RequiredParameterNotProvided(httpContext, "Todo", "body"); // } var checkRequiredBodyBlock = Expression.Block( - Expression.Assign(argument, Expression.Convert(BodyValueExpr, parameter.ParameterType)), - Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), + Expression.Assign(argument, convertedBodyValue), + Expression.IfThen( + Expression.Equal(argument, Expression.Constant(null)), Expression.Block( Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, @@ -773,7 +772,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al } // Convert(bodyValue, Todo) - return Expression.Convert(BodyValueExpr, parameter.ParameterType); + return convertedBodyValue; } private static MethodInfo GetMethodInfo(Expression expr) diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index bc2bc038f83a..820e2113260a 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -668,6 +668,11 @@ void TestImpliedFromBodyInterface(HttpContext httpContext, ITodo myService) [MemberData(nameof(FromBodyActions))] public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) { + // while (!System.Diagnostics.Debugger.IsAttached) + // { + // System.Console.WriteLine($"Waiting to attach on ${Environment.ProcessId}"); + // System.Threading.Thread.Sleep(1000); + // } Todo originalTodo = new() { Name = "Write more tests!" From a1a209f219f5afcf900b7aff2d1db4044766317e Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 20 Jul 2021 16:19:17 -0700 Subject: [PATCH 4/8] Polish up some things --- .../src/RequestDelegateFactory.cs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index f152c31dbb8c..7c2ec519dd32 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -263,21 +263,12 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext } else { - var nullability = NullabilityContext.Create(parameter); var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { - // If the parameter is required - if (!isOptional) - { - // And we are able to resolve a service for it - return serviceProviderIsService.IsService(parameter.ParameterType) - ? Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) // Then get it from the DI - : BindParameterFromBody(parameter, allowEmpty: false, factoryContext); // Otherwise try to find it in the body - } // If the parameter is optional - else + if (!isOptional) { // Then try to resolve it as an optional service and fallback to a body otherwise // Note: if the parameter provides a default value that value will be parsed @@ -286,6 +277,14 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr), BindParameterFromBody(parameter, allowEmpty: false, factoryContext)); } + // If the parameter is required + else + { + // And we are able to resolve a service for it + return serviceProviderIsService.IsService(parameter.ParameterType) + ? Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) // Then get it from the DI + : BindParameterFromBody(parameter, allowEmpty: false, factoryContext); // Otherwise try to find it in the body + } } return BindParameterFromBody(parameter, allowEmpty: false, factoryContext); From 4c694d6034d531a70ae094269cae3228166dd496 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 20 Jul 2021 17:10:51 -0700 Subject: [PATCH 5/8] Fix typo in condition --- src/Http/Http.Extensions/src/RequestDelegateFactory.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 7c2ec519dd32..0752f2723c0c 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -268,7 +268,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { // If the parameter is optional - if (!isOptional) + if (isOptional) { // Then try to resolve it as an optional service and fallback to a body otherwise // Note: if the parameter provides a default value that value will be parsed From c9f0ef5787b6dba3487800da8498212250762f44 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 20 Jul 2021 17:41:23 -0700 Subject: [PATCH 6/8] Remove optionality support for implicit services --- .../src/RequestDelegateFactory.cs | 20 +--- .../test/RequestDelegateFactoryTests.cs | 105 ------------------ 2 files changed, 2 insertions(+), 123 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 0752f2723c0c..9e8d5aa15b1a 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -263,27 +263,11 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext } else { - var nullability = NullabilityContext.Create(parameter); - var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { - // If the parameter is optional - if (isOptional) + if (serviceProviderIsService.IsService(parameter.ParameterType)) { - // Then try to resolve it as an optional service and fallback to a body otherwise - // Note: if the parameter provides a default value that value will be parsed - // as part of the body, not as the service instance - return Expression.Coalesce( - Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr), - BindParameterFromBody(parameter, allowEmpty: false, factoryContext)); - } - // If the parameter is required - else - { - // And we are able to resolve a service for it - return serviceProviderIsService.IsService(parameter.ParameterType) - ? Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) // Then get it from the DI - : BindParameterFromBody(parameter, allowEmpty: false, factoryContext); // Otherwise try to find it in the body + return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) ; } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 820e2113260a..0c8b9ddcdabd 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -1637,50 +1637,6 @@ public async Task RequestDelegateHandlesServiceParamOptionality(Delegate @delega } } - public static IEnumerable ImplicitServiceParamOptionalityData - { - get - { - string requiredImplicitService(MyService name) => $"Hello {name}!"; - string defaultValueImplicitServiceParam(MyService? name = null) => $"Hello {name}!"; - string nullableImplicitServiceParam(MyService? name) => $"Hello {name}!"; - - return new List - { - new object?[] { (Func)requiredImplicitService, false, true}, - new object?[] { (Func)requiredImplicitService, true, false}, - - new object?[] { (Func)defaultValueImplicitServiceParam, false, false}, - new object?[] { (Func)defaultValueImplicitServiceParam, true, false}, - - new object?[] { (Func)nullableImplicitServiceParam, false, false}, - new object?[] { (Func)nullableImplicitServiceParam, true, false} - }; - } - } - - [Theory] - [MemberData(nameof(ImplicitServiceParamOptionalityData))] - public async Task RequestDelegateHandlesImplicitServiceParamOptionality(Delegate @delegate, bool hasService, bool isInvalid) - { - var httpContext = new DefaultHttpContext(); - - var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton(LoggerFactory); - if (hasService) - { - var service = new MyService(); - serviceCollection.AddSingleton(service); - } - var services = serviceCollection.BuildServiceProvider(); - httpContext.RequestServices = services; - RequestDelegateFactoryOptions options = new() { ServiceProvider = services }; - - var requestDelegate = RequestDelegateFactory.Create(@delegate, options); - - await requestDelegate(httpContext); - Assert.Equal(isInvalid ? 400 : 200, httpContext.Response.StatusCode); - } public static IEnumerable AllowEmptyData { @@ -1732,67 +1688,6 @@ public async Task AllowEmptyOverridesOptionality(Delegate @delegate, bool allows } } - [Fact] - public async Task RequestDelegateHandlesRequiredAmbiguousValueFromBody() - { - var invoked = false; - void TestAction(Todo todo) - { - invoked = true; - } - - var httpContext = new DefaultHttpContext(); - - httpContext.Request.Headers["Content-Type"] = "application/json"; - - var todo = new Todo() { Name = "Default Todo" }; - var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); - var stream = new MemoryStream(requestBodyBytes); - httpContext.Request.Body = stream; - httpContext.Request.ContentLength = stream.Length; - - var jsonOptions = new JsonOptions(); - jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); - - var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton(LoggerFactory); - serviceCollection.AddSingleton(Options.Create(jsonOptions)); - var services = serviceCollection.BuildServiceProvider(); - httpContext.RequestServices = services; - - var requestDelegate = RequestDelegateFactory.Create(TestAction, new() { ServiceProvider = services }); - - await requestDelegate(httpContext); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.True(invoked); - } - - [Fact] - public async Task RequestDelegateHandlesRequiredAmbiguousValueFromService() - { - var invoked = false; - void TestAction(Todo todo) - { - invoked = true; - } - - var httpContext = new DefaultHttpContext(); - - var todo = new Todo() { Name = "Default Todo" }; - - var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton(LoggerFactory); - serviceCollection.AddSingleton(todo); - var services = serviceCollection.BuildServiceProvider(); - httpContext.RequestServices = services; - - var requestDelegate = RequestDelegateFactory.Create(TestAction, new() { ServiceProvider = services }); - - await requestDelegate(httpContext); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.True(invoked); - } - #nullable disable [Theory] From b2333d75bdfd4ee49a23cdb5019091892e473039 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Wed, 21 Jul 2021 10:53:41 -0700 Subject: [PATCH 7/8] Add support for IHttpRequestBodyDetectionFeature --- .../src/RequestDelegateFactory.cs | 4 +++- .../test/RequestDelegateFactoryTests.cs | 21 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 9e8d5aa15b1a..726d06ae987f 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -10,6 +10,7 @@ using System.Security.Claims; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Metadata; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; @@ -487,7 +488,8 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, { object? bodyValue = defaultBodyValue; - if (httpContext.Request.ContentLength != 0 && httpContext.Request.HasJsonContentType()) + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) { try { diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 0c8b9ddcdabd..678a4de4cd8a 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -668,11 +668,6 @@ void TestImpliedFromBodyInterface(HttpContext httpContext, ITodo myService) [MemberData(nameof(FromBodyActions))] public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) { - // while (!System.Diagnostics.Debugger.IsAttached) - // { - // System.Console.WriteLine($"Waiting to attach on ${Environment.ProcessId}"); - // System.Threading.Thread.Sleep(1000); - // } Todo originalTodo = new() { Name = "Write more tests!" @@ -686,6 +681,7 @@ public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); var jsonOptions = new JsonOptions(); jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); @@ -717,6 +713,7 @@ public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; + httpContext.Features.Set(new RequestBodyDetectionFeature(false)); var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(LoggerFactory); @@ -793,6 +790,7 @@ void TestAction([FromBody] Todo todo) httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(ioException); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); var requestDelegate = RequestDelegateFactory.Create(TestAction); @@ -826,7 +824,9 @@ void TestAction([FromBody] Todo todo) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(invalidDataException); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); var requestDelegate = RequestDelegateFactory.Create(TestAction); @@ -1550,6 +1550,7 @@ public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, httpContext.Request.Body = stream; httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.ContentLength = stream.Length; + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); } var jsonOptions = new JsonOptions(); @@ -1947,5 +1948,15 @@ public void Abort() _requestAbortedCts.Cancel(); } } + + private class RequestBodyDetectionFeature : IHttpRequestBodyDetectionFeature + { + public RequestBodyDetectionFeature(bool canHaveBody) + { + CanHaveBody = canHaveBody; + } + + public bool CanHaveBody { get; } + } } } From 75dca602e4ed13aa27c70785aa689891062abebf Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Mon, 26 Jul 2021 13:03:10 -0500 Subject: [PATCH 8/8] Address more feedback from peer review --- .../src/RequestDelegateFactory.cs | 65 ++++++++++--------- .../test/RequestDelegateFactoryTests.cs | 14 +++- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 726d06ae987f..fe890d38b114 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -47,7 +47,7 @@ public static partial class RequestDelegateFactory private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target"); private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); - private static readonly ParameterExpression WasTryParseFailureExpr = Expression.Variable(typeof(bool), "wasTryParseFailure"); + private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure"); private static readonly ParameterExpression TempSourceStringExpr = TryParseMethodCache.TempSourceStringExpr; private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.RequestServices)); @@ -168,8 +168,8 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func 0 ? - CreateTryParseCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : + var responseWritingMethodCall = factoryContext.CheckParams.Count > 0 ? + CreateParamCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments); if (factoryContext.UsingTempSourceString) @@ -268,7 +268,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { if (serviceProviderIsService.IsService(parameter.ParameterType)) { - return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) ; + return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); } } @@ -287,13 +287,14 @@ private static Expression CreateResponseWritingMethodCall(MethodInfo methodInfo, return AddResponseWritingToMethodCall(callMethod, methodInfo.ReturnType); } - // If we're calling TryParse and wasTryParseFailure indicates it failed, set a 400 StatusCode instead of calling the method. - private static Expression CreateTryParseCheckingResponseWritingMethodCall( + // If we're calling TryParse or validating parameter optionality and + // wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method. + private static Expression CreateParamCheckingResponseWritingMethodCall( MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext) { // { // string tempSourceString; - // bool wasTryParseFailure = false; + // bool wasParamCheckFailure = false; // // // Assume "int param1" is the first parameter, "[FromRoute] int? param2 = 42" is the second parameter ... // int param1_local; @@ -306,7 +307,7 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // { // if (!int.TryParse(tempSourceString, out param1_local)) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -314,7 +315,7 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // tempSourceString = httpContext.RouteValue["param2"]; // // ... // - // return wasTryParseFailure ? + // return wasParamCheckFailure ? // { // httpContext.Response.StatusCode = 400; // return Task.CompletedTask; @@ -324,15 +325,15 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // }; // } - var localVariables = new ParameterExpression[factoryContext.TryParseParams.Count + 1]; - var tryParseAndCallMethod = new Expression[factoryContext.TryParseParams.Count + 1]; + var localVariables = new ParameterExpression[factoryContext.CheckParams.Count + 1]; + var checkParamAndCallMethod = new Expression[factoryContext.CheckParams.Count + 1]; - for (var i = 0; i < factoryContext.TryParseParams.Count; i++) + for (var i = 0; i < factoryContext.CheckParams.Count; i++) { - (localVariables[i], tryParseAndCallMethod[i]) = factoryContext.TryParseParams[i]; + (localVariables[i], checkParamAndCallMethod[i]) = factoryContext.CheckParams[i]; } - localVariables[factoryContext.TryParseParams.Count] = WasTryParseFailureExpr; + localVariables[factoryContext.CheckParams.Count] = WasParamCheckFailureExpr; var set400StatusAndReturnCompletedTask = Expression.Block( Expression.Assign(StatusCodeExpr, Expression.Constant(400)), @@ -340,13 +341,13 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( var methodCall = CreateMethodCall(methodInfo, target, arguments); - var checkWasTryParseFailure = Expression.Condition(WasTryParseFailureExpr, + var checkWasParamCheckFailure = Expression.Condition(WasParamCheckFailureExpr, set400StatusAndReturnCompletedTask, AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType)); - tryParseAndCallMethod[factoryContext.TryParseParams.Count] = checkWasTryParseFailure; + checkParamAndCallMethod[factoryContext.CheckParams.Count] = checkWasParamCheckFailure; - return Expression.Block(localVariables, tryParseAndCallMethod); + return Expression.Block(localVariables, checkParamAndCallMethod); } private static Expression AddResponseWritingToMethodCall(Expression methodCall, Type returnType) @@ -546,21 +547,21 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; // if (tempSourceString == null) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); // } var checkRequiredStringParameterBlock = Expression.Block( Expression.Assign(argument, valueExpression), Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), Expression.Block( - Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) ) ) ); - factoryContext.TryParseParams.Add((argument, checkRequiredStringParameterBlock)); + factoryContext.CheckParams.Add((argument, checkRequiredStringParameterBlock)); return argument; } @@ -596,7 +597,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres } // string tempSourceString; - // bool wasTryParseFailure = false; + // bool wasParamCheckFailure = false; // // // Assume "int param1" is the first parameter and "[FromRoute] int? param2 = 42" is the second parameter. // int param1_local; @@ -608,7 +609,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // { // if (!int.TryParse(tempSourceString, out param1_local)) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -623,7 +624,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // } // else // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -639,7 +640,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres var parameterNameConstant = Expression.Constant(parameter.Name); var failBlock = Expression.Block( - Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), Expression.Call(LogParameterBindingFailureMethod, HttpContextExpr, parameterTypeNameConstant, parameterNameConstant, TempSourceStringExpr)); @@ -650,13 +651,13 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // // if (tempSourceString == null) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); // } var checkRequiredParaseableParameterBlock = Expression.Block( Expression.IfThen(TempSourceStringNullExpr, Expression.Block( - Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, HttpContextExpr, parameterTypeNameConstant, parameterNameConstant) ) @@ -677,7 +678,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres tryParseExpression, Expression.Assign(argument, Expression.Constant(parameter.DefaultValue))); - var fullTryParseBlock = !isOptional + var fullParamCheckBlock = !isOptional ? Expression.Block( // tempSourceString = httpContext.RequestValue["id"]; Expression.Assign(TempSourceStringExpr, valueExpression), @@ -691,7 +692,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // if (tempSourceString != null) { ... } ifNotNullTryParse); - factoryContext.TryParseParams.Add((argument, fullTryParseBlock)); + factoryContext.CheckParams.Add((argument, fullParamCheckBlock)); return argument; } @@ -730,7 +731,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al // Todo body_local = Convert(bodyValue, ToDo); // if (body_local == null) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.RequiredParameterNotProvided(httpContext, "Todo", "body"); // } var checkRequiredBodyBlock = Expression.Block( @@ -738,13 +739,13 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al Expression.IfThen( Expression.Equal(argument, Expression.Constant(null)), Expression.Block( - Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), Expression.Call(LogRequiredParameterNotProvidedMethod, HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) ) ) ); - factoryContext.TryParseParams.Add((argument, checkRequiredBodyBlock)); + factoryContext.CheckParams.Add((argument, checkRequiredBodyBlock)); return argument; } @@ -941,7 +942,7 @@ private class FactoryContext public List? RouteParameters { get; set; } public bool UsingTempSourceString { get; set; } - public List<(ParameterExpression, Expression)> TryParseParams { get; } = new(); + public List<(ParameterExpression, Expression)> CheckParams { get; } = new(); } private static partial class Log diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 678a4de4cd8a..bdc8014407ce 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -895,6 +895,18 @@ void TestImpliedFromServiceBasedOnContainer(HttpContext httpContext, MyService m } } + [Theory] + [MemberData(nameof(FromServiceActions))] + public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) + { + var httpContext = new DefaultHttpContext(); + httpContext.RequestServices = new EmptyServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(action); + + await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + } + [Theory] [MemberData(nameof(FromServiceActions))] public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAttribute(Delegate action) @@ -1686,7 +1698,7 @@ public async Task AllowEmptyOverridesOptionality(Delegate @delegate, bool allows { Assert.Equal(200, httpContext.Response.StatusCode); Assert.False(httpContext.RequestAborted.IsCancellationRequested); - } + } } #nullable disable