diff --git a/cluster/src/main/java/io/scalecube/cluster/TransportWrapper.java b/cluster/src/main/java/io/scalecube/cluster/TransportWrapper.java index c37bc4e7..5e84f3b8 100644 --- a/cluster/src/main/java/io/scalecube/cluster/TransportWrapper.java +++ b/cluster/src/main/java/io/scalecube/cluster/TransportWrapper.java @@ -7,13 +7,14 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import reactor.core.publisher.Mono; public class TransportWrapper { private final Transport transport; - private final Map addressIndexByMember = new ConcurrentHashMap<>(); + private final Map addressIndexByMember = new ConcurrentHashMap<>(); public TransportWrapper(Transport transport) { this.transport = transport; @@ -27,21 +28,7 @@ public TransportWrapper(Transport transport) { * @return mono result */ public Mono requestResponse(Member member, Message request) { - final List
addresses = member.addresses(); - final AtomicInteger currentIndex = - addressIndexByMember.computeIfAbsent(member, m -> new AtomicInteger()); - return Mono.defer( - () -> { - synchronized (this) { - if (currentIndex.get() == addresses.size()) { - currentIndex.set(0); - } - final Address address = addresses.get(currentIndex.getAndIncrement()); - return transport.requestResponse(address, request); - } - }) - .retry(addresses.size() - 1) - .doOnError(throwable -> addressIndexByMember.remove(member, currentIndex)); + return invokeWithRetry(member, request, transport::requestResponse); } /** @@ -52,20 +39,28 @@ public Mono requestResponse(Member member, Message request) { * @return mono result */ public Mono send(Member member, Message request) { - final List
addresses = member.addresses(); - final AtomicInteger currentIndex = - addressIndexByMember.computeIfAbsent(member, m -> new AtomicInteger()); + return invokeWithRetry(member, request, transport::send); + } + + private Mono invokeWithRetry( + Member member, Message request, BiFunction> function) { return Mono.defer( - () -> { - synchronized (this) { - if (currentIndex.get() == addresses.size()) { - currentIndex.set(0); - } - final Address address = addresses.get(currentIndex.getAndIncrement()); - return transport.send(address, request); - } - }) - .retry(addresses.size() - 1) - .doOnError(throwable -> addressIndexByMember.remove(member, currentIndex)); + () -> { + final List
addresses = member.addresses(); + final Integer index = addressIndexByMember.computeIfAbsent(member, m -> 0); + final AtomicInteger currentIndex = new AtomicInteger(index); + + return Mono.defer( + () -> { + if (currentIndex.get() == addresses.size()) { + currentIndex.set(0); + } + final Address address = addresses.get(currentIndex.get()); + return function.apply(address, request); + }) + .doOnSuccess(s -> addressIndexByMember.put(member, currentIndex.get())) + .doOnError(ex -> currentIndex.incrementAndGet()) + .retry(addresses.size() - 1); + }); } } diff --git a/cluster/src/test/java/io/scalecube/cluster/TransportWrapperTest.java b/cluster/src/test/java/io/scalecube/cluster/TransportWrapperTest.java index eb0882bc..d3d15ec3 100644 --- a/cluster/src/test/java/io/scalecube/cluster/TransportWrapperTest.java +++ b/cluster/src/test/java/io/scalecube/cluster/TransportWrapperTest.java @@ -1,5 +1,7 @@ package io.scalecube.cluster; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -11,10 +13,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import java.util.stream.Stream.Builder; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -76,12 +76,12 @@ static void populateBuilder(Builder builder, int size) { } } - private Map addressIndexByMember() + private Map addressIndexByMember() throws NoSuchFieldException, IllegalAccessException { final Field field = TransportWrapper.class.getDeclaredField("addressIndexByMember"); field.setAccessible(true); //noinspection unchecked - return (Map) field.get(transportWrapper); + return (Map) field.get(transportWrapper); } @ParameterizedTest @@ -95,7 +95,7 @@ void requestResponseShouldWorkByRoundRobin(int size, int startIndex, int success } if (startIndex > 0) { - addressIndexByMember().put(member, new AtomicInteger(startIndex)); + addressIndexByMember().put(member, startIndex); } for (int i = 0; i < size; i++) { @@ -109,9 +109,11 @@ void requestResponseShouldWorkByRoundRobin(int size, int startIndex, int success } StepVerifier.create(transportWrapper.requestResponse(member, request)) - .assertNext(message -> Assertions.assertSame(response, message, "response")) + .assertNext(message -> assertSame(response, message, "response")) .thenCancel() .verify(); + + assertEquals(successIndex, addressIndexByMember().get(member), "successIndex"); } @Test @@ -124,13 +126,12 @@ void requestResponseShouldWorkThenFail() { .thenReturn(Mono.error(new RuntimeException("Error"))); StepVerifier.create(transportWrapper.requestResponse(member, request)) - .assertNext(message -> Assertions.assertSame(response, message, "response")) + .assertNext(message -> assertSame(response, message, "response")) .thenCancel() .verify(); StepVerifier.create(transportWrapper.requestResponse(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); } @Test @@ -143,11 +144,10 @@ void requestResponseShouldFailThenWork() { .thenReturn(Mono.just(response)); StepVerifier.create(transportWrapper.requestResponse(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); StepVerifier.create(transportWrapper.requestResponse(member, request)) - .assertNext(message -> Assertions.assertSame(response, message, "response")) + .assertNext(message -> assertSame(response, message, "response")) .thenCancel() .verify(); } @@ -163,7 +163,7 @@ void requestResponseShouldFailByRoundRobin(int size, int startIndex, int ignore) } if (startIndex > 0) { - addressIndexByMember().put(member, new AtomicInteger(startIndex)); + addressIndexByMember().put(member, startIndex); } for (int i = 0; i < size; i++) { @@ -173,8 +173,9 @@ void requestResponseShouldFailByRoundRobin(int size, int startIndex, int ignore) } StepVerifier.create(transportWrapper.requestResponse(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); + + assertEquals(startIndex, addressIndexByMember().get(member), "startIndex"); } @ParameterizedTest @@ -187,7 +188,7 @@ void sendShouldWorkByRoundRobin(int size, int startIndex, int successIndex) thro } if (startIndex > 0) { - addressIndexByMember().put(member, new AtomicInteger(startIndex)); + addressIndexByMember().put(member, startIndex); } for (int i = 0; i < size; i++) { @@ -201,6 +202,8 @@ void sendShouldWorkByRoundRobin(int size, int startIndex, int successIndex) thro } StepVerifier.create(transportWrapper.send(member, request)).verifyComplete(); + + assertEquals(successIndex, addressIndexByMember().get(member), "successIndex"); } @ParameterizedTest @@ -213,7 +216,7 @@ void sendShouldFailByRoundRobin(int size, int startIndex, int ignore) throws Exc } if (startIndex > 0) { - addressIndexByMember().put(member, new AtomicInteger(startIndex)); + addressIndexByMember().put(member, startIndex); } for (int i = 0; i < size; i++) { @@ -222,8 +225,9 @@ void sendShouldFailByRoundRobin(int size, int startIndex, int ignore) throws Exc } StepVerifier.create(transportWrapper.send(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); + + assertEquals(startIndex, addressIndexByMember().get(member), "startIndex"); } @Test @@ -237,8 +241,7 @@ void sendShouldWorkThenFail() { StepVerifier.create(transportWrapper.send(member, request)).verifyComplete(); StepVerifier.create(transportWrapper.send(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); } @Test @@ -251,8 +254,7 @@ void sendShouldFailThenWork() { .thenReturn(Mono.empty()); StepVerifier.create(transportWrapper.send(member, request)) - .verifyErrorSatisfies( - throwable -> Assertions.assertEquals("Error", throwable.getMessage())); + .verifyErrorSatisfies(throwable -> assertEquals("Error", throwable.getMessage())); StepVerifier.create(transportWrapper.send(member, request)).verifyComplete(); } }