Skip to content
This repository has been archived by the owner on Jul 1, 2022. It is now read-only.

Commit

Permalink
Concurrency improvements to RemoteControlledSampler (#609)
Browse files Browse the repository at this point in the history
* Fix #608 - multithreaded performance/synchronization issues

This removes synchronization from RemoteControlledSampler
and makes RateLimiter thread-safe.

Signed-off-by: Yegor Borovikov <yegor@uber.com>

* Revert RemoteControlledSamplerTest changes

Signed-off-by: Yegor Borovikov <yegor@uber.com>

* Further optimize RateLimiter.checkCredit()

Signed-off-by: Yegor Borovikov <yegor@uber.com>

* Move clock.currentNanoTicks() inside optimistic retry loop

Signed-off-by: Yegor Borovikov <yegor@uber.com>

* Log update failure; add timer health test

Signed-off-by: Yegor Borovikov <yegor@uber.com>
  • Loading branch information
yborovikov authored and pavolloffay committed Apr 29, 2019
1 parent 20ee4e8 commit b2a0e59
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ public class RemoteControlledSampler implements Sampler {
private final int maxOperations = 2000;
private final SamplingManager manager;

@Getter(AccessLevel.PACKAGE)
private Sampler sampler;
// initialized in constructor and updated from a single (poll timer) thread
// volatile to guarantee immediate visibility of the updated sampler to other threads (remove if not a requirement)
@Getter(AccessLevel.PACKAGE) // visible for testing
private volatile Sampler sampler;

// most of the time, toString here is called from the JaegerTracer, which holds this as well
@ToString.Exclude private final String serviceName;

@ToString.Exclude private final Timer pollTimer;
@ToString.Exclude private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
@ToString.Exclude private final Metrics metrics;

private RemoteControlledSampler(Builder builder) {
Expand All @@ -67,22 +68,21 @@ private RemoteControlledSampler(Builder builder) {
new TimerTask() {
@Override
public void run() {
updateSampler();
try {
updateSampler();
} catch (Exception e) { // keep the timer thread alive
log.error("Failed to update sampler", e);
}
}
},
0,
builder.poolingIntervalMs);
return;
}

public ReentrantReadWriteLock getLock() {
return lock;
builder.pollingIntervalMs);
}

/**
* Updates {@link #sampler} to a new sampler when it is different.
*/
void updateSampler() {
void updateSampler() { // visible for testing
SamplingStrategyResponse response;
try {
response = manager.getSamplingStrategy(serviceName);
Expand Down Expand Up @@ -117,29 +117,27 @@ private void updateRateLimitingOrProbabilisticSampler(SamplingStrategyResponse r
return;
}

synchronized (this) {
if (!this.sampler.equals(sampler)) {
this.sampler = sampler;
metrics.samplerUpdated.inc(1);
}
if (!this.sampler.equals(sampler)) {
this.sampler = sampler;
metrics.samplerUpdated.inc(1);
}
}

private synchronized void updatePerOperationSampler(OperationSamplingParameters samplingParameters) {
if (sampler instanceof PerOperationSampler) {
if (((PerOperationSampler) sampler).update(samplingParameters)) {
private void updatePerOperationSampler(OperationSamplingParameters samplingParameters) {
Sampler currentSampler = sampler;
if (currentSampler instanceof PerOperationSampler) {
if (((PerOperationSampler) currentSampler).update(samplingParameters)) {
metrics.samplerUpdated.inc(1);
}
} else {
sampler = new PerOperationSampler(maxOperations, samplingParameters);
metrics.samplerUpdated.inc(1);
}
}

@Override
public SamplingStatus sample(String operation, long id) {
synchronized (this) {
return sampler.sample(operation, id);
}
return sampler.sample(operation, id);
}

@Override
Expand All @@ -149,32 +147,22 @@ public boolean equals(Object sampler) {
}
if (sampler instanceof RemoteControlledSampler) {
RemoteControlledSampler remoteSampler = ((RemoteControlledSampler) sampler);
synchronized (this) {
ReentrantReadWriteLock.ReadLock readLock = remoteSampler.getLock().readLock();
readLock.lock();
try {
return this.sampler.equals(remoteSampler.sampler);
} finally {
readLock.unlock();
}
}
return this.sampler.equals(remoteSampler.sampler);
}
return false;
}

@Override
public void close() {
synchronized (this) {
pollTimer.cancel();
}
pollTimer.cancel();
}

public static class Builder {
private final String serviceName;
private SamplingManager samplingManager;
private Sampler initialSampler;
private Metrics metrics;
private int poolingIntervalMs = DEFAULT_POLLING_INTERVAL_MS;
private int pollingIntervalMs = DEFAULT_POLLING_INTERVAL_MS;

public Builder(String serviceName) {
this.serviceName = serviceName;
Expand All @@ -196,7 +184,7 @@ public Builder withMetrics(Metrics metrics) {
}

public Builder withPollingInterval(int pollingIntervalMs) {
this.poolingIntervalMs = pollingIntervalMs;
this.pollingIntervalMs = pollingIntervalMs;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,42 @@
import io.jaegertracing.internal.clock.Clock;
import io.jaegertracing.internal.clock.SystemClock;

import java.util.concurrent.atomic.AtomicLong;

public class RateLimiter {
private final double creditsPerNanosecond;
private final Clock clock;
private double balance;
private double maxBalance;
private long lastTick;
private final double creditsPerNanosecond;
private final long maxBalance; // max balance in nano ticks
private final AtomicLong debit; // last op nano time less remaining balance

public RateLimiter(double creditsPerSecond, double maxBalance) {
this(creditsPerSecond, maxBalance, new SystemClock());
}

public RateLimiter(double creditsPerSecond, double maxBalance, Clock clock) {
this.clock = clock;
this.balance = maxBalance;
this.maxBalance = maxBalance;
this.creditsPerNanosecond = creditsPerSecond / 1.0e9;
this.maxBalance = (long) (maxBalance / creditsPerNanosecond);
this.debit = new AtomicLong(clock.currentNanoTicks() - this.maxBalance);
}

public boolean checkCredit(double itemCost) {
long currentTime = clock.currentNanoTicks();
double elapsedTime = currentTime - lastTick;
lastTick = currentTime;
balance += elapsedTime * creditsPerNanosecond;
if (balance > maxBalance) {
balance = maxBalance;
}
if (balance >= itemCost) {
balance -= itemCost;
return true;
}
return false;
long cost = (long) (itemCost / creditsPerNanosecond);
long credit;
long currentDebit;
long balance;
do {
currentDebit = debit.get();
credit = clock.currentNanoTicks();
balance = credit - currentDebit;
if (balance > maxBalance) {
balance = maxBalance;
}
balance -= cost;
if (balance < 0) {
return false;
}
} while (!debit.compareAndSet(currentDebit, credit - balance));
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -33,6 +34,9 @@
import io.jaegertracing.spi.SamplingManager;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -52,6 +56,7 @@ public class RemoteControlledSamplerTest {
@Before
public void setUp() throws Exception {
metrics = new Metrics(new InMemoryMetricsFactory());
// TODO this starts the timer with mocks not yet configured, causing NPEs; refactor to .build() from tests
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(samplingManager)
.withInitialSampler(initialSampler)
Expand Down Expand Up @@ -107,6 +112,7 @@ public void testUpdateToPerOperationSamplerReplacesProbabilisticSampler() throws

@Test
public void testUpdatePerOperationSamplerUpdatesExistingPerOperationSampler() throws Exception {
undertest.close();
PerOperationSampler perOperationSampler = mock(PerOperationSampler.class);
OperationSamplingParameters parameters = mock(OperationSamplingParameters.class);
when(samplingManager.getSamplingStrategy(SERVICE_NAME)).thenReturn(
Expand Down Expand Up @@ -138,6 +144,23 @@ public void testUnparseableResponse() throws Exception {
assertEquals(initialSampler, undertest.getSampler());
}

@Test
public void testUpdateFailureKeepsTimerRunning() throws InterruptedException {
undertest.close();
CountDownLatch latch = new CountDownLatch(3);
SamplingManager failingManager = serviceName -> {
latch.countDown();
throw new RuntimeException("test update failure");
};
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(failingManager)
.withInitialSampler(initialSampler)
.withMetrics(metrics)
.withPollingInterval(1)
.build();
assertTrue(latch.await(1, TimeUnit.SECONDS));
}

@Test
public void testSample() throws Exception {
undertest.sample("op", 1L);
Expand All @@ -160,6 +183,7 @@ public void testEquals() {

@Test
public void testDefaultProbabilisticSampler() {
undertest.close();
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(samplingManager)
.withInitialSampler(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@

package io.jaegertracing.internal.utils;

import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import io.jaegertracing.internal.clock.Clock;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Test;

public class RateLimiterTest {
RateLimiter limiter;

private static class MockClock implements Clock {

Expand Down Expand Up @@ -127,4 +131,59 @@ public void testRateLimiterMaxBalance() {
assertTrue(limiter.checkCredit(1.0));
assertFalse(limiter.checkCredit(1.0));
}

/**
* Validates rate limiter behavior with {@link System#nanoTime()}-like (non-zero) initial nano ticks.
*/
@Test
public void testRateLimiterInitial() {
MockClock clock = new MockClock();
clock.timeNanos = TimeUnit.MILLISECONDS.toNanos(-1_000_000);
RateLimiter limiter = new RateLimiter(1000, 100, clock);

assertTrue(limiter.checkCredit(100)); // consume initial (max) balance
assertFalse(limiter.checkCredit(1));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(49); // add 49 credits
assertFalse(limiter.checkCredit(50));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1); // add one credit
assertTrue(limiter.checkCredit(50)); // consume accrued balance
assertFalse(limiter.checkCredit(1));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1_000_000); // add a lot of credits (max out balance)
assertTrue(limiter.checkCredit(1)); // take one credit

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1_000_000); // add a lot of credits (max out balance)
assertFalse(limiter.checkCredit(101)); // can't consume more than max balance
assertTrue(limiter.checkCredit(100)); // consume max balance
assertFalse(limiter.checkCredit(1));
}

/**
* Validates concurrent credit check correctness.
*/
@Test
public void testRateLimiterConcurrency() {
int numWorkers = ForkJoinPool.getCommonPoolParallelism();
int creditsPerWorker = 1000;
MockClock clock = new MockClock();
RateLimiter limiter = new RateLimiter(1, numWorkers * creditsPerWorker, clock);

AtomicInteger count = new AtomicInteger();
for (int w = 0; w < numWorkers; ++w) {
ForkJoinPool.commonPool().execute(() -> {
for (int i = 0; i < creditsPerWorker * 2; ++i) {
if (limiter.checkCredit(1)) {
count.getAndIncrement(); // count allowed operations
}
}
});
}
ForkJoinPool.commonPool().awaitQuiescence(1, TimeUnit.SECONDS);

assertEquals("Exactly the allocated number of credits must be consumed", numWorkers * creditsPerWorker,count.get());
assertFalse(limiter.checkCredit(1));
}

}

0 comments on commit b2a0e59

Please sign in to comment.