diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index 6eb75e98c601e..9260687981444 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -274,7 +274,7 @@ Uni consumeAsync(Message m) { } @OnTextMessage -ReponseMessage process(Message m) { +ResponseMessage process(Message m) { // Process the incoming message and send a response to the client. // The method is called for each incoming message. // Note that if the method returns `null`, no response will be sent to the client. @@ -287,7 +287,7 @@ Uni processAsync(Message m) { // Note that if the method returns `null`, no response will be sent to the client. The method completes when the returned Uni emits its item. } -OnTextMessage +@OnTextMessage Multi stream(Message m) { // Process the incoming message and send multiple responses to the client. // The method is called for each incoming message. @@ -643,6 +643,43 @@ Other options for securing HTTP upgrade requests, such as using the security ann NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection. +== Inspect and/or reject HTTP upgrade + +To inspect an HTTP upgrade, you must provide a CDI bean implementing the `io.quarkus.websockets.next.HttpUpgradeCheck` interface. +Quarkus calls the `HttpUpgradeCheck#perform` method on every HTTP request that should be upgraded to a WebSocket connection. +Inside this method, you can perform any business logic and/or reject the HTTP upgrade. + +.Example HttpUpgradeCheck +[source, java] +---- +package io.quarkus.websockets.next.test; + +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.smallrye.mutiny.Uni; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped <1> +public class ExampleHttpUpgradeCheck implements HttpUpgradeCheck { + + @Override + public Uni perform(HttpUpgradeContext ctx) { + if (rejectUpgrade(ctx)) { + return CheckResult.rejectUpgrade(400); <2> + } + return CheckResult.permitUpgrade(); + } + + private boolean rejectUpgrade(HttpUpgradeContext ctx) { + var headers = ctx.httpRequest().headers(); + // implement your business logic in here + } +} +---- +<1> The CDI beans implementing `HttpUpgradeCheck` interface can be either `@ApplicationScoped`, `@Singleton` or `@Dependent` beans, but never the `@RequestScoped` beans. +<2> Reject the HTTP upgrade. Initial HTTP handshake ends with the 400 Bad Request response status code. + +TIP: You can choose WebSocket endpoints to which the `HttpUpgradeCheck` is applied with the `HttpUpgradeCheck#appliesTo` method. + [[websocket-next-configuration-reference]] == Configuration reference diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index c9c67b9029829..8b690f688344a 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -36,6 +36,7 @@ import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; +import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.arc.deployment.ValidationPhaseBuildItem; import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem; import io.quarkus.arc.processor.Annotations; @@ -68,6 +69,7 @@ import io.quarkus.vertx.http.deployment.RouteBuildItem; import io.quarkus.vertx.http.runtime.HandlerType; import io.quarkus.vertx.http.runtime.HttpBuildTimeConfig; +import io.quarkus.websockets.next.HttpUpgradeCheck; import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.WebSocketClientConnection; import io.quarkus.websockets.next.WebSocketClientException; @@ -106,6 +108,7 @@ public class WebSocketProcessor { static final String SERVER_ENDPOINT_SUFFIX = "_WebSocketServerEndpoint"; static final String CLIENT_ENDPOINT_SUFFIX = "_WebSocketClientEndpoint"; static final String NESTED_SEPARATOR = "$_"; + static final DotName HTTP_UPGRADE_CHECK_NAME = DotName.createSimple(HttpUpgradeCheck.class); // Parameter names consist of alphanumeric characters and underscore private static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}"); @@ -424,6 +427,32 @@ public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildIt } } + @BuildStep + UnremovableBeanBuildItem makeHttpUpgradeChecksUnremovable() { + // we access the checks programmatically + return UnremovableBeanBuildItem.beanTypes(HTTP_UPGRADE_CHECK_NAME); + } + + @BuildStep + List validateHttpUpgradeCheckNotRequestScoped( + ValidationPhaseBuildItem validationPhase) { + return validationPhase + .getContext() + .beans() + .withBeanType(HTTP_UPGRADE_CHECK_NAME) + .filter(b -> { + var targetScope = BuiltinScope.from(b.getScope().getDotName()); + return BuiltinScope.APPLICATION != targetScope + && BuiltinScope.SINGLETON != targetScope + && BuiltinScope.DEPENDENT != targetScope; + }) + .stream() + .map(b -> new ValidationErrorBuildItem(new RuntimeException(("Bean '%s' scope is '%s', but the '%s' " + + "implementors must be one either `@ApplicationScoped', '@Singleton' or '@Dependent' beans") + .formatted(b.getBeanClass(), b.getScope().getDotName(), HTTP_UPGRADE_CHECK_NAME)))) + .toList(); + } + @BuildStep @Record(RUNTIME_INIT) void serverSyntheticBeans(WebSocketServerRecorder recorder, List generatedEndpoints, diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java index 3351c71033053..442e2cfa75712 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java @@ -68,7 +68,12 @@ public void testConnectionClosedWhenAuthExpires() { } else if (System.currentTimeMillis() > threeSecondsFromNow) { Assertions.fail("Authentication expired, therefore connection should had been closed"); } - client.sendAndAwaitReply("Hello #" + i + " from "); + try { + client.sendAndAwaitReply("Hello #" + i + " from "); + } catch (RuntimeException e) { + // this sometimes fails as connection is closed when waiting for the reply + break; + } } var receivedMessages = client.getMessages().stream().map(Buffer::toString).toList(); @@ -82,6 +87,8 @@ public void testConnectionClosedWhenAuthExpires() { .atMost(Duration.ofSeconds(1)) .untilAsserted(() -> assertTrue(Endpoint.CLOSED_MESSAGE.get() .startsWith("Connection closed with reason 'Authentication expired'"))); + + assertTrue(client.isClosed()); } } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/AbstractHttpUpgradeCheckTestBase.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/AbstractHttpUpgradeCheckTestBase.java new file mode 100644 index 0000000000000..daae34f623232 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/AbstractHttpUpgradeCheckTestBase.java @@ -0,0 +1,117 @@ +package io.quarkus.websockets.next.test.upgrade; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.time.Duration; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicBoolean; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.UpgradeRejectedException; +import io.vertx.core.http.WebSocketConnectOptions; + +public abstract class AbstractHttpUpgradeCheckTestBase { + + @Inject + Vertx vertx; + + @TestHTTPResource("opening") + URI openingUri; + + @TestHTTPResource("responding") + URI respondingUri; + + @TestHTTPResource("rejecting") + URI rejectingUri; + + @BeforeEach + public void cleanUp() { + Opening.OPENED.set(false); + OpeningHttpUpgradeCheck.INVOKED.set(0); + } + + @Test + public void testHttpUpgradeRejected() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect( + new WebSocketConnectOptions().addHeader(RejectingHttpUpgradeCheck.REJECT_HEADER, "ignored"), + rejectingUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("403"), root.getMessage()); + } + } + + @Test + public void testHttpUpgradePermitted() { + try (WSClient client = new WSClient(vertx)) { + client.connect(openingUri); + Awaitility.await().atMost(Duration.ofSeconds(2)).until(() -> OpeningHttpUpgradeCheck.INVOKED.get() == 1); + } + } + + @Test + public void testHttpUpgradeOkAndResponding() { + // test no HTTP Upgrade check rejected the upgrade or recorded value + try (WSClient client = new WSClient(vertx)) { + client.connect(new WebSocketConnectOptions(), respondingUri); + var response = client.sendAndAwaitReply("Ho").toString(); + assertEquals("Ho Hey", response); + assertEquals(0, OpeningHttpUpgradeCheck.INVOKED.get()); + } + } + + @WebSocket(path = "/rejecting", endpointId = "rejecting-id") + public static class Rejecting { + + @OnTextMessage + public void onMessage(String message) { + // do nothing + } + + } + + @WebSocket(path = "/opening", endpointId = "opening-id") + public static class Opening { + + static final AtomicBoolean OPENED = new AtomicBoolean(false); + + @OnTextMessage + public void onMessage(String message) { + // do nothing + } + + @OnOpen + void onOpen() { + OPENED.set(true); + } + + } + + @WebSocket(path = "/responding", endpointId = "closing-id") + public static class Responding { + + @OnTextMessage + public String onMessage(String message) { + return message + " Hey"; + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java new file mode 100644 index 0000000000000..cd00b64edd6e2 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java @@ -0,0 +1,223 @@ +package io.quarkus.websockets.next.test.upgrade; + +import static io.quarkus.websockets.next.test.upgrade.GlobalHttpUpgradeCheckTest.ChainHttpUpgradeCheckBase.TEST_CHECK_CHAIN; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.annotation.Priority; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Singleton; + +import org.assertj.core.api.Assertions; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.http.UpgradeRejectedException; +import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.ext.web.Router; + +public class GlobalHttpUpgradeCheckTest extends AbstractHttpUpgradeCheckTestBase { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> root + .addClasses(Opening.class, Responding.class, OpeningHttpUpgradeCheck.class, + RejectingHttpUpgradeCheck.class, WSClient.class, OpeningHttpUpgradeCheckBean.class, + RejectingHttpUpgradeCheckBean.class, ChainHttpUpgradeCheckBase.class, + ChainHttpUpgradeCheck4.class, ChainHttpUpgradeCheck3.class, ChainHttpUpgradeCheck2.class, + ChainHttpUpgradeCheck1.class, NullCheckResultCheck.class, Rejecting.class, + ResponseHeadersObserver.class)); + + @Test + public void testNullCheckResultNotAllowed() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect( + new WebSocketConnectOptions().addHeader(NullCheckResultCheck.NULL_CHECK, "ignored"), + openingUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("500"), root.getMessage()); + } + } + + @Test + public void testHttpUpgradeChecksOrdered() { + ChainHttpUpgradeCheckBase.INVOCATION_COUNT.set(0); + ResponseHeadersObserver.responseHeaders = null; + + // expect the checks are ordered by @Priority + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect( + new WebSocketConnectOptions().addHeader(TEST_CHECK_CHAIN, "ignored"), + openingUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("401"), root.getMessage()); + + Awaitility.await() + .atMost(Duration.ofSeconds(2)) + .until(() -> ResponseHeadersObserver.responseHeaders != null + && !ResponseHeadersObserver.responseHeaders.isEmpty()); + var headers = ResponseHeadersObserver.responseHeaders; + var orderedPriorities = headers + .entries() + .stream() + .filter(e -> "1".equals(e.getKey()) || "2".equals(e.getKey()) || "3".equals(e.getKey()) + || "4".equals(e.getKey())) + .sorted(Comparator.comparingInt(o -> Integer.parseInt(o.getKey()))) + .map(Map.Entry::getValue) + .map(Integer::parseInt) + .toList(); + assertEquals(4, orderedPriorities.size()); + + int prev = 1000000; + for (int next : orderedPriorities) { + if (prev <= next) { + Assertions.fail("HttpUpgradeChecks are not ordered: " + orderedPriorities); + } + prev = next; + } + } + } + + @Singleton + public static class OpeningHttpUpgradeCheckBean extends OpeningHttpUpgradeCheck { + + } + + @ApplicationScoped + public static class RejectingHttpUpgradeCheckBean extends RejectingHttpUpgradeCheck { + + } + + public static abstract class ChainHttpUpgradeCheckBase implements HttpUpgradeCheck { + + static final String TEST_CHECK_CHAIN = "test-check-chain"; + static final AtomicInteger INVOCATION_COUNT = new AtomicInteger(0); + + @Override + public Uni perform(HttpUpgradeContext request) { + if (identityPropagated(request) && testCheckChain(request)) { + return CheckResult.permitUpgrade(getResponseHeaders()); + } + return CheckResult.permitUpgrade(); + } + + protected Map> getResponseHeaders() { + return Map.of(Integer.toString(INVOCATION_COUNT.incrementAndGet()), List.of(Integer.toString(priority()))); + } + + protected abstract int priority(); + + protected static boolean testCheckChain(HttpUpgradeContext context) { + return context.httpRequest().headers().contains(TEST_CHECK_CHAIN); + } + + private static boolean identityPropagated(HttpUpgradeContext context) { + // point of this method is to check that identity is present in the context + return context.securityIdentity() != null && context.securityIdentity().isAnonymous(); + } + + } + + @Dependent + public static final class ChainHttpUpgradeCheck1 extends ChainHttpUpgradeCheckBase { + + @Override + public Uni perform(HttpUpgradeContext request) { + if (testCheckChain(request)) { + return CheckResult.rejectUpgrade(401, getResponseHeaders()); + } + return super.perform(request); + } + + @Override + protected int priority() { + // default priority + return 0; + } + } + + @Priority(10) + @Dependent + public static final class ChainHttpUpgradeCheck2 extends ChainHttpUpgradeCheckBase { + + @Override + protected int priority() { + return 10; + } + } + + @Priority(100) + @Dependent + public static final class ChainHttpUpgradeCheck3 extends ChainHttpUpgradeCheckBase { + + @Override + protected int priority() { + return 100; + } + } + + @Priority(1000) + @Dependent + public static final class ChainHttpUpgradeCheck4 extends ChainHttpUpgradeCheckBase { + + @Override + protected int priority() { + return 1000; + } + } + + @Dependent + public static final class NullCheckResultCheck implements HttpUpgradeCheck { + + static final String NULL_CHECK = "null-check"; + + @Override + public Uni perform(HttpUpgradeContext context) { + if (context.httpRequest().headers().contains(NULL_CHECK)) { + return Uni.createFrom().nullItem(); + } + return CheckResult.permitUpgrade(); + } + } + + public static final class ResponseHeadersObserver { + + static volatile MultiMap responseHeaders = null; + + void observer(@Observes Router router) { + router.route().order(0).handler(ctx -> { + ctx.addHeadersEndHandler(new Handler() { + @Override + public void handle(Void unused) { + responseHeaders = ctx.response().headers(); + } + }); + ctx.next(); + }); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/HttpUpgradeCheckHeaderMergingTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/HttpUpgradeCheckHeaderMergingTest.java new file mode 100644 index 0000000000000..cbec099986e8a --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/HttpUpgradeCheckHeaderMergingTest.java @@ -0,0 +1,94 @@ +package io.quarkus.websockets.next.test.upgrade; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import jakarta.enterprise.context.Dependent; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.restassured.RestAssured; +import io.restassured.http.Header; +import io.smallrye.mutiny.Uni; + +public class HttpUpgradeCheckHeaderMergingTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> root + .addClasses(Headers.class, Header1HttpUpgradeCheck.class, + Header2HttpUpgradeCheck.class, Header3HttpUpgradeCheck.class, WSClient.class)); + + @TestHTTPResource("headers") + URI headersUri; + + @Test + public void testHeadersMultiMap() { + // this is a way to test scenario where HttpUpgradeChecks set headers + // but the checks itself did not reject upgrade, the upgrade wasn't performed due to incorrect headers + var headers = RestAssured.given().get(headersUri).then().statusCode(400).extract().headers(); + + assertNotNull(headers); + assertTrue(headers.size() >= 3); + Stream.of("k", "k2", "k3").forEach(k -> { + assertNotNull(headers.getList(k)); + var vals = headers.getList(k).stream().map(Header::getValue).toList(); + assertEquals(4, vals.size(), vals.toString()); + assertTrue(vals.contains("val1"), vals.toString()); + assertTrue(vals.contains("val2"), vals.toString()); + assertTrue(vals.contains("val3"), vals.toString()); + assertTrue(vals.contains("val4"), vals.toString()); + }); + } + + @Dependent + public static class Header1HttpUpgradeCheck implements HttpUpgradeCheck { + + @Override + public Uni perform(HttpUpgradeContext context) { + return CheckResult.permitUpgrade(Map.of("k", List.of("val1"))); + } + } + + @Dependent + public static class Header2HttpUpgradeCheck implements HttpUpgradeCheck { + + @Override + public Uni perform(HttpUpgradeContext context) { + return CheckResult.permitUpgrade(Map.of("k", List.of("val2", "val3", "val4"), "k2", List.of("val1"))); + } + } + + @Dependent + public static class Header3HttpUpgradeCheck implements HttpUpgradeCheck { + + @Override + public Uni perform(HttpUpgradeContext context) { + return CheckResult.permitUpgrade( + Map.of("k3", List.of("val1", "val2", "val3", "val4"), "k2", List.of("val2", "val3", "val4"))); + } + } + + @WebSocket(path = "/headers") + public static class Headers { + + @OnTextMessage + public String onMessage(String message) { + return "Hola " + message; + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/LocalHttpUpgradeCheckTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/LocalHttpUpgradeCheckTest.java new file mode 100644 index 0000000000000..f8ad9c15bcff6 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/LocalHttpUpgradeCheckTest.java @@ -0,0 +1,44 @@ +package io.quarkus.websockets.next.test.upgrade; + +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.test.utils.WSClient; + +public class LocalHttpUpgradeCheckTest extends AbstractHttpUpgradeCheckTestBase { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> root + .addClasses(Opening.class, Responding.class, OpeningHttpUpgradeCheck.class, + RejectingHttpUpgradeCheck.class, WSClient.class, Rejecting.class, + AlwaysRejectingHttpUpgradeCheck.class, AlwaysInvokedOpeningHttpUpgradeCheck.class)); + + @Singleton + public static final class AlwaysInvokedOpeningHttpUpgradeCheck extends OpeningHttpUpgradeCheck { + @Override + protected boolean shouldCheckUpgrade(HttpUpgradeContext context) { + return true; + } + + @Override + public boolean appliesTo(String endpointId) { + return "opening-id".equals(endpointId); + } + } + + @Singleton + public static final class AlwaysRejectingHttpUpgradeCheck extends RejectingHttpUpgradeCheck { + @Override + protected boolean shouldCheckUpgrade(HttpUpgradeContext context) { + return true; + } + + @Override + public boolean appliesTo(String endpointId) { + return "rejecting-id".equals(endpointId); + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/OpeningHttpUpgradeCheck.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/OpeningHttpUpgradeCheck.java new file mode 100644 index 0000000000000..74b0a2ee495d3 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/OpeningHttpUpgradeCheck.java @@ -0,0 +1,23 @@ +package io.quarkus.websockets.next.test.upgrade; + +import java.util.concurrent.atomic.AtomicInteger; + +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.smallrye.mutiny.Uni; + +public class OpeningHttpUpgradeCheck implements HttpUpgradeCheck { + + public static final AtomicInteger INVOKED = new AtomicInteger(0); + + @Override + public Uni perform(HttpUpgradeContext context) { + if (shouldCheckUpgrade(context)) { + INVOKED.incrementAndGet(); + } + return CheckResult.permitUpgrade(); + } + + protected boolean shouldCheckUpgrade(HttpUpgradeContext context) { + return context.httpRequest().path().contains("/opening"); + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RejectingHttpUpgradeCheck.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RejectingHttpUpgradeCheck.java new file mode 100644 index 0000000000000..0cdea75934c63 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RejectingHttpUpgradeCheck.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next.test.upgrade; + +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.smallrye.mutiny.Uni; + +public class RejectingHttpUpgradeCheck implements HttpUpgradeCheck { + + static final String REJECT_HEADER = "reject"; + + @Override + public Uni perform(HttpUpgradeContext context) { + if (shouldCheckUpgrade(context)) { + return CheckResult.rejectUpgrade(403); + } + return CheckResult.permitUpgrade(); + } + + protected boolean shouldCheckUpgrade(HttpUpgradeContext context) { + return context.httpRequest().headers().contains(REJECT_HEADER); + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RequestScopedHttpUpgradeCheckValidationFailureTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RequestScopedHttpUpgradeCheckValidationFailureTest.java new file mode 100644 index 0000000000000..02fa164acc193 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/RequestScopedHttpUpgradeCheckValidationFailureTest.java @@ -0,0 +1,43 @@ +package io.quarkus.websockets.next.test.upgrade; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.enterprise.context.RequestScoped; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.smallrye.mutiny.Uni; + +public class RequestScopedHttpUpgradeCheckValidationFailureTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> root + .addClasses(RequestScopedHttpUpgradeCheck.class)) + .assertException(t -> { + assertTrue(t.getMessage().contains("RequestScopedHttpUpgradeCheck"), t.getMessage()); + assertTrue(t.getMessage().contains("jakarta.enterprise.context.RequestScoped"), t.getMessage()); + assertTrue(t.getMessage().contains( + "but the '%s' implementors must be one either `@ApplicationScoped', '@Singleton' or '@Dependent' beans" + .formatted(HttpUpgradeCheck.class.getName())), + t.getMessage()); + }); + + @Test + public void test() { + Assertions.fail(); + } + + @RequestScoped + public static class RequestScopedHttpUpgradeCheck implements HttpUpgradeCheck { + + @Override + public Uni perform(HttpUpgradeContext context) { + return CheckResult.permitUpgrade(); + } + } +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java new file mode 100644 index 0000000000000..4697d7785dc9d --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java @@ -0,0 +1,137 @@ +package io.quarkus.websockets.next; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.quarkus.security.identity.SecurityIdentity; +import io.smallrye.mutiny.Uni; +import io.vertx.core.http.HttpServerRequest; + +/** + * A check that controls which requests are allowed to upgrade the HTTP connection to a WebSocket connection. + * CDI beans implementing this interface are invoked on every request. + * The CDI beans implementing `HttpUpgradeCheck` interface can be either `@ApplicationScoped`, `@Singleton` + * or `@Dependent` beans, but never the `@RequestScoped` beans. + *

+ * The checks are called orderly according to a bean priority. + * When no priority is declared (for example with the `@jakarta.annotation.Priority` annotation), default priority is used. + * If one of the checks rejects the upgrade, remaining checks are not called. + */ +public interface HttpUpgradeCheck { + + /** + * This method inspects HTTP Upgrade context and either allows or denies upgrade to a WebSocket connection. + * + * @param context {@link HttpUpgradeContext} + * @return check result; must never be null + */ + Uni perform(HttpUpgradeContext context); + + /** + * Determines WebSocket endpoints this check is applied to. + * + * @param endpointId WebSocket endpoint id, @see {@link WebSocket#endpointId()} for more information + * @return true if this check should be applied on a WebSocket endpoint with given id + */ + default boolean appliesTo(String endpointId) { + return true; + } + + /** + * @param httpRequest {@link HttpServerRequest}; the HTTP 1.X request employing the 'Upgrade' header + * @param securityIdentity {@link SecurityIdentity}; the identity is null if the Quarkus Security extension is absent + */ + record HttpUpgradeContext(HttpServerRequest httpRequest, SecurityIdentity securityIdentity) { + } + + final class CheckResult { + + private static final CheckResult PERMIT_UPGRADE = new CheckResult(true, null, Map.of()); + + private final boolean upgradePermitted; + private final int httpResponseCode; + private final Map> responseHeaders; + + private CheckResult(boolean upgradePermitted, Integer httpResponseCode, Map> responseHeaders) { + this.upgradePermitted = upgradePermitted; + this.httpResponseCode = httpResponseCode == null ? 500 : httpResponseCode; + this.responseHeaders = toUnmodifiableMap(responseHeaders); + } + + public boolean isUpgradePermitted() { + return upgradePermitted; + } + + public int getHttpResponseCode() { + return httpResponseCode; + } + + public Map> getResponseHeaders() { + return this.responseHeaders; + } + + public CheckResult withHeaders(Map> responseHeaders) { + if (responseHeaders == null || responseHeaders.isEmpty()) { + return this; + } + + var newHeaders = new HashMap<>(responseHeaders); + this.responseHeaders.forEach((k, v) -> newHeaders.put(k, merge(v, newHeaders.get(k)))); + + return new CheckResult(this.upgradePermitted, this.httpResponseCode, newHeaders); + } + + public static Uni rejectUpgrade(Integer httpResponseCode, Map> responseHeaders) { + return Uni.createFrom().item(rejectUpgradeSync(httpResponseCode, responseHeaders)); + } + + public static Uni rejectUpgrade(Integer httpResponseCode) { + return rejectUpgrade(httpResponseCode, null); + } + + public static CheckResult rejectUpgradeSync(Integer httpResponseCode, Map> responseHeaders) { + return new CheckResult(false, httpResponseCode, responseHeaders); + } + + public static Uni permitUpgrade(Map> responseHeaders) { + return Uni.createFrom().item(permitUpgradeSync(responseHeaders)); + } + + public static CheckResult permitUpgradeSync(Map> responseHeaders) { + return new CheckResult(true, null, responseHeaders); + } + + public static Uni permitUpgrade() { + return Uni.createFrom().item(permitUpgradeSync()); + } + + public static CheckResult permitUpgradeSync() { + return PERMIT_UPGRADE; + } + + /** + * Merge two lists. + * + * @param a never null + * @param b nullable + * @return list containing both {@code a} and {@code b} (if present) + */ + private static List merge(List a, List b) { + if (b == null || b.isEmpty()) { + return a; + } + return Stream.concat(a.stream(), b.stream()).toList(); + } + + private static Map> toUnmodifiableMap(Map> responseHeaders) { + if (responseHeaders == null || responseHeaders.isEmpty()) { + return Map.of(); + } + var mutableMap = new HashMap<>(responseHeaders); + mutableMap.replaceAll((k, v) -> List.copyOf(v)); + return Map.copyOf(mutableMap); + } + } +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index 2878f921d680c..0715daaf2114d 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -1,5 +1,7 @@ package io.quarkus.websockets.next.runtime; +import java.util.ArrayList; +import java.util.List; import java.util.function.Consumer; import java.util.function.Supplier; @@ -14,6 +16,9 @@ import io.quarkus.security.identity.SecurityIdentity; import io.quarkus.vertx.core.runtime.VertxCoreRecorder; import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser; +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.quarkus.websockets.next.HttpUpgradeCheck.CheckResult; +import io.quarkus.websockets.next.HttpUpgradeCheck.HttpUpgradeContext; import io.quarkus.websockets.next.WebSocketServerException; import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.smallrye.common.vertx.VertxContext; @@ -88,8 +93,28 @@ public Handler createEndpointHandler(String generatedEndpointCla Codecs codecs = container.instance(Codecs.class).get(); return new Handler() { + private final HttpUpgradeCheck[] httpUpgradeChecks = getHttpUpgradeChecks(endpointId, container); + @Override public void handle(RoutingContext ctx) { + if (httpUpgradeChecks != null) { + checkHttpUpgrade(ctx).subscribe().with(result -> { + if (!result.getResponseHeaders().isEmpty()) { + result.getResponseHeaders().forEach((k, v) -> ctx.response().putHeader(k, v)); + } + + if (result.isUpgradePermitted()) { + httpUpgrade(ctx); + } else { + ctx.response().setStatusCode(result.getHttpResponseCode()).end(); + } + }, ctx::fail); + } else { + httpUpgrade(ctx); + } + } + + private void httpUpgrade(RoutingContext ctx) { Future future = ctx.request().toWebSocket(); future.onSuccess(ws -> { Vertx vertx = VertxCoreRecorder.getVertx().get(); @@ -106,9 +131,44 @@ public void handle(RoutingContext ctx) { () -> connectionManager.remove(generatedEndpointClass, connection)); }); } + + private Uni checkHttpUpgrade(RoutingContext ctx) { + SecurityIdentity identity = ctx.user() instanceof QuarkusHttpUser user ? user.getSecurityIdentity() : null; + return checkHttpUpgrade(new HttpUpgradeContext(ctx.request(), identity), httpUpgradeChecks, 0); + } + + private static Uni checkHttpUpgrade(HttpUpgradeContext ctx, + HttpUpgradeCheck[] checks, int idx) { + return checks[idx].perform(ctx).flatMap(res -> { + if (res == null) { + return Uni.createFrom().failure(new IllegalStateException( + "The '%s' returned null CheckResult, please make sure non-null value is returned" + .formatted(checks[idx]))); + } + if (idx < checks.length - 1 && res.isUpgradePermitted()) { + return checkHttpUpgrade(ctx, checks, idx + 1) + .map(n -> n.withHeaders(res.getResponseHeaders())); + } + return Uni.createFrom().item(res); + }); + } }; } + private static HttpUpgradeCheck[] getHttpUpgradeChecks(String endpointId, ArcContainer container) { + List httpUpgradeChecks = null; + for (var check : container.select(HttpUpgradeCheck.class)) { + if (!check.appliesTo(endpointId)) { + continue; + } + if (httpUpgradeChecks == null) { + httpUpgradeChecks = new ArrayList<>(); + } + httpUpgradeChecks.add(check); + } + return httpUpgradeChecks == null ? null : httpUpgradeChecks.toArray(new HttpUpgradeCheck[0]); + } + SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx, Vertx vertx, WebSocketConnectionImpl connection) { Instance currentIdentityAssociation = container.select(CurrentIdentityAssociation.class);