Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MySQLConstants #32575

Merged
merged 1 commit into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void encode(final ChannelHandlerContext context, final DatabasePacket mes
new MySQLErrPacket(new UnknownSQLException(ex).toSQLException()).write(payload);
} finally {
if (out.readableBytes() - PAYLOAD_LENGTH - SEQUENCE_LENGTH < MAX_PACKET_LENGTH) {
updateMessageHeader(out, context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().getAndIncrement());
updateMessageHeader(out, context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().getAndIncrement());
} else {
writeMultiPackets(context, out);
}
Expand All @@ -116,7 +116,7 @@ private void updateMessageHeader(final ByteBuf byteBuf, final int sequenceId) {
private void writeMultiPackets(final ChannelHandlerContext context, final ByteBuf byteBuf) {
int packetCount = byteBuf.skipBytes(PAYLOAD_LENGTH + SEQUENCE_LENGTH).readableBytes() / MAX_PACKET_LENGTH + 1;
CompositeByteBuf result = context.alloc().compositeBuffer(packetCount * 2);
AtomicInteger sequenceId = context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get();
AtomicInteger sequenceId = context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get();
for (int i = 0; i < packetCount; i++) {
ByteBuf header = context.alloc().ioBuffer(4, 4);
int packetLength = Math.min(byteBuf.readableBytes(), MAX_PACKET_LENGTH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class MySQLConstants {

public static final AttributeKey<AtomicInteger> MYSQL_SEQUENCE_ID = AttributeKey.valueOf("MYSQL_SEQUENCE_ID");
public static final AttributeKey<AtomicInteger> SEQUENCE_ID_ATTRIBUTE_KEY = AttributeKey.valueOf("MYSQL_SEQUENCE_ID");

public static final AttributeKey<MySQLCharacterSet> MYSQL_CHARACTER_SET_ATTRIBUTE_KEY = AttributeKey.valueOf(MySQLCharacterSet.class.getName());
public static final AttributeKey<MySQLCharacterSet> CHARACTER_SET_ATTRIBUTE_KEY = AttributeKey.valueOf(MySQLCharacterSet.class.getName());

public static final AttributeKey<Integer> MYSQL_OPTION_MULTI_STATEMENTS = AttributeKey.valueOf("MYSQL_OPTION_MULTI_STATEMENTS");
public static final AttributeKey<Integer> OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY = AttributeKey.valueOf("MYSQL_OPTION_MULTI_STATEMENTS");

/**
* Protocol version is always 0x0A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
public final class MySQLSequenceIdInboundHandler extends ChannelInboundHandlerAdapter {

public MySQLSequenceIdInboundHandler(final Channel channel) {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).set(new AtomicInteger());
}

@Override
public void channelRead(final ChannelHandlerContext context, final Object msg) {
ByteBuf byteBuf = (ByteBuf) msg;
short sequenceId = byteBuf.readUnsignedByte();
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(sequenceId + 1);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(sequenceId + 1);
context.fireChannelRead(byteBuf.readSlice(byteBuf.readableBytes()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MySQLPacketCodecEngineTest {
@BeforeEach
void setup() {
when(context.channel().attr(AttributeKey.<Charset>valueOf(Charset.class.getName())).get()).thenReturn(StandardCharsets.UTF_8);
when(context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get()).thenReturn(new AtomicInteger());
when(context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get()).thenReturn(new AtomicInteger());
}

@Test
Expand Down Expand Up @@ -135,7 +135,7 @@ void assertEncode() {
when(byteBuf.markWriterIndex()).thenReturn(byteBuf);
when(byteBuf.readableBytes()).thenReturn(8);
MySQLPacket actualMessage = mock(MySQLPacket.class);
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(1);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(1);
new MySQLPacketCodecEngine().encode(context, actualMessage, byteBuf);
verify(byteBuf).writeInt(0);
verify(byteBuf).markWriterIndex();
Expand Down Expand Up @@ -175,7 +175,7 @@ void assertEncodeOccursException() {
RuntimeException ex = mock(RuntimeException.class);
MySQLPacket actualMessage = mock(MySQLPacket.class);
doThrow(ex).when(actualMessage).write(any(MySQLPacketPayload.class));
context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(2);
context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(2);
new MySQLPacketCodecEngine().encode(context, actualMessage, byteBuf);
verify(byteBuf).writeInt(0);
verify(byteBuf).markWriterIndex();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MySQLSequenceIdInboundHandlerTest {
void assertChannelReadWithFlowControl() {
EmbeddedChannel channel = new EmbeddedChannel(
new FixtureOutboundHandler(), new ProxyFlowControlHandler(), new MySQLSequenceIdInboundHandler(mock(Channel.class, RETURNS_DEEP_STUBS)), new FixtureInboundHandler());
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).set(new AtomicInteger());
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).set(new AtomicInteger());
channel.writeInbound(Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]), Unpooled.wrappedBuffer(new byte[1]));
assertThat(channel.<ByteBuf>readOutbound().readUnsignedByte(), is((short) 1));
assertThat(channel.<ByteBuf>readOutbound().readUnsignedByte(), is((short) 1));
Expand All @@ -54,7 +54,7 @@ private static class FixtureOutboundHandler extends ChannelOutboundHandlerAdapte

@Override
public void write(final ChannelHandlerContext context, final Object msg, final ChannelPromise promise) {
byte sequenceId = (byte) context.channel().attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().getAndIncrement();
byte sequenceId = (byte) context.channel().attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().getAndIncrement();
context.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{sequenceId}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void dumpBinlog(final String binlogFileName, final long binlogPosition,
}

private void resetSequenceID() {
channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).get().set(0);
channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get().set(0);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ void setUp() throws InterruptedException {
return null;
});
when(channel.localAddress()).thenReturn(new InetSocketAddress("host", 3306));
when(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID).get()).thenReturn(new AtomicInteger());
when(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY).get()).thenReturn(new AtomicInteger());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ private AuthenticationResult authenticatePhaseFastPath(final ChannelHandlerConte
}

private void setMultiStatementsOption(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
context.channel().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).set(handshakeResponsePacket.getMultiStatementsOption());
context.channel().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).set(handshakeResponsePacket.getMultiStatementsOption());
}

private void setCharacterSet(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
MySQLCharacterSet characterSet = MySQLCharacterSet.findById(handshakeResponsePacket.getCharacterSet());
context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(characterSet.getCharset());
context.channel().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
context.channel().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
}

private boolean isClientPluginAuthenticate(final MySQLHandshakeResponse41Packet packet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public final class MySQLComSetOptionExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket> execute() {
connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).set(packet.getValue());
connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).set(packet.getValue());
return Collections.singleton(new MySQLOKPacket(ServerStatusFlagCalculator.calculateFor(connectionSession)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private MySQLServerPreparedStatement updateAndGetPreparedStatement() {

private Collection<DatabasePacket> processQuery(final QueryResponseHeader queryResponseHeader) {
responseType = ResponseType.QUERY;
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
return ResponsePacketBuilder.buildQueryResponsePackets(queryResponseHeader, characterSet, ServerStatusFlagCalculator.calculateFor(connectionSession));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ public Collection<DatabasePacket> execute() {

private void failedIfContainsMultiStatements() {
// TODO Multi statements should be identified by SQL Parser instead of checking if sql contains ";".
if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()
if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()
&& packet.getSQL().contains(";")) {
throw new UnsupportedPreparedStatementException();
}
Expand All @@ -108,7 +108,7 @@ private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlSt
int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession);
if (parameterCount > 0) {
result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public Collection<DatabasePacket> execute() throws SQLException {

private Collection<DatabasePacket> createColumnDefinition41Packets(final String databaseName) throws SQLException {
Collection<DatabasePacket> result = new LinkedList<>();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
while (databaseConnector.next()) {
String columnName = databaseConnector.getRowData().getCells().iterator().next().getData().toString();
result.add(new MySQLColumnDefinition41Packet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public MySQLComQueryPacketExecutor(final MySQLComQueryPacket packet, final Conne
SQLStatement sqlStatement = ProxySQLComQueryParser.parse(packet.getSQL(), databaseType, connectionSession);
proxyBackendHandler = areMultiStatements(connectionSession, sqlStatement, packet.getSQL()) ? new MySQLMultiStatementsHandler(connectionSession, sqlStatement, packet.getSQL())
: ProxyBackendHandlerFactory.newInstance(databaseType, packet.getSQL(), sqlStatement, connectionSession, packet.getHintValueContext());
characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
}

private boolean areMultiStatements(final ConnectionSession connectionSession, final SQLStatement sqlStatement, final String sql) {
Expand All @@ -73,8 +73,8 @@ private boolean areMultiStatements(final ConnectionSession connectionSession, fi
}

private boolean isMultiStatementsEnabled(final ConnectionSession connectionSession) {
return connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get();
return connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
&& MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get();
}

private boolean isSuitableMultiStatementsSQLStatement(final SQLStatement sqlStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void setUp() {
@Test
void assertInitChannel() {
engine.initChannel(channel);
verify(channel.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).set(any(AtomicInteger.class));
verify(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).set(any(AtomicInteger.class));
verify(channel.pipeline())
.addBefore(eq(FrontendChannelInboundHandler.class.getSimpleName()), eq(MySQLSequenceIdInboundHandler.class.getSimpleName()), isA(MySQLSequenceIdInboundHandler.class));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ void assertAuthenticationMethodMismatch() {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307));
when(channel.attr(CommonConstants.CHARSET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channel.attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(channelHandlerContext.channel()).thenReturn(channel);
when(payload.readInt1()).thenReturn(1);
when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
Expand Down Expand Up @@ -321,9 +321,9 @@ private Channel getChannel() {
Channel result = mock(Channel.class);
doReturn(getRemoteAddress()).when(result).remoteAddress();
when(result.attr(CommonConstants.CHARSET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_SEQUENCE_ID)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
when(result.attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MySQLCommandExecutorFactoryTest {

@BeforeEach
void setUp() {
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MySQLComSetOptionExecutorTest {
@Test
void assertExecute() {
when(packet.getValue()).thenReturn(MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON);
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(attribute);
when(connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(attribute);
Collection<DatabasePacket> actual = executor.execute();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(MySQLOKPacket.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MySQLComStmtExecuteExecutorTest {

@BeforeEach
void setUp() {
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
SQLStatementContext selectStatementContext = prepareSelectStatementContext();
when(connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ class MySQLComStmtPrepareExecutorTest {
@BeforeEach
void setup() {
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new ServerPreparedStatementRegistry());
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_UNICODE_CI);
when(connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_UNICODE_CI);
}

@Test
void assertPrepareMultiStatements() {
when(packet.getSQL()).thenReturn("update t set v=v+1 where id=1;update t set v=v+1 where id=2;update t set v=v+1 where id=3");
when(connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)).thenReturn(true);
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()).thenReturn(0);
when(connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(true);
when(connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()).thenReturn(0);
assertThrows(UnsupportedPreparedStatementException.class, () -> new MySQLComStmtPrepareExecutor(packet, connectionSession).execute());
}

Expand Down
Loading