From f8e42e7e8e506dbee4459ef485010b9c2da2629f Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Thu, 9 May 2024 12:59:21 +0200 Subject: [PATCH] WebSockets Next: security integration - when quarkus-security is present and quarkus.http.auth.proactive=false, then we force the authentication before the HTTP upgrade so that it's possible to capture the SecurityIdentity and set it afterwards for all endpoint callbacks - fixes #40312 - also create a new Vertx duplicated context for error handler invocation --- extensions/websockets-next/deployment/pom.xml | 10 ++ .../next/deployment/WebSocketProcessor.java | 29 ++-- .../next/test/errors/RuntimeErrorTest.java | 4 +- .../next/test/errors/UniFailureErrorTest.java | 4 +- .../next/test/security/AdminService.java | 14 ++ .../next/test/security/EagerSecurityTest.java | 59 +++++++ .../test/security/EagerSecurityUniTest.java | 60 +++++++ .../next/test/security/LazySecurityTest.java | 60 +++++++ .../test/security/LazySecurityUniTest.java | 60 +++++++ .../security/RbacServiceSecurityTest.java | 84 ++++++++++ .../next/test/security/SecurityTestBase.java | 71 +++++++++ .../next/test/security/UserService.java | 14 ++ extensions/websockets-next/runtime/pom.xml | 5 + .../next/runtime/ContextSupport.java | 1 - .../websockets/next/runtime/Endpoints.java | 16 +- .../next/runtime/SecuritySupport.java | 32 ++++ .../next/runtime/WebSocketConnectorImpl.java | 2 +- .../next/runtime/WebSocketEndpointBase.java | 147 ++++++++++-------- .../next/runtime/WebSocketServerRecorder.java | 54 ++++++- 19 files changed, 636 insertions(+), 90 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AdminService.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityUniTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityUniTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/RbacServiceSecurityTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/UserService.java create mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java diff --git a/extensions/websockets-next/deployment/pom.xml b/extensions/websockets-next/deployment/pom.xml index 9c33791094f42..78e90a6a61959 100644 --- a/extensions/websockets-next/deployment/pom.xml +++ b/extensions/websockets-next/deployment/pom.xml @@ -36,6 +36,16 @@ quarkus-test-vertx test + + io.quarkus + quarkus-security-deployment + test + + + io.quarkus + quarkus-security-test-utils + test + io.quarkus quarkus-junit5-internal diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 465873ae3cad0..c9c67b9029829 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -44,6 +44,8 @@ import io.quarkus.arc.processor.DotNames; import io.quarkus.arc.processor.InjectionPointInfo; import io.quarkus.arc.processor.Types; +import io.quarkus.deployment.Capabilities; +import io.quarkus.deployment.Capability; import io.quarkus.deployment.GeneratedClassGizmoAdaptor; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -65,6 +67,7 @@ import io.quarkus.vertx.http.deployment.HttpRootPathBuildItem; import io.quarkus.vertx.http.deployment.RouteBuildItem; import io.quarkus.vertx.http.runtime.HandlerType; +import io.quarkus.vertx.http.runtime.HttpBuildTimeConfig; import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.WebSocketClientConnection; import io.quarkus.websockets.next.WebSocketClientException; @@ -79,6 +82,7 @@ import io.quarkus.websockets.next.runtime.ConnectionManager; import io.quarkus.websockets.next.runtime.ContextSupport; import io.quarkus.websockets.next.runtime.JsonTextMessageCodec; +import io.quarkus.websockets.next.runtime.SecuritySupport; import io.quarkus.websockets.next.runtime.WebSocketClientRecorder; import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpoint; import io.quarkus.websockets.next.runtime.WebSocketConnectionBase; @@ -400,12 +404,19 @@ public String apply(String name) { @Record(RUNTIME_INIT) @BuildStep public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildItem httpRootPath, - List generatedEndpoints, + List generatedEndpoints, HttpBuildTimeConfig httpConfig, Capabilities capabilities, BuildProducer routes) { for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer) .toList()) { - RouteBuildItem.Builder builder = RouteBuildItem.builder() - .route(httpRootPath.relativePath(endpoint.path)) + RouteBuildItem.Builder builder = RouteBuildItem.builder(); + String relativePath = httpRootPath.relativePath(endpoint.path); + if (capabilities.isPresent(Capability.SECURITY) && !httpConfig.auth.proactive) { + // Add a special handler so that it's possible to capture the SecurityIdentity before the HTTP upgrade + builder.routeFunction(relativePath, recorder.initializeSecurityHandler()); + } else { + builder.route(relativePath); + } + builder .displayOnNotFoundPage("WebSocket Endpoint") .handlerType(HandlerType.NORMAL) .handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId)); @@ -546,8 +557,8 @@ private void validateOnClose(Callback callback) { * } * * public Echo_WebSocketEndpoint(WebSocketConnection connection, Codecs codecs, - * WebSocketRuntimeConfig config, ContextSupport contextSupport) { - * super(connection, codecs, config, contextSupport); + * WebSocketRuntimeConfig config, ContextSupport contextSupport, SecuritySupport securitySupport) { + * super(connection, codecs, config, contextSupport, securitySupport); * } * * public Uni doOnTextMessage(String message) { @@ -617,12 +628,12 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, .build(); MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnectionBase.class, - Codecs.class, ContextSupport.class); + Codecs.class, ContextSupport.class, SecuritySupport.class); constructor.invokeSpecialMethod( MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnectionBase.class, - Codecs.class, ContextSupport.class), + Codecs.class, ContextSupport.class, SecuritySupport.class), constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), - constructor.getMethodParam(2)); + constructor.getMethodParam(2), constructor.getMethodParam(3)); constructor.returnNull(); MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode", @@ -1044,7 +1055,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre return uniOnFailureDoOnError(endpointThis, method, callback, uniChain, endpoint, globalErrorHandlers); } } else if (callback.isReturnTypeMulti()) { - // return multiText(multi, broadcast, m -> { + // return multiText(multi, m -> { // try { // String text = encodeText(m); // return sendText(buffer,broadcast); diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java index a519c95ea9be3..420f0ba1515ef 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java @@ -80,8 +80,8 @@ String decodingError(BinaryDecodeException e) { Uni runtimeProblem(RuntimeException e, WebSocketConnection connection) { assertTrue(Context.isOnEventLoopThread()); assertEquals(connection.id(), this.connection.id()); - // The request context from @OnBinaryMessage is reused - assertEquals("ok", requestBean.getState()); + // A new request context is used + assertEquals("nok", requestBean.getState()); return connection.sendText(e.getMessage()); } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UniFailureErrorTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UniFailureErrorTest.java index 933f681c26fcc..17164eb98836c 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UniFailureErrorTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UniFailureErrorTest.java @@ -80,8 +80,8 @@ String decodingError(BinaryDecodeException e) { String runtimeProblem(RuntimeException e, WebSocketConnection connection) { assertTrue(Context.isOnWorkerThread()); assertEquals(connection.id(), this.connection.id()); - // The request context from @OnBinaryMessage is reused - assertEquals("ok", requestBean.getState()); + // A new request context is used + assertEquals("nok", requestBean.getState()); return e.getMessage(); } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AdminService.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AdminService.java new file mode 100644 index 0000000000000..38905495f4e66 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AdminService.java @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.enterprise.context.ApplicationScoped; + +@RolesAllowed("admin") +@ApplicationScoped +public class AdminService { + + public String ping() { + return "" + 24; + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityTest.java new file mode 100644 index 0000000000000..506c1a5a55cd2 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityTest.java @@ -0,0 +1,59 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; + +public class EagerSecurityTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addAsResource(new StringAsset("quarkus.http.auth.proactive=true\n" + + "quarkus.http.auth.permission.secured.paths=/end\n" + + "quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties") + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class)); + + @Authenticated + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityUniTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityUniTest.java new file mode 100644 index 0000000000000..809bacfdb0627 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/EagerSecurityUniTest.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; + +public class EagerSecurityUniTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addAsResource(new StringAsset("quarkus.http.auth.proactive=true\n" + + "quarkus.http.auth.permission.secured.paths=/end\n" + + "quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties") + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class)); + + @Authenticated + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + Uni echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return Uni.createFrom().item(message); + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java new file mode 100644 index 0000000000000..7d21f28dbc2c5 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.security.EagerSecurityTest.Endpoint; +import io.quarkus.websockets.next.test.utils.WSClient; + +public class LazySecurityTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addAsResource(new StringAsset("quarkus.http.auth.proactive=false\n" + + "quarkus.http.auth.permission.secured.paths=/end\n" + + "quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties") + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class)); + + @Authenticated + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityUniTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityUniTest.java new file mode 100644 index 0000000000000..cb968d397f890 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityUniTest.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; + +public class LazySecurityUniTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addAsResource(new StringAsset("quarkus.http.auth.proactive=false\n" + + "quarkus.http.auth.permission.secured.paths=/end\n" + + "quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties") + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class)); + + @Authenticated + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + Uni echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return Uni.createFrom().item(message); + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/RbacServiceSecurityTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/RbacServiceSecurityTest.java new file mode 100644 index 0000000000000..0207d3f1b03fd --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/RbacServiceSecurityTest.java @@ -0,0 +1,84 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.Set; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RbacServiceSecurityTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root.addClasses(Endpoint.class, AdminService.class, UserService.class, + TestIdentityProvider.class, TestIdentityController.class, WSClient.class)); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add("admin", "admin", "admin") + .add("user", "user", "user"); + } + + @Test + public void testEndpoint() { + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), endUri); + client.sendAndAwait("hello"); // admin service + client.sendAndAwait("hi"); // forbidden + client.waitForMessages(2); + assertEquals(Set.of("24", "forbidden"), Set.copyOf(client.getMessages().stream().map(Object::toString).toList())); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("user", "user"), endUri); + client.sendAndAwait("hello"); // forbidden + client.sendAndAwait("hi"); // user service + client.waitForMessages(2); + assertEquals(Set.of("42", "forbidden"), Set.copyOf(client.getMessages().stream().map(Object::toString).toList())); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + UserService userService; + + @Inject + AdminService adminService; + + @OnTextMessage + String echo(String message) { + return message.equals("hello") ? adminService.ping() : userService.ping(); + } + + @OnError + String error(ForbiddenException t) { + return "forbidden"; + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java new file mode 100644 index 0000000000000..a9c94143ae59b --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java @@ -0,0 +1,71 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CompletionException; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.HttpHeaders; +import io.vertx.core.http.UpgradeRejectedException; +import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.ext.auth.authentication.UsernamePasswordCredentials; + +public abstract class SecurityTestBase { + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add("admin", "admin", "admin") + .add("user", "user", "user"); + } + + @Test + public void testEndpoint() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, () -> client.connect(endUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertTrue(root instanceof UpgradeRejectedException); + assertTrue(root.getMessage().contains("401")); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), endUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("user", "user"), endUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("forbidden:user", client.getMessages().get(1).toString()); + } + } + + static WebSocketConnectOptions basicAuth(String username, String password) { + return new WebSocketConnectOptions().addHeader(HttpHeaders.AUTHORIZATION.toString(), + new UsernamePasswordCredentials(username, password).applyHttpChallenge(null).toHttpAuthorization()); + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/UserService.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/UserService.java new file mode 100644 index 0000000000000..b8e8045314511 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/UserService.java @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.security; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.enterprise.context.ApplicationScoped; + +@RolesAllowed("user") +@ApplicationScoped +public class UserService { + + public String ping() { + return "" + 42; + } + +} diff --git a/extensions/websockets-next/runtime/pom.xml b/extensions/websockets-next/runtime/pom.xml index 76f218d21b125..d913689652388 100644 --- a/extensions/websockets-next/runtime/pom.xml +++ b/extensions/websockets-next/runtime/pom.xml @@ -26,6 +26,11 @@ io.quarkus quarkus-jackson + + + io.quarkus.security + quarkus-security + org.junit.jupiter diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index 0b018b6fe2eaf..b36d4dc834b3e 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -36,7 +36,6 @@ void start() { void start(ContextState requestContextState) { LOG.debugf("Start contexts: %s", connection); startSession(); - // Activate a new request context requestContext.activate(requestContextState); } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index 85ab430d8dd52..826b5d43e1a69 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -25,7 +25,8 @@ class Endpoints { private static final Logger LOG = Logger.getLogger(Endpoints.class); static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection, - WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, Runnable onClose) { + WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, + SecuritySupport securitySupport, Runnable onClose) { Context context = vertx.getOrCreateContext(); @@ -38,7 +39,8 @@ static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSo container.requestContext()); // Create an endpoint that delegates callbacks to the endpoint bean - WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport); + WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport, + securitySupport); // A broadcast processor is only needed if Multi is consumed by the callback BroadcastProcessor textBroadcastProcessor = endpoint.consumedTextMultiType() != null @@ -118,6 +120,7 @@ public void handle(Void event) { } else { textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { contextSupport.start(); + securitySupport.start(); try { textBroadcastProcessor.onNext(endpoint.decodeTextMultiItem(m)); LOG.debugf("Text message >> Multi: %s", connection); @@ -146,6 +149,7 @@ public void handle(Void event) { } else { binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { contextSupport.start(); + securitySupport.start(); try { binaryBroadcastProcessor.onNext(endpoint.decodeBinaryMultiItem(m)); LOG.debugf("Binary message >> Multi: %s", connection); @@ -298,8 +302,7 @@ public void handle(Void event) { } private static WebSocketEndpoint createEndpoint(String endpointClassName, Context context, - WebSocketConnectionBase connection, - Codecs codecs, ContextSupport contextSupport) { + WebSocketConnectionBase connection, Codecs codecs, ContextSupport contextSupport, SecuritySupport securitySupport) { try { ClassLoader cl = Thread.currentThread().getContextClassLoader(); if (cl == null) { @@ -309,8 +312,9 @@ private static WebSocketEndpoint createEndpoint(String endpointClassName, Contex Class endpointClazz = (Class) cl .loadClass(endpointClassName); WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz - .getDeclaredConstructor(WebSocketConnectionBase.class, Codecs.class, ContextSupport.class) - .newInstance(connection, codecs, contextSupport); + .getDeclaredConstructor(WebSocketConnectionBase.class, Codecs.class, ContextSupport.class, + SecuritySupport.class) + .newInstance(connection, codecs, contextSupport, securitySupport); return endpoint; } catch (Exception e) { throw new WebSocketException("Unable to create endpoint instance: " + endpointClassName, e); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java new file mode 100644 index 0000000000000..8ec115e085e70 --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java @@ -0,0 +1,32 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.Objects; + +import jakarta.enterprise.inject.Instance; + +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.identity.SecurityIdentity; + +public class SecuritySupport { + + static final SecuritySupport NOOP = new SecuritySupport(null, null); + + private final Instance currentIdentity; + private final SecurityIdentity identity; + + SecuritySupport(Instance currentIdentity, SecurityIdentity identity) { + this.currentIdentity = currentIdentity; + this.identity = currentIdentity != null ? Objects.requireNonNull(identity) : identity; + } + + /** + * This method is called before an endpoint callback is invoked. + */ + void start() { + if (currentIdentity != null) { + CurrentIdentityAssociation current = currentIdentity.get(); + current.setIdentity(identity); + } + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index a4abe65f42162..d6281e5da71f4 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -115,7 +115,7 @@ public Uni connect() { connectionManager.add(clientEndpoint.generatedEndpointClass, connection); Endpoints.initialize(vertx, Arc.container(), codecs, connection, ws, - clientEndpoint.generatedEndpointClass, config.autoPingInterval(), + clientEndpoint.generatedEndpointClass, config.autoPingInterval(), SecuritySupport.NOOP, () -> { connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); client.close(); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index ed453f59a97c9..03d39284e0170 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -13,7 +13,6 @@ import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; import io.quarkus.arc.InjectableBean; -import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.virtual.threads.VirtualThreadsRecorder; import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete; @@ -42,15 +41,20 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private final ContextSupport contextSupport; + private final SecuritySupport securitySupport; + private final InjectableBean bean; + private final Object beanInstance; - public WebSocketEndpointBase(WebSocketConnectionBase connection, Codecs codecs, ContextSupport contextSupport) { + public WebSocketEndpointBase(WebSocketConnectionBase connection, Codecs codecs, ContextSupport contextSupport, + SecuritySupport securitySupport) { this.connection = connection; this.codecs = codecs; this.limiter = inboundProcessingMode() == InboundProcessingMode.SERIAL ? new ConcurrencyLimiter(connection) : null; this.container = Arc.container(); this.contextSupport = contextSupport; + this.securitySupport = securitySupport; InjectableBean bean = container.bean(beanIdentifier()); if (bean.getScope().equals(ApplicationScoped.class) || bean.getScope().equals(Singleton.class)) { @@ -105,18 +109,18 @@ private Future execute(M message, ExecutionModel executionModel, limiter.run(context, new Runnable() { @Override public void run() { - doExecute(context, promise, message, executionModel, action, terminateSession, complete::complete, + doExecute(context, message, executionModel, action, terminateSession, complete::complete, complete::failure); } }); } else { // No need to limit the concurrency - doExecute(context, promise, message, executionModel, action, terminateSession, promise::complete, promise::fail); + doExecute(context, message, executionModel, action, terminateSession, promise::complete, promise::fail); } return promise.future(); } - private void doExecute(Context context, Promise promise, M message, ExecutionModel executionModel, + private void doExecute(Context context, M message, ExecutionModel executionModel, Function> action, boolean terminateSession, Runnable onComplete, Consumer onFailure) { Handler contextSupportEnd = executionModel.isBlocking() ? new Handler() { @@ -133,6 +137,7 @@ public void handle(Void event) { public void run() { Context context = Vertx.currentContext(); contextSupport.start(); + securitySupport.start(); action.apply(message).subscribe().with( v -> { context.runOnContext(contextSupportEnd); @@ -150,6 +155,7 @@ public void run() { public Void call() { Context context = Vertx.currentContext(); contextSupport.start(); + securitySupport.start(); action.apply(message).subscribe().with( v -> { context.runOnContext(contextSupportEnd); @@ -165,6 +171,7 @@ public Void call() { } else { // Event loop contextSupport.start(); + securitySupport.start(); action.apply(message).subscribe().with( v -> { contextSupport.end(terminateSession); @@ -179,72 +186,76 @@ public Void call() { public Uni doErrorExecute(Throwable throwable, ExecutionModel executionModel, Function> action) { - // We need to capture the current request context state so that it can be activated - // when the error callback is executed - ContextState requestContextState = contextSupport.currentRequestContextState(); - Handler contextSupportEnd = new Handler() { - + Promise promise = Promise.promise(); + // Always exeute error handler on a new duplicated context + ContextSupport.createNewDuplicatedContext(Vertx.currentContext(), connection).runOnContext(new Handler() { @Override public void handle(Void event) { - contextSupport.end(false, false); - } - }; - contextSupportEnd.handle(null); - - Promise promise = Promise.promise(); - if (executionModel == ExecutionModel.VIRTUAL_THREAD) { - VirtualThreadsRecorder.getCurrent().execute(new Runnable() { - @Override - public void run() { - Context context = Vertx.currentContext(); - contextSupport.start(requestContextState); - action.apply(throwable).subscribe().with( - v -> { - context.runOnContext(contextSupportEnd); - promise.complete(); - }, - t -> { - context.runOnContext(contextSupportEnd); - promise.fail(t); - }); - } - }); - } else if (executionModel == ExecutionModel.WORKER_THREAD) { - Vertx.currentContext().executeBlocking(new Callable() { - @Override - public Void call() { - Context context = Vertx.currentContext(); - contextSupport.start(requestContextState); - action.apply(throwable).subscribe().with( - v -> { - context.runOnContext(contextSupportEnd); - promise.complete(); - }, - t -> { - context.runOnContext(contextSupportEnd); - promise.fail(t); - }); - return null; - } - }, false); - } else { - Vertx.currentContext().runOnContext(new Handler() { - @Override - public void handle(Void event) { - Context context = Vertx.currentContext(); - contextSupport.start(requestContextState); - action.apply(throwable).subscribe().with( - v -> { - context.runOnContext(contextSupportEnd); - promise.complete(); - }, - t -> { - context.runOnContext(contextSupportEnd); - promise.fail(t); - }); + Handler contextSupportEnd = new Handler() { + @Override + public void handle(Void event) { + contextSupport.end(false); + } + }; + + if (executionModel == ExecutionModel.VIRTUAL_THREAD) { + VirtualThreadsRecorder.getCurrent().execute(new Runnable() { + @Override + public void run() { + Context context = Vertx.currentContext(); + contextSupport.start(); + securitySupport.start(); + action.apply(throwable).subscribe().with( + v -> { + context.runOnContext(contextSupportEnd); + promise.complete(); + }, + t -> { + context.runOnContext(contextSupportEnd); + promise.fail(t); + }); + } + }); + } else if (executionModel == ExecutionModel.WORKER_THREAD) { + Vertx.currentContext().executeBlocking(new Callable() { + @Override + public Void call() { + Context context = Vertx.currentContext(); + contextSupport.start(); + securitySupport.start(); + action.apply(throwable).subscribe().with( + v -> { + context.runOnContext(contextSupportEnd); + promise.complete(); + }, + t -> { + context.runOnContext(contextSupportEnd); + promise.fail(t); + }); + return null; + } + }, false); + } else { + Vertx.currentContext().runOnContext(new Handler() { + @Override + public void handle(Void event) { + Context context = Vertx.currentContext(); + contextSupport.start(); + securitySupport.start(); + action.apply(throwable).subscribe().with( + v -> { + context.runOnContext(contextSupportEnd); + promise.complete(); + }, + t -> { + context.runOnContext(contextSupportEnd); + promise.fail(t); + }); + } + }); } - }); - } + } + }); return UniHelper.toUni(promise.future()); } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index e580cf85791e7..9384f8d60fc47 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -1,21 +1,29 @@ package io.quarkus.websockets.next.runtime; +import java.util.function.Consumer; import java.util.function.Supplier; +import jakarta.enterprise.inject.Instance; + import org.jboss.logging.Logger; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; import io.quarkus.runtime.annotations.Recorder; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.identity.SecurityIdentity; import io.quarkus.vertx.core.runtime.VertxCoreRecorder; +import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser; import io.quarkus.websockets.next.WebSocketServerException; import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Uni; import io.vertx.core.Context; import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.http.ServerWebSocket; +import io.vertx.ext.web.Route; import io.vertx.ext.web.RoutingContext; @Recorder @@ -46,6 +54,34 @@ public Object get() { }; } + public Consumer initializeSecurityHandler() { + return new Consumer() { + + @Override + public void accept(Route route) { + // Force authentication so that it's possible to capture the SecurityIdentity before the HTTP upgrade + route.handler(new Handler() { + + @Override + public void handle(RoutingContext ctx) { + if (ctx.user() == null) { + Uni deferredIdentity = ctx + .> get(QuarkusHttpUser.DEFERRED_IDENTITY_KEY); + deferredIdentity.subscribe().with(i -> { + if (ctx.response().ended()) { + return; + } + ctx.next(); + }, ctx::fail); + } else { + ctx.next(); + } + } + }); + } + }; + } + public Handler createEndpointHandler(String generatedEndpointClass, String endpointId) { ArcContainer container = Arc.container(); ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); @@ -54,6 +90,8 @@ public Handler createEndpointHandler(String generatedEndpointCla @Override public void handle(RoutingContext ctx) { + SecuritySupport securitySupport = initializeSecuritySupport(container, ctx); + Future future = ctx.request().toWebSocket(); future.onSuccess(ws -> { Vertx vertx = VertxCoreRecorder.getVertx().get(); @@ -64,10 +102,24 @@ public void handle(RoutingContext ctx) { LOG.debugf("Connection created: %s", connection); Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, - config.autoPingInterval(), () -> connectionManager.remove(generatedEndpointClass, connection)); + config.autoPingInterval(), securitySupport, + () -> connectionManager.remove(generatedEndpointClass, connection)); }); } }; } + SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx) { + Instance currentIdentityAssociation = container.select(CurrentIdentityAssociation.class); + if (currentIdentityAssociation.isResolvable()) { + // Security extension is present + // Obtain the current security identity from the handshake request + QuarkusHttpUser user = (QuarkusHttpUser) ctx.user(); + if (user != null) { + return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity()); + } + } + return SecuritySupport.NOOP; + } + }