Skip to content

Commit

Permalink
Infer that interface parameters are services (#31658)
Browse files Browse the repository at this point in the history
  • Loading branch information
halter73 authored Apr 13, 2021
1 parent e340205 commit d70439c
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 71 deletions.
32 changes: 20 additions & 12 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -203,15 +202,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
}
else if (parameterCustomAttributes.OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute)
{
if (factoryContext.JsonRequestBodyType is not null)
{
throw new InvalidOperationException("Action cannot have more than one FromBody attribute.");
}

factoryContext.JsonRequestBodyType = parameter.ParameterType;
factoryContext.AllowEmptyRequestBody = bodyAttribute.AllowEmpty;

return Expression.Convert(BodyValueExpr, parameter.ParameterType);
return BindParameterFromBody(parameter.ParameterType, bodyAttribute.AllowEmpty, factoryContext);
}
else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
{
Expand All @@ -229,10 +220,14 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
{
return BindParameterFromRouteValueOrQueryString(parameter, parameter.Name, factoryContext);
}
else
else if (parameter.ParameterType.IsInterface)
{
return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}
else
{
return BindParameterFromBody(parameter.ParameterType, allowEmpty: false, factoryContext);
}
}

private static Expression CreateMethodCall(MethodInfo methodInfo, Expression? target, Expression[] arguments) =>
Expand Down Expand Up @@ -428,7 +423,7 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
var invoker = Expression.Lambda<Func<object?, HttpContext, object?, Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile();

var bodyType = factoryContext.JsonRequestBodyType!;
var bodyType = factoryContext.JsonRequestBodyType;
object? defaultBodyValue = null;

if (factoryContext.AllowEmptyRequestBody && bodyType.IsValueType)
Expand Down Expand Up @@ -627,6 +622,19 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo
return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext);
}

private static Expression BindParameterFromBody(Type parameterType, bool allowEmpty, FactoryContext factoryContext)
{
if (factoryContext.JsonRequestBodyType is not null)
{
throw new InvalidOperationException("Action cannot have more than one FromBody attribute.");
}

factoryContext.JsonRequestBodyType = parameterType;
factoryContext.AllowEmptyRequestBody = allowEmpty;

return Expression.Convert(BodyValueExpr, parameterType);
}

private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
{
var mc = (MethodCallExpression)expr.Body;
Expand Down
86 changes: 51 additions & 35 deletions src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,8 @@ public static bool TryParse(string? value, out MyTryParsableRecord? result)
[MemberData(nameof(TryParsableParameters))]
public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromRouteValue(Delegate action, string? routeValue, object? expectedParameterValue)
{
var invalidDataException = new InvalidDataException();
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton(LoggerFactory);

var httpContext = new DefaultHttpContext();
httpContext.Request.RouteValues["tryParsable"] = routeValue;
httpContext.Features.Set<IHttpRequestLifetimeFeature>(new TestHttpRequestLifetimeFeature());
httpContext.RequestServices = serviceCollection.BuildServiceProvider();

var requestDelegate = RequestDelegateFactory.Create(action);

Expand Down Expand Up @@ -416,7 +410,7 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR
Assert.Equal(42, httpContext.Items["tryParsable"]);
}

public static object[][] DelegatesWithInvalidAttributes
public static object[][] DelegatesWithAttributesOnNotTryParsableParameters
{
get
{
Expand All @@ -434,7 +428,7 @@ void InvalidFromHeader([FromHeader] object notTryParsable) { }
}

[Theory]
[MemberData(nameof(DelegatesWithInvalidAttributes))]
[MemberData(nameof(DelegatesWithAttributesOnNotTryParsableParameters))]
public void CreateThrowsInvalidOperationExceptionWhenAttributeRequiresTryParseMethodThatDoesNotExist(Delegate action)
{
var ex = Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create(action));
Expand All @@ -460,7 +454,6 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2)
invoked = true;
}

var invalidDataException = new InvalidDataException();
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton(LoggerFactory);

Expand Down Expand Up @@ -542,47 +535,61 @@ void TestAction([FromHeader(Name = customHeaderName)] int value)
Assert.Equal(originalHeaderParam, deserializedRouteParam);
}

