diff --git a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs index f6483f94..a1d8c0a7 100644 --- a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs +++ b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs @@ -341,9 +341,16 @@ private static object GetInstance(this IServiceProvider provider, ServiceDescrip return descriptor.ImplementationInstance; } - if (descriptor.ImplementationType != null) + // Not suppose to be abstract. + Type implementationType = descriptor.ImplementationType; + if (implementationType != null) { - return provider.GetServiceOrCreateInstance(descriptor.ImplementationType); + if (implementationType != descriptor.ServiceType) + return provider.GetServiceOrCreateInstance(descriptor.ImplementationType); + + // Since implementationType is equal to ServiceType we need explicitly create an implementation type through reflections in order to avoid infinite recursion. + // Should not cause issue with singletons, since singleton will be a decorator and after this fact we can don't care about lifecycle of decorable service (for sure, if IDisposable of decorator disposes underlying type:)) + return provider.CreateInstance(implementationType); } if (descriptor.ImplementationFactory != null) diff --git a/test/Scrutor.Tests/DecorationTests.cs b/test/Scrutor.Tests/DecorationTests.cs index ed281ded..ff70d564 100644 --- a/test/Scrutor.Tests/DecorationTests.cs +++ b/test/Scrutor.Tests/DecorationTests.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.DependencyInjection; using Xunit; using System.Linq; +using static Scrutor.Tests.DecorationTests; namespace Scrutor.Tests { @@ -217,8 +218,47 @@ public void DecoratingNonRegisteredServiceThrows() Assert.Throws(() => ConfigureProvider(services => services.Decorate())); } + [Fact] + public void Issue148_Decorate_IsAbleToDecorateConcreateTypes() + { + var sp = ConfigureProvider(sc => + { + sc + .AddTransient() + .AddTransient() + .Decorate(); + }); + + var result = sp.GetService() as Decorator2; + + Assert.NotNull(result); + Assert.NotNull(result.Inner); + Assert.NotNull(result.Inner.Dependency); + } + public interface IDecoratedService { } + public class DecoratedService + { + public DecoratedService(IService dependency) + { + Dependency = dependency; + } + + public IService Dependency { get; } + } + + public class Decorator2 : DecoratedService + { + public Decorator2(DecoratedService decoratedService) + : base(null) + { + Inner = decoratedService; + } + + public DecoratedService Inner { get; } + } + public interface IService { } private class SomeRandomService : IService { }