Skip to content

Commit

Permalink
WebSockets Next: security integration
Browse files Browse the repository at this point in the history
- fixes quarkusio#40312
- also create a new Vertx duplicated context for error handler
invocation
  • Loading branch information
mkouba committed May 9, 2024
1 parent ea5c6fd commit 7cab635
Show file tree
Hide file tree
Showing 13 changed files with 418 additions and 90 deletions.
10 changes: 10 additions & 0 deletions extensions/websockets-next/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
<artifactId>quarkus-test-vertx</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-security-deployment</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-security-test-utils</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import io.quarkus.arc.processor.DotNames;
import io.quarkus.arc.processor.InjectionPointInfo;
import io.quarkus.arc.processor.Types;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
Expand All @@ -65,6 +67,7 @@
import io.quarkus.vertx.http.deployment.HttpRootPathBuildItem;
import io.quarkus.vertx.http.deployment.RouteBuildItem;
import io.quarkus.vertx.http.runtime.HandlerType;
import io.quarkus.vertx.http.runtime.HttpBuildTimeConfig;
import io.quarkus.websockets.next.InboundProcessingMode;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketClientException;
Expand All @@ -79,6 +82,7 @@
import io.quarkus.websockets.next.runtime.ConnectionManager;
import io.quarkus.websockets.next.runtime.ContextSupport;
import io.quarkus.websockets.next.runtime.JsonTextMessageCodec;
import io.quarkus.websockets.next.runtime.SecuritySupport;
import io.quarkus.websockets.next.runtime.WebSocketClientRecorder;
import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpoint;
import io.quarkus.websockets.next.runtime.WebSocketConnectionBase;
Expand Down Expand Up @@ -399,12 +403,19 @@ public String apply(String name) {
@Record(RUNTIME_INIT)
@BuildStep
public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildItem httpRootPath,
List<GeneratedEndpointBuildItem> generatedEndpoints,
List<GeneratedEndpointBuildItem> generatedEndpoints, HttpBuildTimeConfig httpConfig, Capabilities capabilities,
BuildProducer<RouteBuildItem> routes) {
for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer)
.toList()) {
RouteBuildItem.Builder builder = RouteBuildItem.builder()
.route(httpRootPath.relativePath(endpoint.path))
RouteBuildItem.Builder builder = RouteBuildItem.builder();
String relativePath = httpRootPath.relativePath(endpoint.path);
if (capabilities.isPresent(Capability.SECURITY) && !httpConfig.auth.proactive) {
// Add a special handler so that it's possible to capture the SecurityIdentity before the HTTP upgrade
builder.routeFunction(relativePath, recorder.initializeSecurityHandler());
} else {
builder.route(relativePath);
}
builder
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId));
Expand Down Expand Up @@ -545,8 +556,8 @@ private void validateOnClose(Callback callback) {
* }
*
* public Echo_WebSocketEndpoint(WebSocketConnection connection, Codecs codecs,
* WebSocketRuntimeConfig config, ContextSupport contextSupport) {
* super(connection, codecs, config, contextSupport);
* WebSocketRuntimeConfig config, ContextSupport contextSupport, SecuritySupport securitySupport) {
* super(connection, codecs, config, contextSupport, securitySupport);
* }
*
* public Uni doOnTextMessage(String message) {
Expand Down Expand Up @@ -616,12 +627,12 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
.build();

MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnectionBase.class,
Codecs.class, ContextSupport.class);
Codecs.class, ContextSupport.class, SecuritySupport.class);
constructor.invokeSpecialMethod(
MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnectionBase.class,
Codecs.class, ContextSupport.class),
Codecs.class, ContextSupport.class, SecuritySupport.class),
constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1),
constructor.getMethodParam(2));
constructor.getMethodParam(2), constructor.getMethodParam(3));
constructor.returnNull();

MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode",
Expand Down Expand Up @@ -1043,7 +1054,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
return uniOnFailureDoOnError(endpointThis, method, callback, uniChain, endpoint, globalErrorHandlers);
}
} else if (callback.isReturnTypeMulti()) {
// return multiText(multi, broadcast, m -> {
// return multiText(multi, m -> {
// try {
// String text = encodeText(m);
// return sendText(buffer,broadcast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ String decodingError(BinaryDecodeException e) {
Uni<Void> runtimeProblem(RuntimeException e, WebSocketConnection connection) {
assertTrue(Context.isOnEventLoopThread());
assertEquals(connection.id(), this.connection.id());
// The request context from @OnBinaryMessage is reused
assertEquals("ok", requestBean.getState());
// A new request context is used
assertEquals("nok", requestBean.getState());
return connection.sendText(e.getMessage());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ String decodingError(BinaryDecodeException e) {
String runtimeProblem(RuntimeException e, WebSocketConnection connection) {
assertTrue(Context.isOnWorkerThread());
assertEquals(connection.id(), this.connection.id());
// The request context from @OnBinaryMessage is reused
assertEquals("ok", requestBean.getState());
// A new request context is used
assertEquals("nok", requestBean.getState());
return e.getMessage();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.quarkus.websockets.next.test.security;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CompletionException;

import jakarta.annotation.security.RolesAllowed;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.runtime.util.ExceptionUtil;
import io.quarkus.security.Authenticated;
import io.quarkus.security.ForbiddenException;
import io.quarkus.security.test.utils.TestIdentityController;
import io.quarkus.security.test.utils.TestIdentityProvider;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.UpgradeRejectedException;
import io.vertx.core.http.WebSocketConnectOptions;
import io.vertx.ext.auth.authentication.UsernamePasswordCredentials;

public class EagerSecurityTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addAsResource(new StringAsset("quarkus.http.auth.proactive=true\n" +
"quarkus.http.auth.permission.secured.paths=/end\n" +
"quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties")
.addClasses(WSClient.class, TestIdentityProvider.class, TestIdentityController.class));

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@BeforeAll
public static void setupUsers() {
TestIdentityController.resetRoles()
.add("admin", "admin", "admin")
.add("user", "user", "user");
}

@Test
public void testEndpoint() {
try (WSClient client = new WSClient(vertx)) {
CompletionException ce = assertThrows(CompletionException.class, () -> client.connect(endUri));
Throwable root = ExceptionUtil.getRootCause(ce);
assertTrue(root instanceof UpgradeRejectedException);
assertTrue(root.getMessage().contains("401"));
}
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("admin", "admin"), endUri);
client.sendAndAwait("hello");
client.waitForMessages(1);
assertEquals("hello", client.getMessages().get(0).toString());
}
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("user", "user"), endUri);
client.sendAndAwait("hello");
client.waitForMessages(1);
assertEquals("forbidden", client.getMessages().get(0).toString());
}
}

private WebSocketConnectOptions basicAuth(String username, String password) {
return new WebSocketConnectOptions().addHeader(HttpHeaders.AUTHORIZATION.toString(),
new UsernamePasswordCredentials(username, password).applyHttpChallenge(null).toHttpAuthorization());
}

@Authenticated
@WebSocket(path = "/end")
public static class Endpoint {

@RolesAllowed("admin")
@OnTextMessage
String echo(String message) {
return message;
}

@OnError
String error(ForbiddenException t) {
return "forbidden";
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.quarkus.websockets.next.test.security;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CompletionException;

import jakarta.annotation.security.RolesAllowed;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.runtime.util.ExceptionUtil;
import io.quarkus.security.Authenticated;
import io.quarkus.security.ForbiddenException;
import io.quarkus.security.test.utils.TestIdentityController;
import io.quarkus.security.test.utils.TestIdentityProvider;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.UpgradeRejectedException;
import io.vertx.core.http.WebSocketConnectOptions;
import io.vertx.ext.auth.authentication.UsernamePasswordCredentials;

public class LazySecurityTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addAsResource(new StringAsset("quarkus.http.auth.proactive=false\n" +
"quarkus.http.auth.permission.secured.paths=/end\n" +
"quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties")
.addClasses(WSClient.class, TestIdentityProvider.class, TestIdentityController.class));

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@BeforeAll
public static void setupUsers() {
TestIdentityController.resetRoles()
.add("admin", "admin", "admin")
.add("user", "user", "user");
}

@Test
public void testEndpoint() {
try (WSClient client = new WSClient(vertx)) {
CompletionException ce = assertThrows(CompletionException.class, () -> client.connect(endUri));
Throwable root = ExceptionUtil.getRootCause(ce);
assertTrue(root instanceof UpgradeRejectedException);
assertTrue(root.getMessage().contains("401"));
}
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("admin", "admin"), endUri);
client.sendAndAwait("hello");
client.waitForMessages(1);
assertEquals("hello", client.getMessages().get(0).toString());
}
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("user", "user"), endUri);
client.sendAndAwait("hello");
client.waitForMessages(1);
assertEquals("forbidden", client.getMessages().get(0).toString());
}
}

private WebSocketConnectOptions basicAuth(String username, String password) {
return new WebSocketConnectOptions().addHeader(HttpHeaders.AUTHORIZATION.toString(),
new UsernamePasswordCredentials(username, password).applyHttpChallenge(null).toHttpAuthorization());
}

@Authenticated
@WebSocket(path = "/end")
public static class Endpoint {

@RolesAllowed("admin")
@OnTextMessage
String echo(String message) {
return message;
}

@OnError
String error(ForbiddenException t) {
return "forbidden";
}

}

}
5 changes: 5 additions & 0 deletions extensions/websockets-next/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-jackson</artifactId>
</dependency>
<!-- Quarkus Security API -->
<dependency>
<groupId>io.quarkus.security</groupId>
<artifactId>quarkus-security</artifactId>
</dependency>
<!-- Test dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ void start() {
void start(ContextState requestContextState) {
LOG.debugf("Start contexts: %s", connection);
startSession();
// Activate a new request context
requestContext.activate(requestContextState);
}

Expand Down
Loading

0 comments on commit 7cab635

Please sign in to comment.