[Fact]
public async Task RequestDelegatePopulatesFromBodyParameter()
public static object[][] FromBodyActions
{
Todo originalTodo = new()
get
{
Name = "Write more tests!"
};
void TestExplicitFromBody(HttpContext httpContext, [FromBody] Todo todo)
{
httpContext.Items.Add("body", todo);
}

Todo? deserializedRequestBody = null;
void TestImpliedFromBody(HttpContext httpContext, Todo myService)
{
httpContext.Items.Add("body", myService);
}

void TestAction([FromBody] Todo todo)
{
deserializedRequestBody = todo;
return new[]
{
new[] { (Action<HttpContext, Todo>)TestExplicitFromBody },
new[] { (Action<HttpContext, Todo>)TestImpliedFromBody },
};
}
}

[Theory]
[MemberData(nameof(FromBodyActions))]
public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action)
{
Todo originalTodo = new()
{
Name = "Write more tests!"
};

var httpContext = new DefaultHttpContext();
httpContext.Request.Headers["Content-Type"] = "application/json";

var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo);
httpContext.Request.Body = new MemoryStream(requestBodyBytes);

var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction);
var requestDelegate = RequestDelegateFactory.Create(action);

await requestDelegate(httpContext);

var deserializedRequestBody = httpContext.Items["body"];
Assert.NotNull(deserializedRequestBody);
Assert.Equal(originalTodo.Name, deserializedRequestBody!.Name);
Assert.Equal(originalTodo.Name, ((Todo)deserializedRequestBody!).Name);
}

[Fact]
public async Task RequestDelegateRejectsEmptyBodyGivenDefaultFromBodyParameter()
[Theory]
[MemberData(nameof(FromBodyActions))]
public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate action)
{
void TestAction([FromBody] Todo todo)
{
}

var httpContext = new DefaultHttpContext();
httpContext.Request.Headers["Content-Type"] = "application/json";
httpContext.Request.Headers["Content-Length"] = "0";

var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction);
var requestDelegate = RequestDelegateFactory.Create(action);

await Assert.ThrowsAsync<JsonException>(() => requestDelegate(httpContext));
}
Expand Down Expand Up @@ -702,12 +709,16 @@ void TestAction([FromBody] Todo todo)
[Fact]
public void BuildRequestDelegateThrowsInvalidOperationExceptionGivenFromBodyOnMultipleParameters()
{
void TestAction([FromBody] int value1, [FromBody] int value2) { }
void TestAttributedInvalidAction([FromBody] int value1, [FromBody] int value2) { }
void TestInferredInvalidAction(Todo value1, Todo value2) { }
void TestBothInvalidAction(Todo value1, [FromBody] int value2) { }

Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int, int>)TestAction));
Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int, int>)TestAttributedInvalidAction));
Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, Todo>)TestInferredInvalidAction));
Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, int>)TestBothInvalidAction));
}

public static object[][] FromServiceParameter
public static object[][] FromServiceActions
{
get
{
Expand All @@ -716,7 +727,7 @@ void TestExplicitFromService(HttpContext httpContext, [FromService] MyService my
httpContext.Items.Add("service", myService);
}

void TestImpliedFromService(HttpContext httpContext, MyService myService)
void TestImpliedFromService(HttpContext httpContext, IMyService myService)
{
httpContext.Items.Add("service", myService);
}
Expand All @@ -730,13 +741,14 @@ void TestImpliedFromService(HttpContext httpContext, MyService myService)
}

[Theory]
[MemberData(nameof(FromServiceParameter))]
[MemberData(nameof(FromServiceActions))]
public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAttribute(Delegate action)
{
var myOriginalService = new MyService();

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton(myOriginalService);
serviceCollection.AddSingleton<IMyService>(myOriginalService);

var httpContext = new DefaultHttpContext();
httpContext.RequestServices = serviceCollection.BuildServiceProvider();
Expand All @@ -749,11 +761,11 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt
}

[Theory]
[MemberData(nameof(FromServiceParameter))]
[MemberData(nameof(FromServiceActions))]
public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action)
{
var httpContext = new DefaultHttpContext();
httpContext.RequestServices = (new ServiceCollection()).BuildServiceProvider();
httpContext.RequestServices = new ServiceCollection().BuildServiceProvider();

var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, MyService>)action);

Expand Down Expand Up @@ -1058,7 +1070,11 @@ private class FromServiceAttribute : Attribute, IFromServiceMetadata
{
}

