Skip to content

Commit

Permalink
Refactor MySQLSequenceIdInboundHandler (#32573)
Browse files Browse the repository at this point in the history
* Refactor MySQLSequenceIdInboundHandler

* Refactor MySQLSequenceIdInboundHandler
  • Loading branch information
terrymanu authored Aug 18, 2024
1 parent 58c3eb2 commit 4f04729
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,22 @@
package org.apache.shardingsphere.db.protocol.mysql.netty;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;

import java.util.concurrent.atomic.AtomicInteger;

/**
* Handle MySQL sequence ID before sending to downstream.
*/
public final class MySQLSequenceIdInboundHandler extends ChannelInboundHandlerAdapter {

public MySQLSequenceIdInboundHandler(final Channel channel) {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
}

@Override
public void channelRead(final ChannelHandlerContext context, final Object msg) {
ByteBuf byteBuf = (ByteBuf) msg;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
Expand All @@ -33,12 +34,15 @@

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;

class MySQLSequenceIdInboundHandlerTest {

@Test
void assertChannelReadWithFlowControl() {
EmbeddedChannel channel = new EmbeddedChannel(new FixtureOutboundHandler(), new ProxyFlowControlHandler(), new MySQLSequenceIdInboundHandler(), new FixtureInboundHandler());
EmbeddedChannel channel = new EmbeddedChannel(
new FixtureOutboundHandler(), new ProxyFlowControlHandler(), new MySQLSequenceIdInboundHandler(mock(Channel.class, RETURNS_DEEP_STUBS)), new FixtureInboundHandler());
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.writeInbound(Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]));
assertThat(channel.<ByteBuf>readOutbound().readUnsignedByte(), is((short) 1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
Expand Down Expand Up @@ -104,10 +103,9 @@ public synchronized void connect() {

@Override
protected void initChannel(final SocketChannel socketChannel) {
socketChannel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
socketChannel.pipeline().addLast(new ChannelAttrInitializer());
socketChannel.pipeline().addLast(new PacketCodec(new MySQLPacketCodecEngine()));
socketChannel.pipeline().addLast(new MySQLSequenceIdInboundHandler());
socketChannel.pipeline().addLast(new MySQLSequenceIdInboundHandler(socketChannel));
socketChannel.pipeline().addLast(new MySQLNegotiatePackageDecoder());
socketChannel.pipeline().addLast(new MySQLCommandPacketDecoder());
socketChannel.pipeline().addLast(new MySQLNegotiateHandler(connectInfo.getUsername(), connectInfo.getPassword(), responseCallback));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import lombok.Getter;
import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
import org.apache.shardingsphere.db.protocol.mysql.codec.MySQLPacketCodecEngine;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
import org.apache.shardingsphere.db.protocol.mysql.netty.MySQLSequenceIdInboundHandler;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
Expand All @@ -31,8 +30,6 @@
import org.apache.shardingsphere.proxy.frontend.netty.FrontendChannelInboundHandler;
import org.apache.shardingsphere.proxy.frontend.spi.DatabaseProtocolFrontendEngine;

import java.util.concurrent.atomic.AtomicInteger;

/**
* Frontend engine for MySQL.
*/
Expand All @@ -47,8 +44,7 @@ public final class MySQLFrontendEngine implements DatabaseProtocolFrontendEngine

@Override
public void initChannel(final Channel channel) {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.pipeline().addBefore(FrontendChannelInboundHandler.class.getSimpleName(), MySQLSequenceIdInboundHandler.class.getSimpleName(), new MySQLSequenceIdInboundHandler());
channel.pipeline().addBefore(FrontendChannelInboundHandler.class.getSimpleName(), MySQLSequenceIdInboundHandler.class.getSimpleName(), new MySQLSequenceIdInboundHandler(channel));
}

@Override
Expand Down

0 comments on commit 4f04729

Please sign in to comment.