diff --git a/src/Middleware/RequestLimiter/sample/Controllers/HomeController.cs b/src/Middleware/RequestLimiter/sample/Controllers/HomeController.cs index b743626a7547..c1342aa25650 100644 --- a/src/Middleware/RequestLimiter/sample/Controllers/HomeController.cs +++ b/src/Middleware/RequestLimiter/sample/Controllers/HomeController.cs @@ -1,7 +1,3 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.RequestLimiter; diff --git a/src/Middleware/RequestLimiter/sample/RequestLimiterSample.csproj b/src/Middleware/RequestLimiter/sample/RequestLimiterSample.csproj index 1be3a3ab1333..0d524ffbd097 100644 --- a/src/Middleware/RequestLimiter/sample/RequestLimiterSample.csproj +++ b/src/Middleware/RequestLimiter/sample/RequestLimiterSample.csproj @@ -1,4 +1,4 @@ - + $(DefaultNetCoreTargetFramework) @@ -16,7 +16,7 @@ - + diff --git a/src/Middleware/RequestLimiter/sample/Startup.cs b/src/Middleware/RequestLimiter/sample/Startup.cs index d38a42a49a9a..30ec2e616968 100644 --- a/src/Middleware/RequestLimiter/sample/Startup.cs +++ b/src/Middleware/RequestLimiter/sample/Startup.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading.ResourceLimits; +using System; +using System.Net; +using System.Runtime.RateLimits; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; @@ -19,16 +21,21 @@ public class Startup public void ConfigureServices(IServiceCollection services) { services.AddControllersWithViews(); - services.AddSingleton(new IPAggregatedRateLimiter(2, 2)); - services.AddSingleton(new TokenBucketRateLimiter(2, 2)); + services.AddSingleton( + new TokenBucketRateLimiter( + new TokenBucketRateLimiterOptions { + PermitLimit = 2, + TokensPerPeriod = 2, + ReplenishmentPeriod = TimeSpan.FromSeconds(1) + })); services.AddRequestLimiter(options => { - options.SetDefaultPolicy(new ConcurrencyLimiter(new ConcurrencyLimiterOptions { ResourceLimit = 100 })); + options.SetDefaultPolicy(new ConcurrencyLimiter(new ConcurrencyLimiterOptions { PermitLimit = 100 })); options.AddPolicy("ipPolicy", policy => { // Add instance - policy.AddAggregatedLimiter(new IPAggregatedRateLimiter(2, 2)); + policy.AddAggregatedLimiter(new AggregatedTokenBucketLimiter(2, 2), context => context.Connection.RemoteIpAddress); }); options.AddPolicy("rate", policy => { @@ -57,7 +64,13 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILogger< { await Task.Delay(5000); await context.Response.WriteAsync("Hello World!"); - }).EnforceRequestLimit(new TokenBucketRateLimiter(2, 2)); + }).EnforceRequestLimit(new TokenBucketRateLimiter( + new TokenBucketRateLimiterOptions + { + PermitLimit = 2, + TokensPerPeriod = 2, + ReplenishmentPeriod = TimeSpan.FromSeconds(1) + })); endpoints.MapGet("/concurrent", async context => { diff --git a/src/Middleware/RequestLimiter/src/IPAggregatedRateLimiter.cs b/src/Middleware/RequestLimiter/src/AggregatedTokenBucketLimiter.cs similarity index 68% rename from src/Middleware/RequestLimiter/src/IPAggregatedRateLimiter.cs rename to src/Middleware/RequestLimiter/src/AggregatedTokenBucketLimiter.cs index 91e7020f39fe..5727436f9916 100644 --- a/src/Middleware/RequestLimiter/src/IPAggregatedRateLimiter.cs +++ b/src/Middleware/RequestLimiter/src/AggregatedTokenBucketLimiter.cs @@ -1,16 +1,14 @@ using System; using System.Collections.Concurrent; -using System.Net; using System.Threading; using System.Runtime.RateLimits; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; using System.Diagnostics.CodeAnalysis; namespace Microsoft.AspNetCore.RequestLimiter { // TODO: update implementation with WaitAsync - public class IPAggregatedRateLimiter : AggregatedRateLimiter + public class AggregatedTokenBucketLimiter : AggregatedRateLimiter where TContext: notnull { private int _permitCount; @@ -18,12 +16,12 @@ public class IPAggregatedRateLimiter : AggregatedRateLimiter private readonly int _newPermitPerSecond; private readonly Timer _renewTimer; // TODO: This is racy - private readonly ConcurrentDictionary _cache = new(); + private readonly ConcurrentDictionary _cache = new(); - private static readonly RateLimitLease FailedLease = new RateLimitLease(false); - private static readonly RateLimitLease SuccessfulLease = new RateLimitLease(true); + private static readonly RateLimitLease FailedLease = new(false); + private static readonly RateLimitLease SuccessfulLease = new(true); - public IPAggregatedRateLimiter(int permitCount, int newPermitPerSecond) + public AggregatedTokenBucketLimiter(int permitCount, int newPermitPerSecond) { _permitCount = permitCount; _maxPermitCount = permitCount; @@ -33,35 +31,21 @@ public IPAggregatedRateLimiter(int permitCount, int newPermitPerSecond) _renewTimer = new Timer(Replenish, this, TimeSpan.FromSeconds(5), TimeSpan.FromSeconds(5)); } - public override int AvailablePermits(HttpContext context) + public override int AvailablePermits(TContext context) { - if (context.Connection.RemoteIpAddress == null) - { - // Unknown IP? - return 0; - } - - return _cache.TryGetValue(context.Connection.RemoteIpAddress, out var count) ? count : 0; + return _cache.TryGetValue(context, out var count) ? count : 0; } - public override PermitLease Acquire(HttpContext context, int permitCount) + public override PermitLease Acquire(TContext context, int permitCount) { if (permitCount > _maxPermitCount) { return FailedLease; } - if (context.Connection.RemoteIpAddress == null) - { - // TODO: how should this case be handled? - return SuccessfulLease; - } - - var key = context.Connection.RemoteIpAddress; - - if (!_cache.TryGetValue(key, out var count)) + if (!_cache.TryGetValue(context, out var count)) { - if (_cache.TryAdd(key, _maxPermitCount - permitCount)) + if (_cache.TryAdd(context, _maxPermitCount - permitCount)) { return SuccessfulLease; } @@ -76,14 +60,14 @@ public override PermitLease Acquire(HttpContext context, int permitCount) return FailedLease; } - if (_cache.TryUpdate(key, newCount, count)) + if (_cache.TryUpdate(context, newCount, count)) { return SuccessfulLease; } - if (!_cache.TryGetValue(key, out count)) + if (!_cache.TryGetValue(context, out count)) { - if (_cache.TryAdd(key, _maxPermitCount - permitCount)) + if (_cache.TryAdd(context, _maxPermitCount - permitCount)) { return SuccessfulLease; } @@ -91,7 +75,7 @@ public override PermitLease Acquire(HttpContext context, int permitCount) } } - public override ValueTask WaitAsync(HttpContext context, int permitCount, CancellationToken cancellationToken = default) + public override ValueTask WaitAsync(TContext context, int permitCount, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } @@ -99,7 +83,7 @@ public override ValueTask WaitAsync(HttpContext context, int permit private static void Replenish(object? state) { // Return if Replenish already running to avoid concurrency. - if (state is not IPAggregatedRateLimiter limiter) + if (state is not AggregatedTokenBucketLimiter limiter) { return; } diff --git a/src/Middleware/RequestLimiter/src/LimiterWrappers/AggregatedLimiterWrapper.cs b/src/Middleware/RequestLimiter/src/LimiterWrappers/AggregatedLimiterWrapper.cs new file mode 100644 index 000000000000..be1fbe5f4cad --- /dev/null +++ b/src/Middleware/RequestLimiter/src/LimiterWrappers/AggregatedLimiterWrapper.cs @@ -0,0 +1,35 @@ +using System; +using System.Threading; +using System.Runtime.RateLimits; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.RequestLimiter +{ + internal class AggregatedLimiterWrapper : AggregatedRateLimiter where TContext: notnull + { + private readonly AggregatedRateLimiter _limiter; + private readonly Func _selector; + + public AggregatedLimiterWrapper(AggregatedRateLimiter limiter, Func selector) + { + _limiter = limiter; + _selector = selector; + } + + public override PermitLease Acquire(HttpContext context, int permitCount) + { + return _limiter.Acquire(_selector(context), permitCount); + } + + public override int AvailablePermits(HttpContext context) + { + return _limiter.AvailablePermits(_selector(context)); + } + + public override ValueTask WaitAsync(HttpContext context, int requestedCount, CancellationToken cancellationToken = default) + { + return _limiter.WaitAsync(_selector(context), requestedCount, cancellationToken); + } + } +} diff --git a/src/Middleware/RequestLimiter/src/HttpContextLimiter.cs b/src/Middleware/RequestLimiter/src/LimiterWrappers/SimpleLimiterWrapper.cs similarity index 77% rename from src/Middleware/RequestLimiter/src/HttpContextLimiter.cs rename to src/Middleware/RequestLimiter/src/LimiterWrappers/SimpleLimiterWrapper.cs index 6346ce3e46f4..f20f73dd69d3 100644 --- a/src/Middleware/RequestLimiter/src/HttpContextLimiter.cs +++ b/src/Middleware/RequestLimiter/src/LimiterWrappers/SimpleLimiterWrapper.cs @@ -5,11 +5,11 @@ namespace Microsoft.AspNetCore.RequestLimiter { - internal class HttpContextLimiter : AggregatedRateLimiter + internal class SimpleLimiterWrapper : AggregatedRateLimiter { private readonly RateLimiter _limiter; - public HttpContextLimiter(RateLimiter limiter) + public SimpleLimiterWrapper(RateLimiter limiter) { _limiter = limiter; } @@ -28,7 +28,5 @@ public override ValueTask WaitAsync(HttpContext context, int reques { return _limiter.WaitAsync(requestedCount, cancellationToken); } - - public static implicit operator HttpContextLimiter(RateLimiter limiter) => new(limiter); } } diff --git a/src/Middleware/RequestLimiter/src/RequestLimiterEndpointExtensions.cs b/src/Middleware/RequestLimiter/src/RequestLimiterEndpointExtensions.cs index d34c9279ab86..89b2342afb69 100644 --- a/src/Middleware/RequestLimiter/src/RequestLimiterEndpointExtensions.cs +++ b/src/Middleware/RequestLimiter/src/RequestLimiterEndpointExtensions.cs @@ -36,13 +36,14 @@ public static IEndpointConventionBuilder EnforceRequestRateLimit(this IEndpointC { endpointBuilder.Metadata.Add( new RequestLimitAttribute( - (HttpContextLimiter)new TokenBucketRateLimiter( - new TokenBucketRateLimiterOptions - { - PermitLimit = requestPerSecond, - ReplenishmentPeriod = TimeSpan.FromSeconds(1), - TokensPerPeriod = requestPerSecond - }))); + new SimpleLimiterWrapper( + new TokenBucketRateLimiter( + new TokenBucketRateLimiterOptions + { + PermitLimit = requestPerSecond, + ReplenishmentPeriod = TimeSpan.FromSeconds(1), + TokensPerPeriod = requestPerSecond + })))); }); return builder; @@ -54,22 +55,29 @@ public static IEndpointConventionBuilder EnforceRequestConcurrencyLimit(this IEn { endpointBuilder.Metadata.Add( new RequestLimitAttribute( - (HttpContextLimiter)new ConcurrencyLimiter( - new ConcurrencyLimiterOptions { PermitLimit = concurrentRequests }))); + new SimpleLimiterWrapper( + new ConcurrencyLimiter( + new ConcurrencyLimiterOptions { PermitLimit = concurrentRequests })))); }); return builder; } public static IEndpointConventionBuilder EnforceRequestLimit(this IEndpointConventionBuilder builder, RateLimiter limiter) - => builder.EnforceRequestLimit((HttpContextLimiter)limiter); - - public static IEndpointConventionBuilder EnforceRequestLimit(this IEndpointConventionBuilder builder, AggregatedRateLimiter limiter) { + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new RequestLimitAttribute(new SimpleLimiterWrapper(limiter))); + }); + return builder; + } + + public static IEndpointConventionBuilder EnforceRequestLimit(this IEndpointConventionBuilder builder, AggregatedRateLimiter limiter, Func selector) where TContext: notnull + { builder.Add(endpointBuilder => { - endpointBuilder.Metadata.Add(new RequestLimitAttribute(limiter)); + endpointBuilder.Metadata.Add(new RequestLimitAttribute(new AggregatedLimiterWrapper(limiter, selector))); }); return builder; diff --git a/src/Middleware/RequestLimiter/src/RequestLimiterOptions.cs b/src/Middleware/RequestLimiter/src/RequestLimiterOptions.cs index 29998cef1433..ff38cd76ad4c 100644 --- a/src/Middleware/RequestLimiter/src/RequestLimiterOptions.cs +++ b/src/Middleware/RequestLimiter/src/RequestLimiterOptions.cs @@ -18,17 +18,21 @@ public class RequestLimiterOptions public void SetDefaultPolicy(RateLimiter limiter) { - ResolveDefaultRequestLimit = _ => (HttpContextLimiter)limiter; + ResolveDefaultRequestLimit = _ => new SimpleLimiterWrapper(limiter); } - public void SetDefaultPolicy(AggregatedRateLimiter aggregatedLimiter) + public void SetDefaultPolicy(AggregatedRateLimiter aggregatedLimiter, Func selector) where TContext : notnull { - ResolveDefaultRequestLimit = _ => aggregatedLimiter; + ResolveDefaultRequestLimit = _ => new AggregatedLimiterWrapper(aggregatedLimiter, selector); + } + public void SetDefaultPolicy() where TRateLimiter : RateLimiter + { + ResolveDefaultRequestLimit = services => new AggregatedLimiterWrapper(services.GetRequiredService(), selector); } - public void SetDefaultPolicy() where TRateLimiter : AggregatedRateLimiter + public void SetDefaultPolicy(Func selector) where TRateLimiter : AggregatedRateLimiter where TContext : notnull { - ResolveDefaultRequestLimit = services => services.GetRequiredService(); + ResolveDefaultRequestLimit = services => new AggregatedLimiterWrapper(services.GetRequiredService(), selector); } public void AddPolicy(string name, Action configurePolicy) diff --git a/src/Middleware/RequestLimiter/src/RequestLimiterPolicy.cs b/src/Middleware/RequestLimiter/src/RequestLimiterPolicy.cs index 99017b03e9f7..84f85007a4e6 100644 --- a/src/Middleware/RequestLimiter/src/RequestLimiterPolicy.cs +++ b/src/Middleware/RequestLimiter/src/RequestLimiterPolicy.cs @@ -15,22 +15,22 @@ public class RequestLimiterPolicy public void AddLimiter(RateLimiter limiter) { - LimiterResolvers.Add(_ => (HttpContextLimiter)limiter); + LimiterResolvers.Add(_ => new SimpleLimiterWrapper(limiter)); } - public void AddAggregatedLimiter(AggregatedRateLimiter aggregatedLimiter) + public void AddAggregatedLimiter(AggregatedRateLimiter aggregatedLimiter, Func selector) where TContext : notnull { - LimiterResolvers.Add(_ => aggregatedLimiter); + LimiterResolvers.Add(_ => new AggregatedLimiterWrapper(aggregatedLimiter, selector)); } public void AddLimiter() where TRateLimiter : RateLimiter { - LimiterResolvers.Add(services => (HttpContextLimiter)services.GetRequiredService()); + LimiterResolvers.Add(services => new SimpleLimiterWrapper(services.GetRequiredService())); } - public void AddAggregatedLimiter() where TAggregatedRateLimiter : AggregatedRateLimiter + public void AddAggregatedLimiter(Func selector) where TAggregatedRateLimiter : AggregatedRateLimiter where TContext : notnull { - LimiterResolvers.Add(services => services.GetRequiredService()); + LimiterResolvers.Add(services => new AggregatedLimiterWrapper(services.GetRequiredService(), selector)); } } }