Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle chunked requests on WebsocketServer #1589

Merged
merged 4 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rskj-core/src/main/java/co/rsk/RskContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,9 @@ private Web3WebSocketServer getWeb3WebSocketServer() {
rskSystemProperties.rpcWebSocketPort(),
jsonRpcHandler,
getJsonRpcWeb3ServerHandler(),
rskSystemProperties.rpcWebSocketServerWriteTimeoutSeconds()
rskSystemProperties.rpcWebSocketServerWriteTimeoutSeconds(),
rskSystemProperties.rpcWebSocketMaxFrameSize(),
rskSystemProperties.rpcWebSocketMaxAggregatedFrameSize()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ public class RskWebSocketServerProtocolHandler extends WebSocketServerProtocolHa
public static final String WRITE_TIMEOUT_REASON = "Exceeded write timout";
public static final int NORMAL_CLOSE_WEBSOCKET_STATUS = 1000;

public RskWebSocketServerProtocolHandler(String websocketPath) {
super(websocketPath);
public RskWebSocketServerProtocolHandler(String websocketPath, int maxFrameSize) {
// there are no subprotocols nor extensions
super(websocketPath, null, false, maxFrameSize);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if(cause instanceof WriteTimeoutException) {
ctx.writeAndFlush(new CloseWebSocketFrame(NORMAL_CLOSE_WEBSOCKET_STATUS, WRITE_TIMEOUT_REASON)).addListener(ChannelFutureListener.CLOSE);
ctx.writeAndFlush(new CloseWebSocketFrame(NORMAL_CLOSE_WEBSOCKET_STATUS, WRITE_TIMEOUT_REASON))
.addListener(ChannelFutureListener.CLOSE);
LOGGER.error("Write timeout exceeded, closing web socket channel", cause);
} else {
super.exceptionCaught(ctx, cause);
Expand Down
15 changes: 11 additions & 4 deletions rskj-core/src/main/java/co/rsk/rpc/netty/Web3WebSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -38,7 +39,6 @@

public class Web3WebSocketServer implements InternalService {
private static final Logger logger = LoggerFactory.getLogger(Web3WebSocketServer.class);
private static final int HTTP_MAX_CONTENT_LENGTH = 1024 * 1024 * 5;

private final InetAddress host;
private final int port;
Expand All @@ -48,20 +48,26 @@ public class Web3WebSocketServer implements InternalService {
private final EventLoopGroup workerGroup;
private @Nullable ChannelFuture webSocketChannel;
private final int serverWriteTimeoutSeconds;
private final int maxFrameSize;
private final int maxAggregatedFrameSize;

public Web3WebSocketServer(
InetAddress host,
int port,
RskWebSocketJsonRpcHandler webSocketJsonRpcHandler,
JsonRpcWeb3ServerHandler web3ServerHandler,
int serverWriteTimeoutSeconds) {
int serverWriteTimeoutSeconds,
int maxFrameSize,
int maxAggregatedFrameSize) {
this.host = host;
this.port = port;
this.webSocketJsonRpcHandler = webSocketJsonRpcHandler;
this.web3ServerHandler = web3ServerHandler;
this.bossGroup = new NioEventLoopGroup();
this.workerGroup = new NioEventLoopGroup();
this.serverWriteTimeoutSeconds = serverWriteTimeoutSeconds;
this.maxFrameSize = maxFrameSize;
this.maxAggregatedFrameSize = maxAggregatedFrameSize;
}

@Override
Expand All @@ -75,9 +81,10 @@ public void start() {
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new HttpServerCodec());
p.addLast(new HttpObjectAggregator(HTTP_MAX_CONTENT_LENGTH));
p.addLast(new HttpObjectAggregator(maxAggregatedFrameSize));
p.addLast(new WriteTimeoutHandler(serverWriteTimeoutSeconds, TimeUnit.SECONDS));
p.addLast(new RskWebSocketServerProtocolHandler("/websocket"));
p.addLast(new RskWebSocketServerProtocolHandler("/websocket", maxFrameSize));
p.addLast(new WebSocketFrameAggregator(maxAggregatedFrameSize));
p.addLast(webSocketJsonRpcHandler);
p.addLast(web3ServerHandler);
p.addLast(new Web3ResultWebSocketResponseHandler());
Expand Down
10 changes: 10 additions & 0 deletions rskj-core/src/main/java/org/ethereum/config/SystemProperties.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public abstract class SystemProperties {
private static final String PROPERTY_RPC_WEBSOCKET_ADDRESS = "rpc.providers.web.ws.bind_address";
private static final String PROPERTY_RPC_WEBSOCKET_PORT = "rpc.providers.web.ws.port";
private static final String PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS = "rpc.providers.web.ws.server_write_timeout_seconds";
private static final String PROPERTY_RPC_WEBSOCKET_SERVER_MAX_FRAME_SIZE = "rpc.providers.web.ws.max_frame_size";
private static final String PROPERTY_RPC_WEBSOCKET_SERVER_MAX_AGGREGATED_FRAME_SIZE = "rpc.providers.web.ws.max_aggregated_frame_size";

public static final String PROPERTY_PUBLIC_IP = "public.ip";
public static final String PROPERTY_BIND_ADDRESS = "bind_address";
Expand Down Expand Up @@ -617,6 +619,14 @@ public int rpcWebSocketServerWriteTimeoutSeconds() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS);
}

public int rpcWebSocketMaxFrameSize() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_SERVER_MAX_FRAME_SIZE);
}

public int rpcWebSocketMaxAggregatedFrameSize() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_SERVER_MAX_AGGREGATED_FRAME_SIZE);
}

public InetAddress rpcHttpBindAddress() {
return getWebBindAddress(PROPERTY_RPC_HTTP_ADDRESS);
}
Expand Down
2 changes: 2 additions & 0 deletions rskj-core/src/main/resources/expected.conf
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ rpc = {
bind_address = <bind_address>
port = <port>
server_write_timeout_seconds = <timeout>
max_aggregated_frame_size = <frame_size_bytes>
max_frame_size = <frame_size_bytes>
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rskj-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ rpc {
port = 4445
# Shuts down the server when it's not able to write a response after a certain period (expressed in seconds)
server_write_timeout_seconds = 30
max_aggregated_frame_size = 5242880 # 5 mb, a chunked request can accumulate data up to this frame size
max_frame_size = 65536 # 64 kb, maximum supported frame size
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public class Web3WebSocketServerTest {
private static JsonNodeFactory JSON_NODE_FACTORY = JsonNodeFactory.instance;
private static ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final int DEFAULT_WRITE_TIMEOUT_SECONDS = 30;
private static final int DEFAULT_MAX_FRAME_SIZE = 65536; // 64 kb
private static final int DEFAULT_MAX_AGGREGATED_FRAME_SIZE = 5242880; // 5 mb

private ExecutorService wsExecutor;

Expand All @@ -65,6 +67,20 @@ public void setup() {

@Test
public void smokeTest() throws Exception {
smokeTest(getJsonRpcDummyMessage("value"));
}

@Test
public void smokeTestWithBigJson() throws Exception {
smokeTest(getJsonRpcBigMessage());
}

@After
public void tearDown() {
wsExecutor.shutdown();
}

private void smokeTest(byte [] msg) throws Exception {
Web3 web3Mock = mock(Web3.class);
String mockResult = "output";
when(web3Mock.web3_sha3(anyString())).thenReturn(mockResult);
Expand All @@ -77,15 +93,21 @@ public void smokeTest() throws Exception {
RskWebSocketJsonRpcHandler handler = new RskWebSocketJsonRpcHandler(null, new JacksonBasedRpcSerializer());
JsonRpcWeb3ServerHandler serverHandler = new JsonRpcWeb3ServerHandler(web3Mock, filteredModules);
int serverWriteTimeoutSeconds = testSystemProperties.rpcWebSocketServerWriteTimeoutSeconds();
int maxFrameSize = testSystemProperties.rpcWebSocketMaxFrameSize();
int maxAggregatedFrameSize = testSystemProperties.rpcWebSocketMaxAggregatedFrameSize();

assertEquals(DEFAULT_WRITE_TIMEOUT_SECONDS, serverWriteTimeoutSeconds);
assertEquals(DEFAULT_MAX_FRAME_SIZE, maxFrameSize);
assertEquals(DEFAULT_MAX_AGGREGATED_FRAME_SIZE, maxAggregatedFrameSize);

Web3WebSocketServer websocketServer = new Web3WebSocketServer(
InetAddress.getLoopbackAddress(),
randomPort,
handler,
serverHandler,
serverWriteTimeoutSeconds
serverWriteTimeoutSeconds,
maxFrameSize,
maxAggregatedFrameSize
);
websocketServer.start();

Expand All @@ -103,7 +125,7 @@ public void smokeTest() throws Exception {
@Override
public void onOpen(WebSocket webSocket, Response response) {
wsExecutor.submit(() -> {
RequestBody body = RequestBody.create(WebSocket.TEXT, getJsonRpcDummyMessage());
RequestBody body = RequestBody.create(WebSocket.TEXT, msg);
try {
this.webSocket = webSocket;
this.webSocket.sendMessage(body);
Expand Down Expand Up @@ -154,17 +176,12 @@ public void onClose(int code, String reason) {
}
}

@After
public void tearDown() {
wsExecutor.shutdown();
}

private byte[] getJsonRpcDummyMessage() {
private byte[] getJsonRpcDummyMessage(String value) {
Map<String, JsonNode> jsonRpcRequestProperties = new HashMap<>();
jsonRpcRequestProperties.put("jsonrpc", JSON_NODE_FACTORY.textNode("2.0"));
jsonRpcRequestProperties.put("id", JSON_NODE_FACTORY.numberNode(13));
jsonRpcRequestProperties.put("method", JSON_NODE_FACTORY.textNode("web3_sha3"));
jsonRpcRequestProperties.put("params", JSON_NODE_FACTORY.arrayNode().add("value"));
jsonRpcRequestProperties.put("params", JSON_NODE_FACTORY.arrayNode().add(value));

byte[] request = new byte[0];
try {
Expand All @@ -176,4 +193,12 @@ private byte[] getJsonRpcDummyMessage() {
return request;

}

private byte[] getJsonRpcBigMessage() {
StringBuilder s = new StringBuilder();
for (int i = 0; i < 55; i++) {
s.append("thisisabigmessagethatwillbesentchunked");
}
return getJsonRpcDummyMessage(s.toString());
}
}