diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 03be58ea4203..e658ab7caa8d 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -10,6 +10,7 @@ using System.Security.Claims; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Metadata; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; @@ -22,6 +23,8 @@ namespace Microsoft.AspNetCore.Http /// public static partial class RequestDelegateFactory { + private static readonly NullabilityInfoContext NullabilityContext = new NullabilityInfoContext(); + private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -31,16 +34,20 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteObjectReturnMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteObjectReturn), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; + private static readonly MethodInfo GetServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); private static readonly MethodInfo LogParameterBindingFailureMethod = GetMethodInfo>((httpContext, parameterType, parameterName, sourceValue) => Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue)); + private static readonly MethodInfo LogRequiredParameterNotProvidedMethod = GetMethodInfo>((httpContext, parameterType, parameterName) => + Log.RequiredParameterNotProvided(httpContext, parameterType, parameterName)); + private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target"); private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); - private static readonly ParameterExpression WasTryParseFailureExpr = Expression.Variable(typeof(bool), "wasTryParseFailure"); + private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure"); private static readonly ParameterExpression TempSourceStringExpr = TryParseMethodCache.TempSourceStringExpr; private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.RequestServices)); @@ -55,6 +62,7 @@ public static partial class RequestDelegateFactory private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask)); private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); + private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); /// /// Creates a implementation for . @@ -160,8 +168,8 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func 0 ? - CreateTryParseCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : + var responseWritingMethodCall = factoryContext.CheckParams.Count > 0 ? + CreateParamCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments); if (factoryContext.UsingTempSourceString) @@ -217,11 +225,11 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext } else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } bodyAttribute) { - return BindParameterFromBody(parameter.ParameterType, bodyAttribute.AllowEmpty, factoryContext); + return BindParameterFromBody(parameter, bodyAttribute.AllowEmpty, factoryContext); } else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType))) { - return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + return BindParameterFromService(parameter); } else if (parameter.ParameterType == typeof(HttpContext)) { @@ -258,14 +266,13 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) { - // If the parameter resolves as a service then get it from services if (serviceProviderIsService.IsService(parameter.ParameterType)) { return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); } } - return BindParameterFromBody(parameter.ParameterType, allowEmpty: false, factoryContext); + return BindParameterFromBody(parameter, allowEmpty: false, factoryContext); } } @@ -280,13 +287,14 @@ private static Expression CreateResponseWritingMethodCall(MethodInfo methodInfo, return AddResponseWritingToMethodCall(callMethod, methodInfo.ReturnType); } - // If we're calling TryParse and wasTryParseFailure indicates it failed, set a 400 StatusCode instead of calling the method. - private static Expression CreateTryParseCheckingResponseWritingMethodCall( + // If we're calling TryParse or validating parameter optionality and + // wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method. + private static Expression CreateParamCheckingResponseWritingMethodCall( MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext) { // { // string tempSourceString; - // bool wasTryParseFailure = false; + // bool wasParamCheckFailure = false; // // // Assume "int param1" is the first parameter, "[FromRoute] int? param2 = 42" is the second parameter ... // int param1_local; @@ -299,7 +307,7 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // { // if (!int.TryParse(tempSourceString, out param1_local)) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -307,7 +315,7 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // tempSourceString = httpContext.RouteValue["param2"]; // // ... // - // return wasTryParseFailure ? + // return wasParamCheckFailure ? // { // httpContext.Response.StatusCode = 400; // return Task.CompletedTask; @@ -317,15 +325,15 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( // }; // } - var localVariables = new ParameterExpression[factoryContext.TryParseParams.Count + 1]; - var tryParseAndCallMethod = new Expression[factoryContext.TryParseParams.Count + 1]; + var localVariables = new ParameterExpression[factoryContext.CheckParams.Count + 1]; + var checkParamAndCallMethod = new Expression[factoryContext.CheckParams.Count + 1]; - for (var i = 0; i < factoryContext.TryParseParams.Count; i++) + for (var i = 0; i < factoryContext.CheckParams.Count; i++) { - (localVariables[i], tryParseAndCallMethod[i]) = factoryContext.TryParseParams[i]; + (localVariables[i], checkParamAndCallMethod[i]) = factoryContext.CheckParams[i]; } - localVariables[factoryContext.TryParseParams.Count] = WasTryParseFailureExpr; + localVariables[factoryContext.CheckParams.Count] = WasParamCheckFailureExpr; var set400StatusAndReturnCompletedTask = Expression.Block( Expression.Assign(StatusCodeExpr, Expression.Constant(400)), @@ -333,13 +341,13 @@ private static Expression CreateTryParseCheckingResponseWritingMethodCall( var methodCall = CreateMethodCall(methodInfo, target, arguments); - var checkWasTryParseFailure = Expression.Condition(WasTryParseFailureExpr, + var checkWasParamCheckFailure = Expression.Condition(WasParamCheckFailureExpr, set400StatusAndReturnCompletedTask, AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType)); - tryParseAndCallMethod[factoryContext.TryParseParams.Count] = checkWasTryParseFailure; + checkParamAndCallMethod[factoryContext.CheckParams.Count] = checkWasParamCheckFailure; - return Expression.Block(localVariables, tryParseAndCallMethod); + return Expression.Block(localVariables, checkParamAndCallMethod); } private static Expression AddResponseWritingToMethodCall(Expression methodCall, Type returnType) @@ -479,13 +487,10 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, return async (target, httpContext) => { - object? bodyValue; + object? bodyValue = defaultBodyValue; - if (factoryContext.AllowEmptyRequestBody && httpContext.Request.ContentLength == 0) - { - bodyValue = defaultBodyValue; - } - else + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) { try { @@ -516,21 +521,66 @@ private static Expression GetValueFromProperty(Expression sourceExpression, stri return Expression.Convert(indexExpression, typeof(string)); } + private static Expression BindParameterFromService(ParameterInfo parameter) + { + var nullability = NullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + + return isOptional + ? Expression.Call(GetServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr) + : Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + } + private static Expression BindParameterFromValue(ParameterInfo parameter, Expression valueExpression, FactoryContext factoryContext) { + var nullability = NullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + + var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); + if (parameter.ParameterType == typeof(string)) { - if (!parameter.HasDefaultValue) + if (!isOptional) + { + // The following is produced if the parameter is required: + // + // tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; + // if (tempSourceString == null) + // { + // wasParamCheckFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); + // } + var checkRequiredStringParameterBlock = Expression.Block( + Expression.Assign(argument, valueExpression), + Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)), + Expression.Block( + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); + + factoryContext.CheckParams.Add((argument, checkRequiredStringParameterBlock)); + return argument; + } + + // Allow nullable parameters that don't have a default value + if (nullability.ReadState == NullabilityState.Nullable && !parameter.HasDefaultValue) { return valueExpression; } - factoryContext.UsingTempSourceString = true; + // The following is produced if the parameter is optional. Note that we convert the + // default value to the target ParameterType to address scenarios where the user is + // is setting null as the default value in a context where nullability is disabled. + // + // param1_local = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"]; + // param1_local != null ? param1_local : Convert(null, Int32) return Expression.Block( - Expression.Assign(TempSourceStringExpr, valueExpression), - Expression.Condition(TempSourceStringNotNullExpr, - TempSourceStringExpr, - Expression.Constant(parameter.DefaultValue))); + Expression.Condition(Expression.NotEqual(valueExpression, Expression.Constant(null)), + valueExpression, + Expression.Convert(Expression.Constant(parameter.DefaultValue), parameter.ParameterType))); } factoryContext.UsingTempSourceString = true; @@ -547,7 +597,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres } // string tempSourceString; - // bool wasTryParseFailure = false; + // bool wasParamCheckFailure = false; // // // Assume "int param1" is the first parameter and "[FromRoute] int? param2 = 42" is the second parameter. // int param1_local; @@ -559,7 +609,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // { // if (!int.TryParse(tempSourceString, out param1_local)) // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -574,7 +624,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // } // else // { - // wasTryParseFailure = true; + // wasParamCheckFailure = true; // Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString) // } // } @@ -583,8 +633,6 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // param2_local = 42; // } - var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); - // If the parameter is nullable, create a "parsedValue" local to TryParse into since we cannot the parameter directly. var parsedValue = isNotNullable ? argument : Expression.Variable(nonNullableParameterType, "parsedValue"); @@ -592,12 +640,30 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres var parameterNameConstant = Expression.Constant(parameter.Name); var failBlock = Expression.Block( - Expression.Assign(WasTryParseFailureExpr, Expression.Constant(true)), + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), Expression.Call(LogParameterBindingFailureMethod, HttpContextExpr, parameterTypeNameConstant, parameterNameConstant, TempSourceStringExpr)); var tryParseCall = tryParseMethodCall(parsedValue); + // The following code is generated if the parameter is required and + // the method should not be matched. + // + // if (tempSourceString == null) + // { + // wasParamCheckFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Int32", "param1"); + // } + var checkRequiredParaseableParameterBlock = Expression.Block( + Expression.IfThen(TempSourceStringNullExpr, + Expression.Block( + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, parameterTypeNameConstant, parameterNameConstant) + ) + ) + ); + // If the parameter is nullable, we need to assign the "parsedValue" local to the nullable parameter on success. Expression tryParseExpression = isNotNullable ? Expression.IfThen(Expression.Not(tryParseCall), failBlock) : @@ -612,13 +678,21 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres tryParseExpression, Expression.Assign(argument, Expression.Constant(parameter.DefaultValue))); - var fullTryParseBlock = Expression.Block( - // tempSourceString = httpContext.RequestValue["id"]; - Expression.Assign(TempSourceStringExpr, valueExpression), - // if (tempSourceString != null) { ... } - ifNotNullTryParse); + var fullParamCheckBlock = !isOptional + ? Expression.Block( + // tempSourceString = httpContext.RequestValue["id"]; + Expression.Assign(TempSourceStringExpr, valueExpression), + // if (tempSourceString == null) { ... } only produced when parameter is required + checkRequiredParaseableParameterBlock, + // if (tempSourceString != null) { ... } + ifNotNullTryParse) + : Expression.Block( + // tempSourceString = httpContext.RequestValue["id"]; + Expression.Assign(TempSourceStringExpr, valueExpression), + // if (tempSourceString != null) { ... } + ifNotNullTryParse); - factoryContext.TryParseParams.Add((argument, fullTryParseBlock)); + factoryContext.CheckParams.Add((argument, fullParamCheckBlock)); return argument; } @@ -633,17 +707,58 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext); } - private static Expression BindParameterFromBody(Type parameterType, bool allowEmpty, FactoryContext factoryContext) + private static Expression BindParameterFromBody(ParameterInfo parameter, bool allowEmpty, FactoryContext factoryContext) { if (factoryContext.JsonRequestBodyType is not null) { throw new InvalidOperationException("Action cannot have more than one FromBody attribute."); } - factoryContext.JsonRequestBodyType = parameterType; - factoryContext.AllowEmptyRequestBody = allowEmpty; + var nullability = NullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; - return Expression.Convert(BodyValueExpr, parameterType); + factoryContext.JsonRequestBodyType = parameter.ParameterType; + factoryContext.AllowEmptyRequestBody = allowEmpty || isOptional; + + var convertedBodyValue = Expression.Convert(BodyValueExpr, parameter.ParameterType); + var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); + + if (!factoryContext.AllowEmptyRequestBody) + { + // If the parameter is required or the user has not explicitly + // set allowBody to be empty then validate that it is required. + // + // Todo body_local = Convert(bodyValue, ToDo); + // if (body_local == null) + // { + // wasParamCheckFailure = true; + // Log.RequiredParameterNotProvided(httpContext, "Todo", "body"); + // } + var checkRequiredBodyBlock = Expression.Block( + Expression.Assign(argument, convertedBodyValue), + Expression.IfThen( + Expression.Equal(argument, Expression.Constant(null)), + Expression.Block( + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); + factoryContext.CheckParams.Add((argument, checkRequiredBodyBlock)); + return argument; + } + + if (parameter.HasDefaultValue) + { + // Convert(bodyValue ?? SomeDefault, Todo) + return Expression.Convert( + Expression.Coalesce(BodyValueExpr, Expression.Constant(parameter.DefaultValue)), + parameter.ParameterType); + } + + // Convert(bodyValue, Todo) + return convertedBodyValue; } private static MethodInfo GetMethodInfo(Expression expr) @@ -838,7 +953,7 @@ private class FactoryContext public List? RouteParameters { get; set; } public bool UsingTempSourceString { get; set; } - public List<(ParameterExpression, Expression)> TryParseParams { get; } = new(); + public List<(ParameterExpression, Expression)> CheckParams { get; } = new(); } private static partial class Log @@ -858,11 +973,19 @@ public static void RequestBodyInvalidDataException(HttpContext httpContext, Inva public static void ParameterBindingFailed(HttpContext httpContext, string parameterTypeName, string parameterName, string sourceValue) => ParameterBindingFailed(GetLogger(httpContext), parameterTypeName, parameterName, sourceValue); + public static void RequiredParameterNotProvided(HttpContext httpContext, string parameterTypeName, string parameterName) + => RequiredParameterNotProvided(GetLogger(httpContext), parameterTypeName, parameterName); + [LoggerMessage(3, LogLevel.Debug, @"Failed to bind parameter ""{ParameterType} {ParameterName}"" from ""{SourceValue}"".", EventName = "ParamaterBindingFailed")] private static partial void ParameterBindingFailed(ILogger logger, string parameterType, string parameterName, string sourceValue); + [LoggerMessage(4, LogLevel.Debug, + @"Required parameter ""{ParameterType} {ParameterName}"" was not provided.", + EventName = "RequiredParameterNotProvided")] + private static partial void RequiredParameterNotProvided(ILogger logger, string parameterType, string parameterName); + private static ILogger GetLogger(HttpContext httpContext) { var loggerFactory = httpContext.RequestServices.GetRequiredService(); diff --git a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj index fffeeb054b84..25444590f57a 100644 --- a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj +++ b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj @@ -2,6 +2,8 @@ $(DefaultNetCoreTargetFramework) + + $(Features.Replace('nullablePublicOnly', '') diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 8156083bdbc2..b9f81f328b37 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -360,7 +360,7 @@ void TestAction([FromRoute(Name = specifiedName)] int foo) } [Fact] - public async Task UsesDefaultValueIfNoMatchingRouteValue() + public async Task Returns400IfNoMatchingRouteValueForRequiredParam() { const string unmatchedName = "value"; const int unmatchedRouteParam = 42; @@ -375,11 +375,15 @@ void TestAction([FromRoute] int foo) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[unmatchedName] = unmatchedRouteParam.ToString(NumberFormatInfo.InvariantInfo); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(TestAction); await requestDelegate(httpContext); - Assert.Equal(0, deserializedRouteParam); + Assert.Equal(400, httpContext.Response.StatusCode); } public static object?[][] TryParsableParameters @@ -424,7 +428,6 @@ static void Store(HttpContext httpContext, T tryParsable) new object[] { (Action)Store, "42", 42 }, new object[] { (Action)Store, "ValueB", MyEnum.ValueB }, new object[] { (Action)Store, "https://example.org", new MyTryParsableRecord(new Uri("https://example.org")) }, - new object?[] { (Action)Store, null, 0 }, new object?[] { (Action)Store, null, null }, }; } @@ -454,6 +457,10 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues["tryParsable"] = routeValue; + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(action); await requestDelegate(httpContext); @@ -471,6 +478,10 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQ ["tryParsable"] = routeValue }); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var requestDelegate = RequestDelegateFactory.Create(action); await requestDelegate(httpContext); @@ -663,10 +674,14 @@ public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) }; var httpContext = new DefaultHttpContext(); - httpContext.Request.Headers["Content-Type"] = "application/json"; var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); - httpContext.Request.Body = new MemoryStream(requestBodyBytes); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); var jsonOptions = new JsonOptions(); jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); @@ -698,10 +713,17 @@ public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; + httpContext.Features.Set(new RequestBodyDetectionFeature(false)); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); var requestDelegate = RequestDelegateFactory.Create(action); - await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + await requestDelegate(httpContext); + + Assert.Equal(400, httpContext.Response.StatusCode); } [Fact] @@ -765,8 +787,10 @@ void TestAction([FromBody] Todo todo) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(ioException); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); var requestDelegate = RequestDelegateFactory.Create(TestAction); @@ -798,8 +822,11 @@ void TestAction([FromBody] Todo todo) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = "1"; httpContext.Request.Body = new IOExceptionThrowingRequestBodyStream(invalidDataException); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); var requestDelegate = RequestDelegateFactory.Create(TestAction); @@ -868,6 +895,18 @@ void TestImpliedFromServiceBasedOnContainer(HttpContext httpContext, MyService m } } + [Theory] + [MemberData(nameof(FromServiceActions))] + public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) + { + var httpContext = new DefaultHttpContext(); + httpContext.RequestServices = new EmptyServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(action); + + await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + } + [Theory] [MemberData(nameof(FromServiceActions))] public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAttribute(Delegate action) @@ -892,18 +931,6 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt Assert.Same(myOriginalService, httpContext.Items["service"]); } - [Theory] - [MemberData(nameof(FromServiceActions))] - public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) - { - var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = new EmptyServiceProvider(); - - var requestDelegate = RequestDelegateFactory.Create(action); - - await Assert.ThrowsAsync(() => requestDelegate(httpContext)); - } - [Fact] public async Task RequestDelegatePopulatesHttpContextParameterWithoutAttribute() { @@ -1369,6 +1396,388 @@ public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate) Assert.Equal("null", responseBody); } + public static IEnumerable QueryParamOptionalityData + { + get + { + string requiredQueryParam(string name) => $"Hello {name}!"; + string defaultValueQueryParam(string name = "DefaultName") => $"Hello {name}!"; + string nullableQueryParam(string? name) => $"Hello {name}!"; + string requiredParseableQueryParam(int age) => $"Age: {age}"; + string defaultValueParseableQueryParam(int age = 12) => $"Age: {age}"; + string nullableQueryParseableParam(int? age) => $"Age: {age}"; + + return new List + { + new object?[] { (Func)requiredQueryParam, "name", null, true, null}, + new object?[] { (Func)requiredQueryParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)defaultValueQueryParam, "name", null, false, "Hello DefaultName!" }, + new object?[] { (Func)defaultValueQueryParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)nullableQueryParam, "name", null, false, "Hello !" }, + new object?[] { (Func)nullableQueryParam, "name", "TestName", false, "Hello TestName!"}, + + new object?[] { (Func)requiredParseableQueryParam, "age", null, true, null}, + new object?[] { (Func)requiredParseableQueryParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)defaultValueParseableQueryParam, "age", null, false, "Age: 12" }, + new object?[] { (Func)defaultValueParseableQueryParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)nullableQueryParseableParam, "age", null, false, "Age: " }, + new object?[] { (Func)nullableQueryParseableParam, "age", "42", false, "Age: 42"}, + }; + } + } + + [Theory] + [MemberData(nameof(QueryParamOptionalityData))] + public async Task RequestDelegateHandlesQueryParamOptionality(Delegate @delegate, string paramName, string? queryParam, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (queryParam is not null) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + [paramName] = queryParam + }); + } + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + var expectedType = paramName == "age" ? "Int32 age" : "String name"; + Assert.Equal($@"Required parameter ""{expectedType}"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable RouteParamOptionalityData + { + get + { + string requiredRouteParam(string name) => $"Hello {name}!"; + string defaultValueRouteParam(string name = "DefaultName") => $"Hello {name}!"; + string nullableRouteParam(string? name) => $"Hello {name}!"; + string requiredParseableRouteParam(int age) => $"Age: {age}"; + string defaultValueParseableRouteParam(int age = 12) => $"Age: {age}"; + string nullableParseableRouteParam(int? age) => $"Age: {age}"; + + return new List + { + new object?[] { (Func)requiredRouteParam, "name", null, true, null}, + new object?[] { (Func)requiredRouteParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)defaultValueRouteParam, "name", null, false, "Hello DefaultName!" }, + new object?[] { (Func)defaultValueRouteParam, "name", "TestName", false, "Hello TestName!" }, + new object?[] { (Func)nullableRouteParam, "name", null, false, "Hello !" }, + new object?[] { (Func)nullableRouteParam, "name", "TestName", false, "Hello TestName!" }, + + new object?[] { (Func)requiredParseableRouteParam, "age", null, true, null}, + new object?[] { (Func)requiredParseableRouteParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)defaultValueParseableRouteParam, "age", null, false, "Age: 12" }, + new object?[] { (Func)defaultValueParseableRouteParam, "age", "42", false, "Age: 42" }, + new object?[] { (Func)nullableParseableRouteParam, "age", null, false, "Age: " }, + new object?[] { (Func)nullableParseableRouteParam, "age", "42", false, "Age: 42"}, + }; + } + } + + [Theory] + [MemberData(nameof(RouteParamOptionalityData))] + public async Task RequestDelegateHandlesRouteParamOptionality(Delegate @delegate, string paramName, string? routeParam, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (routeParam is not null) + { + httpContext.Request.RouteValues[paramName] = routeParam; + } + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + var expectedType = paramName == "age" ? "Int32 age" : "String name"; + Assert.Equal($@"Required parameter ""{expectedType}"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable BodyParamOptionalityData + { + get + { + string requiredBodyParam(Todo todo) => $"Todo: {todo.Name}"; + string defaultValueBodyParam(Todo? todo = null) => $"Todo: {todo?.Name}"; + string nullableBodyParam(Todo? todo) => $"Todo: {todo?.Name}"; + + return new List + { + new object?[] { (Func)requiredBodyParam, false, true, null }, + new object?[] { (Func)requiredBodyParam, true, false, "Todo: Default Todo"}, + new object?[] { (Func)defaultValueBodyParam, false, false, "Todo: "}, + new object?[] { (Func)defaultValueBodyParam, true, false, "Todo: Default Todo"}, + new object?[] { (Func)nullableBodyParam, false, false, "Todo: " }, + new object?[] { (Func)nullableBodyParam, true, false, "Todo: Default Todo" }, + }; + } + } + + [Theory] + [MemberData(nameof(BodyParamOptionalityData))] + public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, bool hasBody, bool isInvalid, string? expectedResponse) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (hasBody) + { + var todo = new Todo() { Name = "Default Todo" }; + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.ContentLength = stream.Length; + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + } + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddSingleton(Options.Create(jsonOptions)); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (isInvalid) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + Assert.Equal(@"Required parameter ""Todo todo"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + } + + public static IEnumerable ServiceParamOptionalityData + { + get + { + string requiredExplicitService([FromService] MyService service) => $"Service: {service}"; + string defaultValueExplicitServiceParam([FromService] MyService? service = null) => $"Service: {service}"; + string nullableExplicitServiceParam([FromService] MyService? service) => $"Service: {service}"; + + return new List + { + new object?[] { (Func)requiredExplicitService, false, true}, + new object?[] { (Func)requiredExplicitService, true, false}, + + new object?[] { (Func)defaultValueExplicitServiceParam, false, false}, + new object?[] { (Func)defaultValueExplicitServiceParam, true, false}, + + new object?[] { (Func)nullableExplicitServiceParam, false, false}, + new object?[] { (Func)nullableExplicitServiceParam, true, false}, + }; + } + } + + [Theory] + [MemberData(nameof(ServiceParamOptionalityData))] + public async Task RequestDelegateHandlesServiceParamOptionality(Delegate @delegate, bool hasService, bool isInvalid) + { + var httpContext = new DefaultHttpContext(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + if (hasService) + { + var service = new MyService(); + + serviceCollection.AddSingleton(service); + } + var services = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = services; + RequestDelegateFactoryOptions options = new() { ServiceProvider = services }; + + var requestDelegate = RequestDelegateFactory.Create(@delegate, options); + + if (!isInvalid) + { + await requestDelegate(httpContext); + Assert.Equal(200, httpContext.Response.StatusCode); + } + else + { + await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + } + } + + + public static IEnumerable AllowEmptyData + { + get + { + string disallowEmptyAndNonOptional([FromBody(AllowEmpty = false)] Todo todo) => $"{todo}"; + string allowEmptyAndNonOptional([FromBody(AllowEmpty = true)] Todo todo) => $"{todo}"; + string allowEmptyAndOptional([FromBody(AllowEmpty = true)] Todo? todo = null) => $"{todo}"; + string disallowEmptyAndOptional([FromBody(AllowEmpty = false)] Todo? todo = null) => $"{todo}"; + + return new List + { + new object?[] { (Func)disallowEmptyAndNonOptional, false }, + new object?[] { (Func)allowEmptyAndNonOptional, true }, + new object?[] { (Func)allowEmptyAndOptional, true }, + new object?[] { (Func)disallowEmptyAndOptional, true } + }; + } + } + + [Theory] + [MemberData(nameof(AllowEmptyData))] + public async Task AllowEmptyOverridesOptionality(Delegate @delegate, bool allowsEmptyRequest) + { + var httpContext = new DefaultHttpContext(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var logs = TestSink.Writes.ToArray(); + + if (!allowsEmptyRequest) + { + Assert.Equal(400, httpContext.Response.StatusCode); + var log = Assert.Single(logs); + Assert.Equal(LogLevel.Debug, log.LogLevel); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), log.EventId); + Assert.Equal(@"Required parameter ""Todo todo"" was not provided.", log.Message); + } + else + { + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + } + } + +#nullable disable + + [Theory] + [InlineData(true, "Hello TestName!")] + [InlineData(false, "Hello !")] + public async Task CanSetStringParamAsOptionalWithNullabilityDisability(bool provideValue, string expectedResponse) + { + string optionalQueryParam(string name = null) => $"Hello {name}!"; + + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (provideValue) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + } + + var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + + [Theory] + [InlineData(true, "Age: 42")] + [InlineData(false, "Age: 0")] + public async Task CanSetParseableStringParamAsOptionalWithNullabilityDisability(bool provideValue, string expectedResponse) + { + string optionalQueryParam(int age = default(int)) => $"Age: {age}"; + + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + if (provideValue) + { + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["age"] = "42" + }); + } + + var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(expectedResponse, decodedResponseBody); + } + +#nullable enable + private class Todo : ITodo { public int Id { get; set; } @@ -1566,5 +1975,15 @@ public void Abort() _requestAbortedCts.Cancel(); } } + + private class RequestBodyDetectionFeature : IHttpRequestBodyDetectionFeature + { + public RequestBodyDetectionFeature(bool canHaveBody) + { + CanHaveBody = canHaveBody; + } + + public bool CanHaveBody { get; } + } } }