diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java index 9023669ffff6..f34c16a024a3 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java @@ -145,15 +145,16 @@ public void onOpen(CoreSession coreSession, Callback callback) container.notifySessionListeners((listener) -> listener.onWebSocketSessionOpened(session)); callback.succeeded(); + + if (openHandle != null) + autoDemand(); + else + session.getCoreSession().demand(); } catch (Throwable cause) { callback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " OPEN method error: " + cause.getMessage(), cause)); } - finally - { - autoDemand(); - } } private static MessageSink createMessageSink(Class sinkClass, WebSocketSession session, MethodHandle msgHandle, boolean autoDemanding) @@ -320,7 +321,7 @@ private void onPingFrame(Frame frame, Callback callback) public void succeed() { callback.succeeded(); - autoDemand(); + session.getCoreSession().demand(); } @Override @@ -328,6 +329,7 @@ public void fail(Throwable x) { // Ignore failures, we might be output closed and receive a PING. callback.succeeded(); + session.getCoreSession().demand(); } }); } @@ -355,7 +357,7 @@ private void onPongFrame(Frame frame, Callback callback) } else { - autoDemand(); + session.getCoreSession().demand(); } } @@ -384,7 +386,7 @@ private void acceptFrame(Frame frame, Callback callback) if (activeMessageSink == null) { callback.succeeded(); - autoDemand(); + session.getCoreSession().demand(); return; } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/pom.xml b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/pom.xml index c21c1f345ed4..a3533dcceb9b 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/pom.xml +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/pom.xml @@ -17,6 +17,11 @@ + + org.awaitility + awaitility + test + org.eclipse.jetty jetty-alpn-java-server diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java index 9f6c5366c03f..620707d1c6a9 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java @@ -15,12 +15,17 @@ import java.io.IOException; import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Frame; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.client.WebSocketClient; @@ -29,6 +34,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static org.awaitility.Awaitility.await; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertNull; @@ -55,9 +61,23 @@ public void onMessage(String message) throws IOException } } + @WebSocket(autoDemand = false) + public static class ListenerSocket implements Session.Listener + { + final List frames = new CopyOnWriteArrayList<>(); + + @Override + public void onWebSocketFrame(Frame frame, Callback callback) + { + frames.add(frame); + callback.succeed(); + } + } + private final Server server = new Server(); private final WebSocketClient client = new WebSocketClient(); private final SuspendSocket serverSocket = new SuspendSocket(); + private final ListenerSocket listenerSocket = new ListenerSocket(); private ServerConnector connector; @BeforeEach @@ -67,7 +87,10 @@ public void start() throws Exception server.addConnector(connector); WebSocketUpgradeHandler wsHandler = WebSocketUpgradeHandler.from(server, container -> - container.addMapping("/suspend", (rq, rs, cb) -> serverSocket)); + { + container.addMapping("/suspend", (rq, rs, cb) -> serverSocket); + container.addMapping("/listenerSocket", (rq, rs, cb) -> listenerSocket); + }); server.setHandler(wsHandler); server.start(); @@ -114,4 +137,27 @@ public void testNoDemandWhenProcessingFrame() throws Exception assertNull(clientSocket.error); assertNull(serverSocket.error); } + + @Test + public void testNoAutoDemand() throws Exception + { + URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/listenerSocket"); + ListenerSocket listenerSocket = new ListenerSocket(); + Future connect = client.connect(listenerSocket, uri); + Session session = connect.get(5, TimeUnit.SECONDS); + + session.sendPing(ByteBuffer.wrap("ping-0".getBytes(StandardCharsets.UTF_8)), Callback.NOOP); + session.sendText("test-text", Callback.NOOP); + session.sendPing(ByteBuffer.wrap("ping-1".getBytes(StandardCharsets.UTF_8)), Callback.NOOP); + session.close(); + + await().atMost(5, TimeUnit.SECONDS).until(listenerSocket.frames::size, is(3)); + Frame frame0 = listenerSocket.frames.get(0); + assertThat(frame0.getType(), is(Frame.Type.PONG)); + assertThat(StandardCharsets.UTF_8.decode(frame0.getPayload()).toString(), is("ping-0")); + Frame frame1 = listenerSocket.frames.get(1); + assertThat(frame1.getType(), is(Frame.Type.PONG)); + assertThat(StandardCharsets.UTF_8.decode(frame1.getPayload()).toString(), is("ping-1")); + assertThat(listenerSocket.frames.get(2).getType(), is(Frame.Type.CLOSE)); + } }