Skip to content

Commit

Permalink
Refactor MySQLConstants (#32575)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Aug 18, 2024
1 parent 4f04729 commit 63f6a1c
Show file tree
Hide file tree
Showing 20 changed files with 41 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void encode(final ChannelHandlerContext context, final DatabasePacket mes
new MySQLErrPacket(new UnknownSQLException(ex).toSQLException()).write(payload);
} finally {
if (out.readableBytes() - PAYLOAD_LENGTH - SEQUENCE_LENGTH < MAX_PACKET_LENGTH) {
updateMessageHeader(out, context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().getAndIncrement());
updateMessageHeader(out, context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().getAndIncrement());
} else {
writeMultiPackets(context, out);
}
Expand All @@ -116,7 +116,7 @@ private void updateMessageHeader(final ByteBuf byteBuf, final int sequenceId) {
private void writeMultiPackets(final ChannelHandlerContext context, final ByteBuf byteBuf) {
int packetCount = byteBuf.skipBytes(PAYLOAD_LENGTH + SEQUENCE_LENGTH).readableBytes() / MAX_PACKET_LENGTH + 1;
CompositeByteBuf result = context.alloc().compositeBuffer(packetCount * 2);
AtomicInteger sequenceId = context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get();
AtomicInteger sequenceId = context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get();
for (int i = 0; i < packetCount; i++) {
ByteBuf header = context.alloc().ioBuffer(4, 4);
int packetLength = Math.min(byteBuf.readableBytes(), MAX_PACKET_LENGTH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class MySQLConstants {

public static final AttributeKey<AtomicInteger> MYSQL_SEQUENCE_ID = AttributeKey.valueOf("MYSQL_SEQUENCE_ID");
public static final AttributeKey<AtomicInteger> SEQUENCE_ID_ATTRIBUTE_KEY = AttributeKey.valueOf("MYSQL_SEQUENCE_ID");

public static final AttributeKey<MySQLCharacterSet> MYSQL_CHARACTER_SET_ATTRIBUTE_KEY = AttributeKey.valueOf(MySQLCharacterSet.class.getName());
public static final AttributeKey<MySQLCharacterSet> CHARACTER_SET_ATTRIBUTE_KEY = AttributeKey.valueOf(MySQLCharacterSet.class.getName());

public static final AttributeKey<Integer> MYSQL_OPTION_MULTI_STATEMENTS = AttributeKey.valueOf("MYSQL_OPTION_MULTI_STATEMENTS");
public static final AttributeKey<Integer> OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY = AttributeKey.valueOf("MYSQL_OPTION_MULTI_STATEMENTS");

/**
* Protocol version is always 0x0A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
public final class MySQLSequenceIdInboundHandler extends ChannelInboundHandlerAdapter {

public MySQLSequenceIdInboundHandler(final Channel channel) {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).set(new AtomicInteger());
}

@Override
public void channelRead(final ChannelHandlerContext context, final Object msg) {
ByteBuf byteBuf = (ByteBuf) msg;
short sequenceId = byteBuf.readUnsignedByte();
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(sequenceId + 1);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(sequenceId + 1);
context.fireChannelRead(byteBuf.readSlice(byteBuf.readableBytes()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MySQLPacketCodecEngineTest {
@BeforeEach
void setup() {
when(context.channel().attr(AttributeKey.<Charset>valueOf(Charset.class.getName())).get()).thenReturn(StandardCharsets.UTF_8);
when(context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get()).thenReturn(new AtomicInteger());
when(context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get()).thenReturn(new AtomicInteger());
}

@Test
Expand Down Expand Up @@ -135,7 +135,7 @@ void assertEncode() {
when(byteBuf.markWriterIndex()).thenReturn(byteBuf);
when(byteBuf.readableBytes()).thenReturn(8);
MySQLPacket actualMessage = mock(MySQLPacket.class);
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(1);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(1);
new MySQLPacketCodecEngine().encode(context, actualMessage, byteBuf);
verify(byteBuf).writeInt(0);
verify(byteBuf).markWriterIndex();
Expand Down Expand Up @@ -175,7 +175,7 @@ void assertEncodeOccursException() {
RuntimeException ex = mock(RuntimeException.class);
MySQLPacket actualMessage = mock(MySQLPacket.class);
doThrow(ex).when(actualMessage).write(any(MySQLPacketPayload.class));
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(2);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(2);
new MySQLPacketCodecEngine().encode(context, actualMessage, byteBuf);
verify(byteBuf).writeInt(0);
verify(byteBuf).markWriterIndex();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MySQLSequenceIdInboundHandlerTest {
void assertChannelReadWithFlowControl() {
EmbeddedChannel channel = new EmbeddedChannel(
new FixtureOutboundHandler(), new ProxyFlowControlHandler(), new MySQLSequenceIdInboundHandler(mock(Channel.class, RETURNS_DEEP_STUBS)), new FixtureInboundHandler());
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).set(new AtomicInteger());
channel.writeInbound(Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]));
assertThat(channel.<ByteBuf>readOutbound().readUnsignedByte(), is((short) 1));
assertThat(channel.<ByteBuf>readOutbound().readUnsignedByte(), is((short) 1));
Expand All @@ -54,7 +54,7 @@ private static class FixtureOutboundHandler extends ChannelOutboundHandlerAdapte

@Override
public void write(final ChannelHandlerContext context, final Object msg, final ChannelPromise promise) {
byte sequenceId = (byte) context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().getAndIncrement();
byte sequenceId = (byte) context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().getAndIncrement();
context.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{sequenceId}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void dumpBinlog(final String binlogFileName, final long binlogPosition,
}

private void resetSequenceID() {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(0);
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(0);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ void setUp() throws InterruptedException {
return null;
});
when(channel.localAddress()).thenReturn(new InetSocketAddress("host", 3306));
when(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).get()).thenReturn(new AtomicInteger());
when(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get()).thenReturn(new AtomicInteger());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ private AuthenticationResult authenticatePhaseFastPath(final ChannelHandlerConte
}

private void setMultiStatementsOption(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
context.channel().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).set(handshakeResponsePacket.getMultiStatementsOption());
context.channel().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).set(handshakeResponsePacket.getMultiStatementsOption());
}

private void setCharacterSet(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
MySQLCharacterSet characterSet = MySQLCharacterSet.findById(handshakeResponsePacket.getCharacterSet());
context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(characterSet.getCharset());
context.channel().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
context.channel().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
}

private boolean isClientPluginAuthenticate(final MySQLHandshakeResponse41Packet packet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public final class MySQLComSetOptionExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket> execute() {
connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).set(packet.getValue());
connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).set(packet.getValue());
return Collections.singleton(new MySQLOKPacket(ServerStatusFlagCalculator.calculateFor(connectionSession)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private MySQLServerPreparedStatement updateAndGetPreparedStatement() {

private Collection<DatabasePacket> processQuery(final QueryResponseHeader queryResponseHeader) {
responseType = ResponseType.QUERY;
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
return ResponsePacketBuilder.buildQueryResponsePackets(queryResponseHeader, characterSet, ServerStatusFlagCalculator.calculateFor(connectionSession));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ public Collection<DatabasePacket> execute() {

private void failedIfContainsMultiStatements() {
// TODO Multi statements should be identified by SQL Parser instead of checking if sql contains ";".
if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()
if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()
&& packet.getSQL().contains(";")) {
throw new UnsupportedPreparedStatementException();
}
Expand All @@ -108,7 +108,7 @@ private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlSt
int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession);
if (parameterCount > 0) {
result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public Collection<DatabasePacket> execute() throws SQLException {

private Collection<DatabasePacket> createColumnDefinition41Packets(final String databaseName) throws SQLException {
Collection<DatabasePacket> result = new LinkedList<>();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
while (databaseConnector.next()) {
String columnName = databaseConnector.getRowData().getCells().iterator().next().getData().toString();
result.add(new MySQLColumnDefinition41Packet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public MySQLComQueryPacketExecutor(final MySQLComQueryPacket packet, final Conne
SQLStatement sqlStatement = ProxySQLComQueryParser.parse(packet.getSQL(), databaseType, connectionSession);
proxyBackendHandler = areMultiStatements(connectionSession, sqlStatement, packet.getSQL()) ? new MySQLMultiStatementsHandler(connectionSession, sqlStatement, packet.getSQL())
: ProxyBackendHandlerFactory.newInstance(databaseType, packet.getSQL(), sqlStatement, connectionSession, packet.getHintValueContext());
characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
}

private boolean areMultiStatements(final ConnectionSession connectionSession, final SQLStatement sqlStatement, final String sql) {
Expand All @@ -73,8 +73,8 @@ private boolean areMultiStatements(final ConnectionSession connectionSession, fi
}

private boolean isMultiStatementsEnabled(final ConnectionSession connectionSession) {
return connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get();
return connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get();
}

private boolean isSuitableMultiStatementsSQLStatement(final SQLStatement sqlStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void setUp() {
@Test
void assertInitChannel() {
engine.initChannel(channel);
verify(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).set(any(AtomicInteger.class));
verify(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).set(any(AtomicInteger.class));
verify(channel.pipeline())
.addBefore(eq(FrontendChannelInboundHandler.class.getSimpleName()), eq(MySQLSequenceIdInboundHandler.class.getSimpleName()), isA(MySQLSequenceIdInboundHandler.class));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ void assertAuthenticationMethodMismatch() {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307));
when(channel.attr(CommonConstants.CHARSET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channelHandlerContext.channel()).thenReturn(channel);
when(payload.readInt1()).thenReturn(1);
when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
Expand Down Expand Up @@ -321,9 +321,9 @@ private Channel getChannel() {
Channel result = mock(Channel.class);
doReturn(getRemoteAddress()).when(result).remoteAddress();
when(result.attr(CommonConstants.CHARSET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MySQLCommandExecutorFactoryTest {

@BeforeEach
void setUp() {
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MySQLComSetOptionExecutorTest {
@Test
void assertExecute() {
when(packet.getValue()).thenReturn(MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON);
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(attribute);
when(connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(attribute);
Collection<DatabasePacket> actual = executor.execute();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(MySQLOKPacket.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MySQLComStmtExecuteExecutorTest {

@BeforeEach
void setUp() {
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
SQLStatementContext selectStatementContext = prepareSelectStatementContext();
when(connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ class MySQLComStmtPrepareExecutorTest {
@BeforeEach
void setup() {
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new ServerPreparedStatementRegistry());
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_UNICODE_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_UNICODE_CI);
}

@Test
void assertPrepareMultiStatements() {
when(packet.getSQL()).thenReturn("update t set v=v+1 where id=1;update t set v=v+1 where id=2;update t set v=v+1 where id=3");
when(connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(true);
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()).thenReturn(0);
when(connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(true);
when(connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()).thenReturn(0);
assertThrows(UnsupportedPreparedStatementException.class, () -> new MySQLComStmtPrepareExecutor(packet, connectionSession).execute());
}

Expand Down
Loading

0 comments on commit 63f6a1c

Please sign in to comment.