Skip to content

Commit

Permalink
Update handling of simple and aggregated limiters
Browse files Browse the repository at this point in the history
  • Loading branch information
John Luo committed Jun 3, 2021
1 parent 6f4f892 commit f04fb9f
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.RequestLimiter;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
Expand All @@ -16,7 +16,7 @@
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
<Reference Include="Microsoft.Extensions.Logging.Console" />
<Reference Include="System.Threading.ResourceLimits" />
<Reference Include="System.Runtime.RateLimits" />
</ItemGroup>

</Project>
25 changes: 19 additions & 6 deletions src/Middleware/RequestLimiter/sample/Startup.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading.ResourceLimits;
using System;
using System.Net;
using System.Runtime.RateLimits;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
Expand All @@ -19,16 +21,21 @@ public class Startup
public void ConfigureServices(IServiceCollection services)
{
services.AddControllersWithViews();
services.AddSingleton(new IPAggregatedRateLimiter(2, 2));
services.AddSingleton(new TokenBucketRateLimiter(2, 2));
services.AddSingleton(
new TokenBucketRateLimiter(
new TokenBucketRateLimiterOptions {
PermitLimit = 2,
TokensPerPeriod = 2,
ReplenishmentPeriod = TimeSpan.FromSeconds(1)
}));

services.AddRequestLimiter(options =>
{
options.SetDefaultPolicy(new ConcurrencyLimiter(new ConcurrencyLimiterOptions { ResourceLimit = 100 }));
options.SetDefaultPolicy(new ConcurrencyLimiter(new ConcurrencyLimiterOptions { PermitLimit = 100 }));
options.AddPolicy("ipPolicy", policy =>
{
// Add instance
policy.AddAggregatedLimiter(new IPAggregatedRateLimiter(2, 2));
policy.AddAggregatedLimiter(new AggregatedTokenBucketLimiter<IPAddress>(2, 2), context => context.Connection.RemoteIpAddress);
});
options.AddPolicy("rate", policy =>
{
Expand Down Expand Up @@ -57,7 +64,13 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILogger<
{
await Task.Delay(5000);
await context.Response.WriteAsync("Hello World!");
}).EnforceRequestLimit(new TokenBucketRateLimiter(2, 2));
}).EnforceRequestLimit(new TokenBucketRateLimiter(
new TokenBucketRateLimiterOptions
{
PermitLimit = 2,
TokensPerPeriod = 2,
ReplenishmentPeriod = TimeSpan.FromSeconds(1)
}));
endpoints.MapGet("/concurrent", async context =>
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
using System;
using System.Collections.Concurrent;
using System.Net;
using System.Threading;
using System.Runtime.RateLimits;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.RequestLimiter
{
// TODO: update implementation with WaitAsync
public class IPAggregatedRateLimiter : AggregatedRateLimiter<HttpContext>
public class AggregatedTokenBucketLimiter<TContext> : AggregatedRateLimiter<TContext> where TContext: notnull
{
private int _permitCount;

private readonly int _maxPermitCount;
private readonly int _newPermitPerSecond;
private readonly Timer _renewTimer;
// TODO: This is racy
private readonly ConcurrentDictionary<IPAddress, int> _cache = new();
private readonly ConcurrentDictionary<TContext, int> _cache = new();

private static readonly RateLimitLease FailedLease = new RateLimitLease(false);
private static readonly RateLimitLease SuccessfulLease = new RateLimitLease(true);
private static readonly RateLimitLease FailedLease = new(false);
private static readonly RateLimitLease SuccessfulLease = new(true);

public IPAggregatedRateLimiter(int permitCount, int newPermitPerSecond)
public AggregatedTokenBucketLimiter(int permitCount, int newPermitPerSecond)
{
_permitCount = permitCount;
_maxPermitCount = permitCount;
Expand All @@ -33,35 +31,21 @@ public IPAggregatedRateLimiter(int permitCount, int newPermitPerSecond)
_renewTimer = new Timer(Replenish, this, TimeSpan.FromSeconds(5), TimeSpan.FromSeconds(5));
}

public override int AvailablePermits(HttpContext context)
public override int AvailablePermits(TContext context)
{
if (context.Connection.RemoteIpAddress == null)
{
// Unknown IP?
return 0;
}

return _cache.TryGetValue(context.Connection.RemoteIpAddress, out var count) ? count : 0;
return _cache.TryGetValue(context, out var count) ? count : 0;
}

public override PermitLease Acquire(HttpContext context, int permitCount)
public override PermitLease Acquire(TContext context, int permitCount)
{
if (permitCount > _maxPermitCount)
{
return FailedLease;
}

if (context.Connection.RemoteIpAddress == null)
{
// TODO: how should this case be handled?
return SuccessfulLease;
}

var key = context.Connection.RemoteIpAddress;

if (!_cache.TryGetValue(key, out var count))
if (!_cache.TryGetValue(context, out var count))
{
if (_cache.TryAdd(key, _maxPermitCount - permitCount))
if (_cache.TryAdd(context, _maxPermitCount - permitCount))
{
return SuccessfulLease;
}
Expand All @@ -76,30 +60,30 @@ public override PermitLease Acquire(HttpContext context, int permitCount)
return FailedLease;
}

if (_cache.TryUpdate(key, newCount, count))
if (_cache.TryUpdate(context, newCount, count))
{
return SuccessfulLease;
}

if (!_cache.TryGetValue(key, out count))
if (!_cache.TryGetValue(context, out count))
{
if (_cache.TryAdd(key, _maxPermitCount - permitCount))
if (_cache.TryAdd(context, _maxPermitCount - permitCount))
{
return SuccessfulLease;
}
}
}
}

public override ValueTask<PermitLease> WaitAsync(HttpContext context, int permitCount, CancellationToken cancellationToken = default)
public override ValueTask<PermitLease> WaitAsync(TContext context, int permitCount, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

private static void Replenish(object? state)
{
// Return if Replenish already running to avoid concurrency.
if (state is not IPAggregatedRateLimiter limiter)
if (state is not AggregatedTokenBucketLimiter<TContext> limiter)
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Threading;
using System.Runtime.RateLimits;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.RequestLimiter
{
internal class AggregatedLimiterWrapper<TContext> : AggregatedRateLimiter<HttpContext> where TContext: notnull
{
private readonly AggregatedRateLimiter<TContext> _limiter;
private readonly Func<HttpContext, TContext> _selector;

public AggregatedLimiterWrapper(AggregatedRateLimiter<TContext> limiter, Func<HttpContext, TContext> selector)
{
_limiter = limiter;
_selector = selector;
}

public override PermitLease Acquire(HttpContext context, int permitCount)
{
return _limiter.Acquire(_selector(context), permitCount);
}

public override int AvailablePermits(HttpContext context)
{
return _limiter.AvailablePermits(_selector(context));
}

public override ValueTask<PermitLease> WaitAsync(HttpContext context, int requestedCount, CancellationToken cancellationToken = default)
{
return _limiter.WaitAsync(_selector(context), requestedCount, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

namespace Microsoft.AspNetCore.RequestLimiter
{
internal class HttpContextLimiter : AggregatedRateLimiter<HttpContext>
internal class SimpleLimiterWrapper : AggregatedRateLimiter<HttpContext>
{
private readonly RateLimiter _limiter;

public HttpContextLimiter(RateLimiter limiter)
public SimpleLimiterWrapper(RateLimiter limiter)
{
_limiter = limiter;
}
Expand All @@ -28,7 +28,5 @@ public override ValueTask<PermitLease> WaitAsync(HttpContext context, int reques
{
return _limiter.WaitAsync(requestedCount, cancellationToken);
}

public static implicit operator HttpContextLimiter(RateLimiter limiter) => new(limiter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ public static IEndpointConventionBuilder EnforceRequestRateLimit(this IEndpointC
{
endpointBuilder.Metadata.Add(
new RequestLimitAttribute(
(HttpContextLimiter)new TokenBucketRateLimiter(
new TokenBucketRateLimiterOptions
{
PermitLimit = requestPerSecond,
ReplenishmentPeriod = TimeSpan.FromSeconds(1),
TokensPerPeriod = requestPerSecond
})));
new SimpleLimiterWrapper(
new TokenBucketRateLimiter(
new TokenBucketRateLimiterOptions
{
PermitLimit = requestPerSecond,
ReplenishmentPeriod = TimeSpan.FromSeconds(1),
TokensPerPeriod = requestPerSecond
}))));
});

return builder;
Expand All @@ -54,22 +55,29 @@ public static IEndpointConventionBuilder EnforceRequestConcurrencyLimit(this IEn
{
endpointBuilder.Metadata.Add(
new RequestLimitAttribute(
(HttpContextLimiter)new ConcurrencyLimiter(
new ConcurrencyLimiterOptions { PermitLimit = concurrentRequests })));
new SimpleLimiterWrapper(
new ConcurrencyLimiter(
new ConcurrencyLimiterOptions { PermitLimit = concurrentRequests }))));
});

return builder;
}

public static IEndpointConventionBuilder EnforceRequestLimit(this IEndpointConventionBuilder builder, RateLimiter limiter)
=> builder.EnforceRequestLimit((HttpContextLimiter)limiter);

public static IEndpointConventionBuilder EnforceRequestLimit(this IEndpointConventionBuilder builder, AggregatedRateLimiter<HttpContext> limiter)
{
builder.Add(endpointBuilder =>
{
endpointBuilder.Metadata.Add(new RequestLimitAttribute(new SimpleLimiterWrapper(limiter)));
});

return builder;
}

public static IEndpointConventionBuilder EnforceRequestLimit<TContext>(this IEndpointConventionBuilder builder, AggregatedRateLimiter<TContext> limiter, Func<HttpContext, TContext> selector) where TContext: notnull
{
builder.Add(endpointBuilder =>
{
endpointBuilder.Metadata.Add(new RequestLimitAttribute(limiter));
endpointBuilder.Metadata.Add(new RequestLimitAttribute(new AggregatedLimiterWrapper<TContext>(limiter, selector)));
});

return builder;
Expand Down
14 changes: 9 additions & 5 deletions src/Middleware/RequestLimiter/src/RequestLimiterOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@ public class RequestLimiterOptions

public void SetDefaultPolicy(RateLimiter limiter)
{
ResolveDefaultRequestLimit = _ => (HttpContextLimiter)limiter;
ResolveDefaultRequestLimit = _ => new SimpleLimiterWrapper(limiter);
}

public void SetDefaultPolicy(AggregatedRateLimiter<HttpContext> aggregatedLimiter)
public void SetDefaultPolicy<TContext>(AggregatedRateLimiter<TContext> aggregatedLimiter, Func<HttpContext, TContext> selector) where TContext : notnull
{
ResolveDefaultRequestLimit = _ => aggregatedLimiter;
ResolveDefaultRequestLimit = _ => new AggregatedLimiterWrapper<TContext>(aggregatedLimiter, selector);
}
public void SetDefaultPolicy<TRateLimiter>() where TRateLimiter : RateLimiter
{
ResolveDefaultRequestLimit = services => new AggregatedLimiterWrapper<TContext>(services.GetRequiredService<TRateLimiter>(), selector);
}

public void SetDefaultPolicy<TRateLimiter>() where TRateLimiter : AggregatedRateLimiter<HttpContext>
public void SetDefaultPolicy<TRateLimiter, TContext>(Func<HttpContext, TContext> selector) where TRateLimiter : AggregatedRateLimiter<TContext> where TContext : notnull
{
ResolveDefaultRequestLimit = services => services.GetRequiredService<TRateLimiter>();
ResolveDefaultRequestLimit = services => new AggregatedLimiterWrapper<TContext>(services.GetRequiredService<TRateLimiter>(), selector);
}

public void AddPolicy(string name, Action<RequestLimiterPolicy> configurePolicy)
Expand Down
12 changes: 6 additions & 6 deletions src/Middleware/RequestLimiter/src/RequestLimiterPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ public class RequestLimiterPolicy

public void AddLimiter(RateLimiter limiter)
{
LimiterResolvers.Add(_ => (HttpContextLimiter)limiter);
LimiterResolvers.Add(_ => new SimpleLimiterWrapper(limiter));
}

public void AddAggregatedLimiter(AggregatedRateLimiter<HttpContext> aggregatedLimiter)
public void AddAggregatedLimiter<TContext>(AggregatedRateLimiter<TContext> aggregatedLimiter, Func<HttpContext, TContext> selector) where TContext : notnull
{
LimiterResolvers.Add(_ => aggregatedLimiter);
LimiterResolvers.Add(_ => new AggregatedLimiterWrapper<TContext>(aggregatedLimiter, selector));
}

public void AddLimiter<TRateLimiter>() where TRateLimiter : RateLimiter
{
LimiterResolvers.Add(services => (HttpContextLimiter)services.GetRequiredService<TRateLimiter>());
LimiterResolvers.Add(services => new SimpleLimiterWrapper(services.GetRequiredService<TRateLimiter>()));
}

public void AddAggregatedLimiter<TAggregatedRateLimiter>() where TAggregatedRateLimiter : AggregatedRateLimiter<HttpContext>
public void AddAggregatedLimiter<TAggregatedRateLimiter, TContext>(Func<HttpContext, TContext> selector) where TAggregatedRateLimiter : AggregatedRateLimiter<TContext> where TContext : notnull
{
LimiterResolvers.Add(services => services.GetRequiredService<TAggregatedRateLimiter>());
LimiterResolvers.Add(services => new AggregatedLimiterWrapper<TContext>(services.GetRequiredService<TAggregatedRateLimiter>(), selector));
}
}
}

0 comments on commit f04fb9f

Please sign in to comment.