Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limiting to CRUD and search APIs #1109

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions rest/src/main/groovy/whelk/rest/api/Crud.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package whelk.rest.api
import groovy.transform.CompileStatic
import groovy.transform.PackageScope
import groovy.util.logging.Log4j2 as Log

import whelk.util.RateLimiter

import org.apache.http.entity.ContentType
import whelk.Document
import whelk.IdGenerator
Expand Down Expand Up @@ -58,8 +61,19 @@ class Crud extends HttpServlet {

SiteSearch siteSearch

Map<String, Tuple2<Document, String>> cachedFetches = [:]

Map<String, Tuple2<Document, String>> cachedFetches = [:]

enum RequestType { READ, WRITE, FIND
}

// TODO: make configurable
Map<RequestType, RateLimiter> rateLimiters = [
RequestType.READ : new RateLimiter(100),
RequestType.WRITE: new RateLimiter(10),
RequestType.FIND : new RateLimiter(100),
]

Crud() {
// Do nothing - only here for Tomcat to have something to call
}
Expand Down Expand Up @@ -113,14 +127,17 @@ class Crud extends HttpServlet {

void doGet2(HttpServletRequest request, HttpServletResponse response) {
RestMetrics.Measurement measurement = null

try {
if (request.pathInfo == "/") {
measurement = metrics.measure('INDEX')
displayInfo(response)
} else if (siteSearch.isSearchResource(request.pathInfo)) {
rateLimit(request, RequestType.FIND)
measurement = metrics.measure('FIND')
handleQuery(request, response)
} else {
rateLimit(request, RequestType.READ)
measurement = metrics.measure('GET')
handleGetRequest(CrudGetRequest.parse(request), response)
}
Expand Down Expand Up @@ -514,7 +531,8 @@ class Crud extends HttpServlet {
if (!isSupportedContentType(request.getContentType())) {
throw new BadRequestException("Content-Type not supported.")
}

rateLimit(request, RequestType.WRITE)

Map requestBody = getRequestBody(request)

if (isEmptyInput(requestBody)) {
Expand Down Expand Up @@ -600,7 +618,8 @@ class Crud extends HttpServlet {
if (!isSupportedContentType(request.getContentType())) {
throw new BadRequestException("Content-Type not supported.")
}

rateLimit(request, RequestType.WRITE)

Map requestBody = getRequestBody(request)

if (isEmptyInput(requestBody)) {
Expand Down Expand Up @@ -832,7 +851,15 @@ class Crud extends HttpServlet {
response.setStatus(HttpServletResponse.SC_NO_CONTENT)
}
}


void rateLimit(HttpServletRequest request, RequestType requestType) {
HttpTools.getRemoteIp(request, [/*TODO*/]).ifPresent { ip ->
if (!rateLimiters[requestType].isOk(ip)) {
throw new RateLimitException('TODO')
}
}
}

static void sendError(HttpServletRequest request, HttpServletResponse response, Exception e) {
int code = mapError(e)
metrics.failedRequests.labels(request.getMethod(), code.toString()).inc()
Expand All @@ -855,6 +882,9 @@ class Crud extends HttpServlet {
case UnsupportedContentTypeException:
return HttpServletResponse.SC_NOT_ACCEPTABLE

case RateLimitException:
return HttpTools.SC_TOO_MANY_REQUESTS

case WhelkRuntimeException:
return HttpServletResponse.SC_INTERNAL_SERVER_ERROR

Expand All @@ -864,7 +894,7 @@ class Crud extends HttpServlet {

case OtherStatusException:
return ((OtherStatusException) e).code

default:
return HttpServletResponse.SC_INTERNAL_SERVER_ERROR
}
Expand All @@ -884,6 +914,12 @@ class Crud extends HttpServlet {
}
}

static class RateLimitException extends NoStackTraceException {
RateLimitException(String msg) {
super(msg)
}
}

/** "Don't use exceptions for flow control" in part comes from that exceptions in Java are
* expensive to create because building the stack trace is expensive. But in the context of
* sending error responses in this API exceptions are pretty useful for flow control.
Expand Down
12 changes: 12 additions & 0 deletions rest/src/main/groovy/whelk/rest/api/HttpTools.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import static whelk.util.Jackson.mapper
*/
@Log
class HttpTools {
final static int SC_TOO_MANY_REQUESTS = 429

static void sendResponse(HttpServletResponse response, Map data, String contentType, int statusCode = 200) {
if (data == null) {
sendResponse(response, new byte[0], contentType, statusCode)
Expand Down Expand Up @@ -87,6 +89,16 @@ class HttpTools {
return baseUri
}

static Optional<String> getRemoteIp(HttpServletRequest request, List<String> proxyIps) {
List remoteIps = []
request.getHeaders('X-Forwarded-For').each {
it.split(',').each { ip -> remoteIps.add(it.trim()) }
}


return Optional.empty()
}

enum DisplayMode {
DOCUMENT, META, RAW
}
Expand Down
52 changes: 52 additions & 0 deletions rest/src/main/java/whelk/util/RateLimiter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package whelk.util;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;

public class RateLimiter {
private final int UPDATE_PERIOD_MS = 100;

// Make the calculations work with really low max rates
private final long SCALE_FACTOR = 1000;

private ConcurrentMap<String, AtomicLong> buckets = new ConcurrentHashMap<>();
private AtomicLong lastUpdate = new AtomicLong(-1);
private int maxRateHz;

public RateLimiter(int maxRateHz) {
this.maxRateHz = maxRateHz;
}

public boolean isOk(String key) {
return isOk(key, System.currentTimeMillis());
}

public boolean isOk(String key, long currentTimeMillis) {
maybeDrainBuckets(currentTimeMillis);

var bucket = buckets.computeIfAbsent(key, k -> new AtomicLong());
long rate = bucket.accumulateAndGet(SCALE_FACTOR, (current, update) -> Math.min(current + update, maxRateHz * SCALE_FACTOR));

return rate < maxRateHz * SCALE_FACTOR;
}

private void maybeDrainBuckets(long currentTimeMillis) {
long previousUpdate = lastUpdate.get();
boolean isTimeToUpdate = currentTimeMillis == lastUpdate.accumulateAndGet(currentTimeMillis, (last, now) ->
now - last >= UPDATE_PERIOD_MS || last == -1 ? now : last
);
if (!isTimeToUpdate) {
return;
}

long deltaMs = currentTimeMillis - previousUpdate;
long drainAmount = (long) ((deltaMs / 1000.0) * maxRateHz * SCALE_FACTOR);
buckets.forEach((key, bucket) -> {
if (0 == bucket.accumulateAndGet(drainAmount, (current, update) -> Math.max(current - update, 0))) {
// we might miss some increments here, that's ok
buckets.remove(key);
}
});
}
}
53 changes: 53 additions & 0 deletions rest/src/test/java/whelk/util/RateLimiterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package whelk.util;

import org.junit.Assert;
import org.junit.Test;


public class RateLimiterTest {

@Test
public void test() {
int maxRate = 10;
var limiter = new RateLimiter(maxRate);
String aKey = "abc";
long time = 0;
for (int i = 0 ; i < maxRate - 1 ; i ++) {
Assert.assertTrue(limiter.isOk(aKey, time));
}
Assert.assertFalse(limiter.isOk(aKey, time));
Assert.assertTrue(limiter.isOk("anotherKey", time));
}

@Test
public void testRecover() {
int maxRate = 10;
var limiter = new RateLimiter(maxRate);
String aKey = "abc";
long time = 0;
for (int i = 0 ; i < maxRate - 1 ; i ++) {
Assert.assertTrue(limiter.isOk(aKey, time));
}
Assert.assertFalse(limiter.isOk(aKey, time));

Assert.assertTrue(limiter.isOk(aKey, time + 101));
Assert.assertFalse(limiter.isOk(aKey, time + 101));
}

@Test
public void testContinuous() {
int maxRate = 110;
var limiter = new RateLimiter(maxRate);
String aKey = "abc";
long time = 0;

for (int i = 0 ; i < 10_000 ; i ++) {
for (int j = 0 ; j < maxRate / 10 ; j ++) {
for (int k = 0 ; k < 10 ; k ++) {
Assert.assertTrue(limiter.isOk(aKey, time));
}
time += 100;
}
}
}
}
Loading