From 71228976ce03a2eb5ae499a3ffee860cfde02060 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 21 Nov 2023 10:56:35 +0200 Subject: [PATCH] Allow SSE events to be filtered out from REST Client (cherry picked from commit c9d1eeae65a34632ee993b19ab48755e16cdf596) --- .../reactive/jackson/test/MultiSseTest.java | 85 +++++++++++++++++++ .../client/reactive/deployment/DotNames.java | 3 + .../RestClientReactiveProcessor.java | 37 ++++++++ .../resteasy/reactive/client/SseEvent.java | 27 ++++++ .../reactive/client/SseEventFilter.java | 22 +++++ .../reactive/client/impl/MultiInvoker.java | 69 +++++++++++++-- 6 files changed, 236 insertions(+), 7 deletions(-) create mode 100644 independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEventFilter.java diff --git a/extensions/resteasy-reactive/rest-client-reactive-jackson/deployment/src/test/java/io/quarkus/rest/client/reactive/jackson/test/MultiSseTest.java b/extensions/resteasy-reactive/rest-client-reactive-jackson/deployment/src/test/java/io/quarkus/rest/client/reactive/jackson/test/MultiSseTest.java index 780bb6b931694..629b881a93bec 100644 --- a/extensions/resteasy-reactive/rest-client-reactive-jackson/deployment/src/test/java/io/quarkus/rest/client/reactive/jackson/test/MultiSseTest.java +++ b/extensions/resteasy-reactive/rest-client-reactive-jackson/deployment/src/test/java/io/quarkus/rest/client/reactive/jackson/test/MultiSseTest.java @@ -9,6 +9,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import java.util.function.Predicate; import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; @@ -23,6 +24,7 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; import org.jboss.resteasy.reactive.RestStreamElementType; import org.jboss.resteasy.reactive.client.SseEvent; +import org.jboss.resteasy.reactive.client.SseEventFilter; import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -136,6 +138,25 @@ public void accept(SseEvent event) { new EventContainer("id1", "name1", new Dto("name1", "1")))); } + @Test + void shouldBeAbleReadEntireEventWhileAlsoBeingAbleToFilterEvents() { + var resultList = new CopyOnWriteArrayList<>(); + createClient() + .eventWithFilter() + .subscribe().with(new Consumer<>() { + @Override + public void accept(SseEvent event) { + resultList.add(new EventContainer(event.id(), event.name(), event.data())); + } + }); + await().atMost(5, TimeUnit.SECONDS) + .untilAsserted( + () -> assertThat(resultList).containsExactly( + new EventContainer("id", "n0", new Dto("name0", "0")), + new EventContainer("id", "n1", new Dto("name1", "1")), + new EventContainer("id", "n2", new Dto("name2", "2")))); + } + static class EventContainer { final String id; final String name; @@ -212,6 +233,26 @@ public interface SseClient { @Path("/event") @Produces(MediaType.SERVER_SENT_EVENTS) Multi> event(); + + @GET + @Path("/event-with-filter") + @Produces(MediaType.SERVER_SENT_EVENTS) + @SseEventFilter(CustomFilter.class) + Multi> eventWithFilter(); + } + + public static class CustomFilter implements Predicate> { + + @Override + public boolean test(SseEvent event) { + if ("heartbeat".equals(event.id())) { + return false; + } + if ("END".equals(event.data())) { + return false; + } + return true; + } } @Path("/sse") @@ -261,6 +302,50 @@ public void event(@Context SseEventSink sink, @Context Sse sse) { } } } + + @GET + @Path("/event-with-filter") + @Produces(MediaType.SERVER_SENT_EVENTS) + public void eventWithFilter(@Context SseEventSink sink, @Context Sse sse) { + try (sink) { + sink.send(sse.newEventBuilder() + .id("id") + .mediaType(MediaType.APPLICATION_JSON_TYPE) + .data(Dto.class, new Dto("name0", "0")) + .name("n0") + .build()); + + sink.send(sse.newEventBuilder() + .id("heartbeat") + .comment("heartbeat") + .mediaType(MediaType.APPLICATION_JSON_TYPE) + .build()); + + sink.send(sse.newEventBuilder() + .id("id") + .mediaType(MediaType.APPLICATION_JSON_TYPE) + .data(Dto.class, new Dto("name1", "1")) + .name("n1") + .build()); + + sink.send(sse.newEventBuilder() + .id("heartbeat") + .comment("heartbeat") + .build()); + + sink.send(sse.newEventBuilder() + .id("id") + .mediaType(MediaType.APPLICATION_JSON_TYPE) + .data(Dto.class, new Dto("name2", "2")) + .name("n2") + .build()); + + sink.send(sse.newEventBuilder() + .id("end") + .data("END") + .build()); + } + } } @Path("/sse-rest-stream-element-type") diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/DotNames.java b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/DotNames.java index add3e44795d65..f635e470595a4 100644 --- a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/DotNames.java +++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/DotNames.java @@ -12,6 +12,7 @@ import org.eclipse.microprofile.rest.client.annotation.RegisterProviders; import org.eclipse.microprofile.rest.client.ext.ResponseExceptionMapper; import org.jboss.jandex.DotName; +import org.jboss.resteasy.reactive.client.SseEventFilter; import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.ClientFormParam; @@ -41,6 +42,8 @@ public class DotNames { static final DotName METHOD = DotName.createSimple(Method.class.getName()); + public static final DotName SSE_EVENT_FILTER = DotName.createSimple(SseEventFilter.class); + private DotNames() { } } diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/RestClientReactiveProcessor.java b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/RestClientReactiveProcessor.java index 49ee3402e3daf..22a4b76f9b69e 100644 --- a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/RestClientReactiveProcessor.java +++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/main/java/io/quarkus/rest/client/reactive/deployment/RestClientReactiveProcessor.java @@ -64,6 +64,7 @@ import org.jboss.resteasy.reactive.common.util.QuarkusMultivaluedHashMap; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem; import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem; import io.quarkus.arc.deployment.GeneratedBeanBuildItem; import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor; @@ -371,6 +372,42 @@ void registerCompressionInterceptors(BuildProducer ref } } + @BuildStep + void handleSseEventFilter(BuildProducer reflectiveClasses, + BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) { + var index = beanArchiveIndexBuildItem.getIndex(); + Collection instances = index.getAnnotations(DotNames.SSE_EVENT_FILTER); + if (instances.isEmpty()) { + return; + } + + List filterClassNames = new ArrayList<>(instances.size()); + for (AnnotationInstance instance : instances) { + if (instance.target().kind() != AnnotationTarget.Kind.METHOD) { + continue; + } + if (instance.value() == null) { + continue; // can't happen + } + Type filterType = instance.value().asClass(); + DotName filterClassName = filterType.name(); + ClassInfo filterClassInfo = index.getClassByName(filterClassName.toString()); + if (filterClassInfo == null) { + log.warn("Unable to find class '" + filterType.name() + "' in index"); + } else if (!filterClassInfo.hasNoArgsConstructor()) { + throw new RestClientDefinitionException( + "Classes used in @SseEventFilter must have a no-args constructor. Offending class is '" + + filterClassName + "'"); + } else { + filterClassNames.add(filterClassName.toString()); + } + } + reflectiveClasses.produce(ReflectiveClassBuildItem + .builder(filterClassNames.toArray(new String[0])) + .constructors(true) + .build()); + } + @BuildStep @Record(ExecutionTime.STATIC_INIT) void addRestClientBeans(Capabilities capabilities, diff --git a/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEvent.java b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEvent.java index a6978b93d2dc7..bcbee51c809dc 100644 --- a/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEvent.java +++ b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEvent.java @@ -5,11 +5,38 @@ */ public interface SseEvent { + /** + * Get event identifier. + *

+ * Contains value of SSE {@code "id"} field. This field is optional. Method may return {@code null}, if the event + * identifier is not specified. + * + * @return event id. + */ String id(); + /** + * Get event name. + *

+ * Contains value of SSE {@code "event"} field. This field is optional. Method may return {@code null}, if the event + * name is not specified. + * + * @return event name, or {@code null} if not set. + */ String name(); + /** + * Get a comment string that accompanies the event. + *

+ * Contains value of the comment associated with SSE event. This field is optional. Method may return {@code null}, if + * the event comment is not specified. + * + * @return comment associated with the event. + */ String comment(); + /** + * Get event data. + */ T data(); } diff --git a/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEventFilter.java b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEventFilter.java new file mode 100644 index 0000000000000..d9419dca5dfdb --- /dev/null +++ b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/SseEventFilter.java @@ -0,0 +1,22 @@ +package org.jboss.resteasy.reactive.client; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.function.Predicate; + +/** + * Used when not all SSE events streamed from the server should be included in the event stream returned by the client. + *

+ * IMPORTANT: implementations MUST contain a no-args constructor + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface SseEventFilter { + + /** + * Predicate which decides whether an event should be included in the event stream returned by the client. + */ + Class>> value(); +} diff --git a/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/impl/MultiInvoker.java b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/impl/MultiInvoker.java index e483baa0ce357..4459e66000227 100644 --- a/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/impl/MultiInvoker.java +++ b/independent-projects/resteasy-reactive/client/runtime/src/main/java/org/jboss/resteasy/reactive/client/impl/MultiInvoker.java @@ -2,16 +2,19 @@ import java.io.ByteArrayInputStream; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.GenericType; import jakarta.ws.rs.core.MediaType; -import jakarta.ws.rs.core.Response; import org.jboss.resteasy.reactive.client.SseEvent; +import org.jboss.resteasy.reactive.client.SseEventFilter; import org.jboss.resteasy.reactive.common.jaxrs.ResponseImpl; import org.jboss.resteasy.reactive.common.util.RestMediaType; @@ -45,8 +48,8 @@ public Multi get(GenericType responseType) { /** * We need this class to work around a bug in Mutiny where we can register our cancel listener - * after the subscription is cancelled and we never get notified - * See https://github.com/smallrye/smallrye-mutiny/issues/417 + * after the subscription is cancelled, and we never get notified + * See ... */ static class MultiRequest { @@ -127,9 +130,11 @@ public Multi method(String name, Entity entity, GenericType respons if (!emitter.isCancelled()) { if (response.getStatus() == 200 && MediaType.SERVER_SENT_EVENTS_TYPE.isCompatible(response.getMediaType())) { - registerForSse(multiRequest, responseType, response, vertxResponse, + registerForSse( + multiRequest, responseType, vertxResponse, (String) restClientRequestContext.getProperties() - .get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP)); + .get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP), + restClientRequestContext.getInvokedMethod()); } else if (response.getStatus() == 200 && RestMediaType.APPLICATION_STREAM_JSON_TYPE.isCompatible(response.getMediaType())) { registerForJsonStream(multiRequest, restClientRequestContext, responseType, response, @@ -156,14 +161,16 @@ private boolean isNewlineDelimited(ResponseImpl response) { @SuppressWarnings({ "unchecked", "rawtypes" }) private void registerForSse(MultiRequest multiRequest, GenericType responseType, - Response response, - HttpClientResponse vertxResponse, String defaultContentType) { + HttpClientResponse vertxResponse, String defaultContentType, + Method invokedMethod) { boolean returnSseEvent = SseEvent.class.equals(responseType.getRawType()); GenericType responseTypeFirstParam = responseType.getType() instanceof ParameterizedType ? new GenericType(((ParameterizedType) responseType.getType()).getActualTypeArguments()[0]) : null; + Predicate> eventPredicate = createEventPredicate(invokedMethod); + // honestly, isn't reconnect contradictory with completion? // FIXME: Reconnect settings? // For now we don't want multi to reconnect @@ -172,8 +179,39 @@ private void registerForSse(MultiRequest multiRequest, multiRequest.onCancel(sseSource::close); sseSource.register(event -> { + + // TODO: we might want to cut down on the allocations here... + + if (eventPredicate != null) { + boolean keep = eventPredicate.test(new SseEvent<>() { + @Override + public String id() { + return event.getId(); + } + + @Override + public String name() { + return event.getName(); + } + + @Override + public String comment() { + return event.getComment(); + } + + @Override + public String data() { + return event.readData(); + } + }); + if (!keep) { + return; + } + } + // DO NOT pass the response mime type because it's SSE: let the event pick between the X-SSE-Content-Type header or // the content-type SSE field + if (returnSseEvent) { multiRequest.emit((R) new SseEvent() { @Override @@ -212,6 +250,23 @@ public Object data() { sseSource.registerAfterRequest(vertxResponse); } + private Predicate> createEventPredicate(Method invokedMethod) { + if (invokedMethod == null) { + return null; // should never happen + } + + SseEventFilter filterAnnotation = invokedMethod.getAnnotation(SseEventFilter.class); + if (filterAnnotation == null) { + return null; + } + + try { + return filterAnnotation.value().getConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + private void registerForChunks(MultiRequest multiRequest, RestClientRequestContext restClientRequestContext, GenericType responseType,