diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index c6c4444aef4e7..ec6054118bc56 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -88,7 +88,6 @@ import io.quarkus.security.spi.runtime.SecurityCheck; import io.quarkus.vertx.http.deployment.RouteBuildItem; import io.quarkus.vertx.http.runtime.HandlerType; -import io.quarkus.vertx.http.runtime.HttpBuildTimeConfig; import io.quarkus.websockets.next.HttpUpgradeCheck; import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.WebSocketClientConnection; @@ -445,18 +444,11 @@ public String apply(String name) { @Record(RUNTIME_INIT) @BuildStep public void registerRoutes(WebSocketServerRecorder recorder, List generatedEndpoints, - HttpBuildTimeConfig httpConfig, Capabilities capabilities, BuildProducer routes) { for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer) .toList()) { - RouteBuildItem.Builder builder = RouteBuildItem.builder(); - if (capabilities.isPresent(Capability.SECURITY) && !httpConfig.auth.proactive) { - // Add a special handler so that it's possible to capture the SecurityIdentity before the HTTP upgrade - builder.routeFunction(endpoint.path, recorder.initializeSecurityHandler()); - } else { - builder.route(endpoint.path); - } - builder + RouteBuildItem.Builder builder = RouteBuildItem.builder() + .route(endpoint.path) .displayOnNotFoundPage("WebSocket Endpoint") .handlerType(HandlerType.NORMAL) .handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId)); diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java index cd00b64edd6e2..f4ce7202d1002 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/upgrade/GlobalHttpUpgradeCheckTest.java @@ -118,10 +118,12 @@ public static abstract class ChainHttpUpgradeCheckBase implements HttpUpgradeChe @Override public Uni perform(HttpUpgradeContext request) { - if (identityPropagated(request) && testCheckChain(request)) { - return CheckResult.permitUpgrade(getResponseHeaders()); - } - return CheckResult.permitUpgrade(); + return request.securityIdentity().chain(identity -> { + if (identity != null && identity.isAnonymous() && testCheckChain(request)) { + return CheckResult.permitUpgrade(getResponseHeaders()); + } + return CheckResult.permitUpgrade(); + }); } protected Map> getResponseHeaders() { @@ -134,11 +136,6 @@ protected static boolean testCheckChain(HttpUpgradeContext context) { return context.httpRequest().headers().contains(TEST_CHECK_CHAIN); } - private static boolean identityPropagated(HttpUpgradeContext context) { - // point of this method is to check that identity is present in the context - return context.securityIdentity() != null && context.securityIdentity().isAnonymous(); - } - } @Dependent diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java index 2e342aad4d00f..542efb782a146 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java @@ -48,7 +48,7 @@ default boolean appliesTo(String endpointId) { * @param securityIdentity {@link SecurityIdentity}; the identity is null if the Quarkus Security extension is absent * @param endpointId {@link WebSocket#endpointId()} */ - record HttpUpgradeContext(HttpServerRequest httpRequest, SecurityIdentity securityIdentity, String endpointId) { + record HttpUpgradeContext(HttpServerRequest httpRequest, Uni securityIdentity, String endpointId) { } final class CheckResult { diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecurityHttpUpgradeCheck.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecurityHttpUpgradeCheck.java index f87baabf3244e..91f50a62c9406 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecurityHttpUpgradeCheck.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecurityHttpUpgradeCheck.java @@ -26,11 +26,10 @@ public class SecurityHttpUpgradeCheck implements HttpUpgradeCheck { @Override public Uni perform(HttpUpgradeContext context) { - return endpointToCheck - .get(context.endpointId()) - .nonBlockingApply(context.securityIdentity(), (MethodDescription) null, null) + return context.securityIdentity().chain(identity -> endpointToCheck.get(context.endpointId()) + .nonBlockingApply(identity, (MethodDescription) null, null) .replaceWith(CheckResult::permitUpgradeSync) - .onFailure(SecurityException.class).recoverWithItem(this::rejectUpgrade); + .onFailure(SecurityException.class).recoverWithItem(this::rejectUpgrade)); } @Override diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index c1e464e4b0190..ff5030af7ee24 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -3,7 +3,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.function.Consumer; import java.util.function.Supplier; import jakarta.enterprise.inject.Instance; @@ -31,7 +30,6 @@ import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.http.ServerWebSocket; -import io.vertx.ext.web.Route; import io.vertx.ext.web.RoutingContext; @Recorder @@ -62,34 +60,6 @@ public Object get() { }; } - public Consumer initializeSecurityHandler() { - return new Consumer() { - - @Override - public void accept(Route route) { - // Force authentication so that it's possible to capture the SecurityIdentity before the HTTP upgrade - route.handler(new Handler() { - - @Override - public void handle(RoutingContext ctx) { - if (ctx.user() == null) { - Uni deferredIdentity = ctx - .> get(QuarkusHttpUser.DEFERRED_IDENTITY_KEY); - deferredIdentity.subscribe().with(i -> { - if (ctx.response().ended()) { - return; - } - ctx.next(); - }, ctx::fail); - } else { - ctx.next(); - } - } - }); - } - }; - } - public Handler createEndpointHandler(String generatedEndpointClass, String endpointId) { ArcContainer container = Arc.container(); ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); @@ -142,7 +112,13 @@ private void httpUpgrade(RoutingContext ctx) { } private Uni checkHttpUpgrade(RoutingContext ctx, String endpointId) { - SecurityIdentity identity = ctx.user() instanceof QuarkusHttpUser user ? user.getSecurityIdentity() : null; + QuarkusHttpUser user = (QuarkusHttpUser) ctx.user(); + Uni identity; + if (user == null) { + identity = ctx.> get(QuarkusHttpUser.DEFERRED_IDENTITY_KEY); + } else { + identity = Uni.createFrom().item(user.getSecurityIdentity()); + } return checkHttpUpgrade(new HttpUpgradeContext(ctx.request(), identity, endpointId), httpUpgradeChecks, 0); }