private class MyService
private interface IMyService
{
}

private class MyService : IMyService
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ namespace Microsoft.AspNetCore.Builder
/// <summary>
/// Builds conventions that will be used for customization of MapAction <see cref="EndpointBuilder"/> instances.
/// </summary>
public sealed class MapActionEndpointConventionBuilder : IEndpointConventionBuilder
public sealed class MinimalActionEndpointConventionBuilder : IEndpointConventionBuilder
{
private readonly List<IEndpointConventionBuilder> _endpointConventionBuilders;

internal MapActionEndpointConventionBuilder(IEndpointConventionBuilder endpointConventionBuilder)
internal MinimalActionEndpointConventionBuilder(IEndpointConventionBuilder endpointConventionBuilder)
{
_endpointConventionBuilders = new List<IEndpointConventionBuilder>() { endpointConventionBuilder };
}

internal MapActionEndpointConventionBuilder(List<IEndpointConventionBuilder> endpointConventionBuilders)
internal MinimalActionEndpointConventionBuilder(List<IEndpointConventionBuilder> endpointConventionBuilders)
{
_endpointConventionBuilders = endpointConventionBuilders;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Builder
/// <summary>
/// Provides extension methods for <see cref="IEndpointRouteBuilder"/> to define HTTP API endpoints.
/// </summary>
public static class MapActionEndpointRouteBuilderExtensions
public static class MinmalActionEndpointRouteBuilderExtensions
{
// Avoid creating a new array every call
private static readonly string[] GetVerb = new[] { "GET" };
Expand All @@ -30,7 +30,7 @@ public static class MapActionEndpointRouteBuilderExtensions
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder MapGet(
public static MinimalActionEndpointConventionBuilder MapGet(
this IEndpointRouteBuilder endpoints,
string pattern,
Delegate action)
Expand All @@ -46,7 +46,7 @@ public static MapActionEndpointConventionBuilder MapGet(
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder MapPost(
public static MinimalActionEndpointConventionBuilder MapPost(
this IEndpointRouteBuilder endpoints,
string pattern,
Delegate action)
Expand All @@ -62,7 +62,7 @@ public static MapActionEndpointConventionBuilder MapPost(
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that canaction be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder MapPut(
public static MinimalActionEndpointConventionBuilder MapPut(
this IEndpointRouteBuilder endpoints,
string pattern,
Delegate action)
Expand All @@ -78,7 +78,7 @@ public static MapActionEndpointConventionBuilder MapPut(
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder MapDelete(
public static MinimalActionEndpointConventionBuilder MapDelete(
this IEndpointRouteBuilder endpoints,
string pattern,
Delegate action)
Expand All @@ -95,7 +95,7 @@ public static MapActionEndpointConventionBuilder MapDelete(
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <param name="httpMethods">HTTP methods that the endpoint will match.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder MapMethods(
public static MinimalActionEndpointConventionBuilder MapMethods(
this IEndpointRouteBuilder endpoints,
string pattern,
IEnumerable<string> httpMethods,
Expand All @@ -120,7 +120,7 @@ public static MapActionEndpointConventionBuilder MapMethods(
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder Map(
public static MinimalActionEndpointConventionBuilder Map(
this IEndpointRouteBuilder endpoints,
string pattern,
Delegate action)
Expand All @@ -136,7 +136,7 @@ public static MapActionEndpointConventionBuilder Map(
/// <param name="pattern">The route pattern.</param>
/// <param name="action">The delegate executed when the endpoint is matched.</param>
/// <returns>A <see cref="IEndpointConventionBuilder"/> that can be used to further customize the endpoint.</returns>
public static MapActionEndpointConventionBuilder Map(
public static MinimalActionEndpointConventionBuilder Map(
this IEndpointRouteBuilder endpoints,
RoutePattern pattern,
Delegate action)
Expand Down Expand Up @@ -185,7 +185,7 @@ public static MapActionEndpointConventionBuilder Map(
endpoints.DataSources.Add(dataSource);
}

return new MapActionEndpointConventionBuilder(dataSource.AddEndpointBuilder(builder));
return new MinimalActionEndpointConventionBuilder(dataSource.AddEndpointBuilder(builder));
}
}
}
Loading

0 comments on commit d70439c

Please sign in to comment.