Skip to content

Commit

Permalink
Adding registration methods to infer the behavior types
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogard committed Jul 7, 2023
1 parent 9ebdf7b commit 3e1c399
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 4 deletions.
137 changes: 135 additions & 2 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,41 @@ public MediatRServiceConfiguration RegisterServicesFromAssemblies(
public MediatRServiceConfiguration AddBehavior<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed behavior type against all <see cref="IPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed behavior implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
return AddBehavior(typeof(TImplementationType), serviceLifetime);
}

/// <summary>
/// Register a closed behavior type against all <see cref="IPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed behavior implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>)));

if (implementedBehaviorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
{
BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Register a closed behavior type
/// </summary>
Expand Down Expand Up @@ -181,6 +216,39 @@ public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type impl
return this;
}

/// <summary>
/// Register a closed stream behavior type against all <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed stream behavior implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddStreamBehavior(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed stream behavior type against all <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed stream behavior implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>)));

if (implementedBehaviorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
{
StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Registers an open stream behavior type against the <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> open generic interface type
/// </summary>
Expand Down Expand Up @@ -210,7 +278,6 @@ public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType,
return this;
}


/// <summary>
/// Register a closed request pre processor type
/// </summary>
Expand All @@ -234,6 +301,40 @@ public MediatRServiceConfiguration AddRequestPreProcessor(Type serviceType, Type

return this;
}

/// <summary>
/// Register a closed request pre processor type against all <see cref="IRequestPreProcessor{TRequest}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed request pre processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor<TImplementationType>(
ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request pre processor type against all <see cref="IRequestPreProcessor{TRequest}"/> implementations
/// </summary>
/// <param name="implementationType">Closed request pre processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPreProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));

if (implementedPreProcessorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var implementedPreProcessorType in implementedPreProcessorTypes)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Registers an open request pre processor type against the <see cref="IRequestPreProcessor{TRequest}"/> open generic interface type
Expand Down Expand Up @@ -272,7 +373,7 @@ public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorT
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);
=> AddRequestPostProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request post processor type
Expand All @@ -287,6 +388,38 @@ public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Typ

return this;
}

/// <summary>
/// Register a closed request post processor type against all <see cref="IRequestPostProcessor{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed request post processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPostProcessor(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request post processor type against all <see cref="IRequestPostProcessor{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed request post processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPostProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));

if (implementedPostProcessorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var implementedPostProcessorType in implementedPostProcessorTypes)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime));
}
return this;
}

/// <summary>
/// Registers an open request post processor type against the <see cref="IRequestPostProcessor{TRequest,TResponse}"/> open generic interface type
Expand Down
130 changes: 128 additions & 2 deletions test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.DependencyInjection;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;

namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests;

Expand Down Expand Up @@ -51,6 +52,46 @@ public async Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next,
return response;
}
}

public class OuterStreamBehavior : IStreamPipelineBehavior<Ping, Pong>
{
private readonly Logger _output;

public OuterStreamBehavior(Logger output)
{
_output = output;
}

public async IAsyncEnumerable<Pong> Handle(Ping request, StreamHandlerDelegate<Pong> next, [EnumeratorCancellation] CancellationToken cancellationToken)
{
_output.Messages.Add("Outer before");
await foreach (var item in next().WithCancellation(cancellationToken))
{
yield return item;
}
_output.Messages.Add("Outer after");
}
}

public class InnerStreamBehavior : IStreamPipelineBehavior<Ping, Pong>
{
private readonly Logger _output;

public InnerStreamBehavior(Logger output)
{
_output = output;
}

public async IAsyncEnumerable<Pong> Handle(Ping request, StreamHandlerDelegate<Pong> next, [EnumeratorCancellation] CancellationToken cancellationToken)
{
_output.Messages.Add("Inner before");
await foreach (var item in next().WithCancellation(cancellationToken))
{
yield return item;
}
_output.Messages.Add("Inner after");
}
}

public class InnerBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
Expand Down Expand Up @@ -517,7 +558,7 @@ public async Task Should_handle_constrained_generics()
cfg.AddOpenRequestPreProcessor(typeof(FirstPreProcessor<>));
cfg.AddOpenRequestPreProcessor(typeof(NextPreProcessor<>));
cfg.AddRequestPostProcessor<IRequestPostProcessor<Ping, Pong>, FirstConcretePostProcessor>();
cfg.AddRequestPreProcessor<IRequestPostProcessor<Ping, Pong>, NextConcretePostProcessor>();
cfg.AddRequestPostProcessor<IRequestPostProcessor<Ping, Pong>, NextConcretePostProcessor>();
cfg.AddOpenRequestPostProcessor(typeof(FirstPostProcessor<,>));
cfg.AddOpenRequestPostProcessor(typeof(NextPostProcessor<,>));
});
Expand Down Expand Up @@ -608,7 +649,92 @@ public void Should_handle_open_behavior_registration()
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
}

[Fact]
public void Should_handle_inferred_behavior_registration()
{
var cfg = new MediatRServiceConfiguration();
cfg.AddBehavior<InnerBehavior>();
cfg.AddBehavior(typeof(OuterBehavior));

cfg.BehaviorsToRegister.Count.ShouldBe(2);

cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerBehavior));
cfg.BehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterBehavior));
cfg.BehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
}


[Fact]
public void Should_handle_inferred_stream_behavior_registration()
{
var cfg = new MediatRServiceConfiguration();
cfg.AddStreamBehavior<InnerStreamBehavior>();
cfg.AddStreamBehavior(typeof(OuterStreamBehavior));

cfg.StreamBehaviorsToRegister.Count.ShouldBe(2);

cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerStreamBehavior));
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterStreamBehavior));
cfg.StreamBehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
}

[Fact]
public void Should_handle_inferred_pre_processor_registration()
{
var cfg = new MediatRServiceConfiguration();
cfg.AddRequestPreProcessor<FirstConcretePreProcessor>();
cfg.AddRequestPreProcessor(typeof(NextConcretePreProcessor));

cfg.RequestPreProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
}

[Fact]
public void Should_handle_inferred_post_processor_registration()
{
var cfg = new MediatRServiceConfiguration();
cfg.AddRequestPostProcessor<FirstConcretePostProcessor>();
cfg.AddRequestPostProcessor(typeof(NextConcretePostProcessor));

cfg.RequestPostProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
}

[Fact]
public void Should_handle_open_behaviors_registration_from_a_single_type()
{
Expand Down

0 comments on commit 3e1c399

Please sign in to comment.