From 32009a8fbfffbbdaa1be34db709d041faa5e6fe1 Mon Sep 17 00:00:00 2001 From: Muluo-cyan <1217593253@qq.com> Date: Fri, 11 Oct 2024 03:14:24 +0800 Subject: [PATCH] feature: add double token support for seata netty communication authentication --- .../seata/common/ConfigurationKeys.java | 30 ++++ .../apache/seata/common/util/StringUtils.java | 44 +++++- .../console/controller/AuthController.java | 7 +- .../apache/seata/core/auth/AuthResult.java | 61 ++++++++ .../seata/core/auth/AuthResultBuilder.java | 60 ++++++++ .../seata/core/auth/JwtAuthManager.java | 131 ++++++++++++++++++ .../seata/core/auth/RegisterHandler.java | 30 ++++ .../core/protocol/RegisterRMRequest.java | 12 ++ .../seata/core/protocol/ResultCode.java | 20 ++- .../rpc/DefaultServerMessageListenerImpl.java | 124 +++++++++-------- .../core/rpc/RegisterCheckAuthHandler.java | 11 +- .../netty/AbstractNettyRemotingClient.java | 85 +++++++----- .../rpc/netty/NettyClientChannelManager.java | 27 ++-- .../core/rpc/netty/NettyPoolableFactory.java | 78 +++++++---- .../core/rpc/netty/RmNettyRemotingClient.java | 53 ++++--- .../core/rpc/netty/TmNettyRemotingClient.java | 31 +++-- .../client/ClientOnResponseProcessor.java | 34 ++--- .../rpc/processor/server/RegRmProcessor.java | 51 ++++--- .../rpc/processor/server/RegTmProcessor.java | 49 ++++--- .../seata/core/protocol/ResultCodeTest.java | 26 +++- .../raft/RaftRegistryServiceImpl.java | 80 +++++------ .../apache/seata/rm/AbstractRMHandler.java | 87 +++++++++--- script/client/spring/application.properties | 5 + script/client/spring/application.yml | 9 ++ .../AbstractIdentifyResponseCodec.java | 42 ++++-- .../protocol/AbstractResultMessageCodec.java | 58 ++++---- .../RegisterTMRequestSerializerTest.java | 18 ++- .../server/auth/AbstractCheckAuthHandler.java | 23 ++- .../server/auth/DefaultCheckAuthHandler.java | 16 ++- .../server/auth/JwtCheckAuthHandler.java | 87 ++++++++++++ .../auth/utils/ManagerRegJwtTokenUtils.java | 131 ++++++++++++++++++ ...he.seata.core.rpc.RegisterCheckAuthHandler | 2 +- server/src/main/resources/application.yml | 7 +- 33 files changed, 1156 insertions(+), 373 deletions(-) create mode 100644 core/src/main/java/org/apache/seata/core/auth/AuthResult.java create mode 100644 core/src/main/java/org/apache/seata/core/auth/AuthResultBuilder.java create mode 100644 core/src/main/java/org/apache/seata/core/auth/JwtAuthManager.java create mode 100644 core/src/main/java/org/apache/seata/core/auth/RegisterHandler.java create mode 100644 server/src/main/java/org/apache/seata/server/auth/JwtCheckAuthHandler.java create mode 100644 server/src/main/java/org/apache/seata/server/auth/utils/ManagerRegJwtTokenUtils.java diff --git a/common/src/main/java/org/apache/seata/common/ConfigurationKeys.java b/common/src/main/java/org/apache/seata/common/ConfigurationKeys.java index ff8436b6dc9..a46c31358ef 100644 --- a/common/src/main/java/org/apache/seata/common/ConfigurationKeys.java +++ b/common/src/main/java/org/apache/seata/common/ConfigurationKeys.java @@ -68,6 +68,11 @@ public interface ConfigurationKeys { */ String SEATA_PREFIX = SEATA_FILE_ROOT_CONFIG + "."; + /** + * The constant SECURITY_PREFIX + */ + String SECURITY_PREFIX = "security."; + /** * The constant SERVICE_PREFIX. */ @@ -1014,6 +1019,31 @@ public interface ConfigurationKeys { */ String SERVER_APPLICATION_DATA_SIZE_CHECK = SERVER_PREFIX + "applicationDataLimitCheck"; + /** + * The constant SECURITY_USERNAME; + */ + String SECURITY_USERNME = SECURITY_PREFIX + "username"; + + /** + * The constant SECURITY_PASSWORD; + */ + String SECURITY_PASSWORD = SECURITY_PREFIX + "password"; + + /** + * The constant SECURITY_SECRET_KEY; + */ + String SECURITY_SECRET_KEY = SECURITY_PREFIX + "secretKey"; + + /** + * The constant SECURITY_ACCESS_TOKEN_VALID_TIME; + */ + String SECURITY_ACCESS_TOKEN_VALID_TIME = SECURITY_PREFIX + "accessTokenValidityInMilliseconds"; + + /** + * The constant SECURITY_REFRESH_TOKEN_VALID_TIME; + */ + String SECURITY_REFRESH_TOKEN_VALID_TIME = SECURITY_PREFIX + "refreshTokenValidityInMilliseconds"; + /** * The constant ROCKET_MQ_MSG_TIMEOUT */ diff --git a/common/src/main/java/org/apache/seata/common/util/StringUtils.java b/common/src/main/java/org/apache/seata/common/util/StringUtils.java index 05859cf4746..b8b81480234 100644 --- a/common/src/main/java/org/apache/seata/common/util/StringUtils.java +++ b/common/src/main/java/org/apache/seata/common/util/StringUtils.java @@ -16,6 +16,10 @@ */ package org.apache.seata.common.util; +import org.apache.seata.common.Constants; +import org.apache.seata.common.exception.ShouldNeverHappenException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.lang.annotation.Annotation; @@ -25,15 +29,14 @@ import java.text.SimpleDateFormat; import java.util.Collection; import java.util.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; -import org.apache.seata.common.Constants; -import org.apache.seata.common.exception.ShouldNeverHappenException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR; +import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR; /** * The type String utils. @@ -331,7 +334,7 @@ public static boolean isNotEmpty(final CharSequence cs) { /** * hump to Line or line to hump, only spring environment use - * + * * @param str str * @return string string */ @@ -446,4 +449,35 @@ public static boolean hasText(CharSequence str) { return false; } + public static HashMap string2Map(String inputString) { + HashMap resultMap = new HashMap<>(); + if (StringUtils.isBlank(inputString)) { + return resultMap; + } + String[] keyValuePairs = inputString.split(EXTRA_DATA_SPLIT_CHAR); + for (String pair : keyValuePairs) { + String[] keyValue = pair.trim().split(EXTRA_DATA_KV_CHAR); + if (keyValue.length == 2) { + resultMap.put(keyValue[0].trim(), keyValue[1].trim()); + } + } + return resultMap; + } + + public static String map2String(HashMap inputMap) { + if (inputMap == null || inputMap.isEmpty()) { + return ""; + } + StringBuilder resultString = new StringBuilder(); + for (Map.Entry entry : inputMap.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + String pair = key + EXTRA_DATA_KV_CHAR + value + EXTRA_DATA_SPLIT_CHAR; + resultString.append(pair); + } + if (resultString.length() > 0) { + resultString.deleteCharAt(resultString.length() - 1); + } + return resultString.toString(); + } } diff --git a/console/src/main/java/org/apache/seata/console/controller/AuthController.java b/console/src/main/java/org/apache/seata/console/controller/AuthController.java index 7f43218c4c8..c2088acb45c 100644 --- a/console/src/main/java/org/apache/seata/console/controller/AuthController.java +++ b/console/src/main/java/org/apache/seata/console/controller/AuthController.java @@ -16,11 +16,9 @@ */ package org.apache.seata.console.controller; -import javax.servlet.http.HttpServletResponse; - import org.apache.seata.common.result.Code; -import org.apache.seata.console.config.WebSecurityConfig; import org.apache.seata.common.result.SingleResult; +import org.apache.seata.console.config.WebSecurityConfig; import org.apache.seata.console.security.User; import org.apache.seata.console.utils.JwtTokenUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -34,6 +32,8 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import javax.servlet.http.HttpServletResponse; + /** * auth user * @@ -58,7 +58,6 @@ public class AuthController { public SingleResult login(HttpServletResponse response, @RequestBody User user) { UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken( user.getUsername(), user.getPassword()); - try { //AuthenticationManager(default ProviderManager) #authenticate check Authentication Authentication authentication = authenticationManager.authenticate(authenticationToken); diff --git a/core/src/main/java/org/apache/seata/core/auth/AuthResult.java b/core/src/main/java/org/apache/seata/core/auth/AuthResult.java new file mode 100644 index 00000000000..01764a836c1 --- /dev/null +++ b/core/src/main/java/org/apache/seata/core/auth/AuthResult.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.core.auth; + +import org.apache.seata.core.protocol.ResultCode; + +public class AuthResult { + private ResultCode resultCode; + + private String accessToken; + + private String refreshToken; + + public AuthResult() { + } + + public AuthResult(AuthResultBuilder builder) { + this.resultCode = builder.getResultCode(); + this.accessToken = builder.getAccessToken(); + this.refreshToken = builder.getRefreshToken(); + } + + public ResultCode getResultCode() { + return resultCode; + } + + public void setResultCode(ResultCode resultCode) { + this.resultCode = resultCode; + } + + public String getAccessToken() { + return accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + public String getRefreshToken() { + return refreshToken; + } + + public void setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + } + +} diff --git a/core/src/main/java/org/apache/seata/core/auth/AuthResultBuilder.java b/core/src/main/java/org/apache/seata/core/auth/AuthResultBuilder.java new file mode 100644 index 00000000000..e01941f3a60 --- /dev/null +++ b/core/src/main/java/org/apache/seata/core/auth/AuthResultBuilder.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.core.auth; + +import org.apache.seata.core.protocol.ResultCode; + +public class AuthResultBuilder { + private ResultCode resultCode; + private String accessToken; + private String refreshToken; + + public ResultCode getResultCode() { + return resultCode; + } + + public String getAccessToken() { + return accessToken; + } + + public String getRefreshToken() { + return refreshToken; + } + + // 设置 resultCode + public AuthResultBuilder setResultCode(ResultCode resultCode) { + this.resultCode = resultCode; + return this; + } + + // 设置 accessToken + public AuthResultBuilder setAccessToken(String accessToken) { + this.accessToken = accessToken; + return this; + } + + // 设置 refreshToken + public AuthResultBuilder setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + return this; + } + + // 构建最终的 AuthResult 对象 + public AuthResult build() { + return new AuthResult(this); + } +} diff --git a/core/src/main/java/org/apache/seata/core/auth/JwtAuthManager.java b/core/src/main/java/org/apache/seata/core/auth/JwtAuthManager.java new file mode 100644 index 00000000000..8dadaf58f83 --- /dev/null +++ b/core/src/main/java/org/apache/seata/core/auth/JwtAuthManager.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.core.auth; + + +import org.apache.seata.common.ConfigurationKeys; +import org.apache.seata.common.util.StringUtils; +import org.apache.seata.config.ConfigurationFactory; + +import java.util.HashMap; + +public class JwtAuthManager { + private String refreshToken; + + private String accessToken; + + private boolean isAccessTokenNearExpiration; + + private String username; + + private String password; + + public final static String PRO_USERNAME = "username"; + + public final static String PRO_PASSWORD = "password"; + + public final static String PRO_TOKEN = "token"; + + public final static String PRO_REFRESH_TOKEN = "refresh_token"; + + private static volatile JwtAuthManager instance; + + private JwtAuthManager() { + } + + public static JwtAuthManager getInstance() { + if (instance == null) { + synchronized (JwtAuthManager.class) { + if (instance == null) { + instance = new JwtAuthManager(); + instance.username = ConfigurationFactory.CURRENT_FILE_INSTANCE.getConfig(ConfigurationKeys.SECURITY_USERNME); + instance.password = ConfigurationFactory.CURRENT_FILE_INSTANCE.getConfig(ConfigurationKeys.SECURITY_PASSWORD); + instance.isAccessTokenNearExpiration = false; + } + } + } + return instance; + } + + public void init() { + } + + public boolean isAccessTokenNearExpiration() { + return isAccessTokenNearExpiration; + } + + public void setAccessTokenNearExpiration(boolean accessTokenNearExpiration) { + isAccessTokenNearExpiration = accessTokenNearExpiration; + } + + public String getAccessToken() { + return accessToken; + } + + public String getRefreshToken() { + return refreshToken; + } + + public String getUsername() { + return username; + } + + public void setUsername(String username) { + this.username = username; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } + + public void refreshToken(String newAccessToken, String newRefreshToken) { + if (newAccessToken != null) { + accessToken = newAccessToken; + isAccessTokenNearExpiration = false; + } + if (newRefreshToken != null) { + refreshToken = newRefreshToken; + } + } + + public void setAccessToken(String token) { + accessToken = token; + } + + public void setRefreshToken(String token) { + refreshToken = token; + } + + public String getAuthData() { + HashMap extraDataMap = new HashMap<>(); + extraDataMap.remove(PRO_TOKEN); + if (accessToken != null && !isAccessTokenNearExpiration) { + extraDataMap.put(PRO_TOKEN, accessToken); + } else if (refreshToken != null) { + extraDataMap.put(PRO_REFRESH_TOKEN, refreshToken); + } else if (username != null && password != null) { + extraDataMap.put(PRO_USERNAME, username); + extraDataMap.put(PRO_PASSWORD, password); + } + return StringUtils.map2String(extraDataMap); + } + +} diff --git a/core/src/main/java/org/apache/seata/core/auth/RegisterHandler.java b/core/src/main/java/org/apache/seata/core/auth/RegisterHandler.java new file mode 100644 index 00000000000..a2d3a78eff8 --- /dev/null +++ b/core/src/main/java/org/apache/seata/core/auth/RegisterHandler.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.core.auth; + +import io.netty.channel.Channel; +import org.apache.seata.core.protocol.RegisterRMResponse; + +public interface RegisterHandler { + /** + * On a register response received. + * + * @param response received response message + * @param channel channel of the response + */ + void onRegisterResponse(RegisterRMResponse response, Channel channel, Integer rpcId); +} diff --git a/core/src/main/java/org/apache/seata/core/protocol/RegisterRMRequest.java b/core/src/main/java/org/apache/seata/core/protocol/RegisterRMRequest.java index d5e6c3edead..597d1ec47f6 100644 --- a/core/src/main/java/org/apache/seata/core/protocol/RegisterRMRequest.java +++ b/core/src/main/java/org/apache/seata/core/protocol/RegisterRMRequest.java @@ -41,6 +41,18 @@ public RegisterRMRequest(String applicationId, String transactionServiceGroup) { super(applicationId, transactionServiceGroup); } + /** + * Instantiates a new Register rm request. + * + * @param applicationId the application id + * @param transactionServiceGroup the transaction service group + * @param extraData the extra data + */ + public RegisterRMRequest(String applicationId, String transactionServiceGroup, String extraData) { + super(applicationId, transactionServiceGroup, extraData); + } + + private String resourceIds; /** diff --git a/core/src/main/java/org/apache/seata/core/protocol/ResultCode.java b/core/src/main/java/org/apache/seata/core/protocol/ResultCode.java index d3338eb7f10..4916e10632a 100644 --- a/core/src/main/java/org/apache/seata/core/protocol/ResultCode.java +++ b/core/src/main/java/org/apache/seata/core/protocol/ResultCode.java @@ -32,7 +32,25 @@ public enum ResultCode { * Success result code. */ // Success - Success; + Success, + + /** + * Access token is expired result code. + */ + // AccessTokenExpired + AccessTokenExpired, + + /** + * Access token will expire soon result code. + */ + // AccessTokenNearExpiration + AccessTokenNearExpiration, + + /** + * Refresh token is expired result code. + */ + // RefreshTokenExpired + RefreshTokenExpired; /** * Get result code. diff --git a/core/src/main/java/org/apache/seata/core/rpc/DefaultServerMessageListenerImpl.java b/core/src/main/java/org/apache/seata/core/rpc/DefaultServerMessageListenerImpl.java index 276cca1b50e..c4a3a4bb705 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/DefaultServerMessageListenerImpl.java +++ b/core/src/main/java/org/apache/seata/core/rpc/DefaultServerMessageListenerImpl.java @@ -19,29 +19,15 @@ import io.netty.channel.ChannelHandlerContext; import org.apache.seata.common.thread.NamedThreadFactory; import org.apache.seata.common.util.NetUtil; -import org.apache.seata.core.protocol.AbstractMessage; -import org.apache.seata.core.protocol.AbstractResultMessage; -import org.apache.seata.core.protocol.HeartbeatMessage; -import org.apache.seata.core.protocol.MergeResultMessage; -import org.apache.seata.core.protocol.MergedWarpMessage; -import org.apache.seata.core.protocol.RegisterRMRequest; -import org.apache.seata.core.protocol.RegisterRMResponse; -import org.apache.seata.core.protocol.RegisterTMRequest; -import org.apache.seata.core.protocol.RegisterTMResponse; -import org.apache.seata.core.protocol.RpcMessage; -import org.apache.seata.core.protocol.Version; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.protocol.*; import org.apache.seata.core.rpc.netty.ChannelManager; -import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; /** * The type Default server message listener. @@ -74,11 +60,11 @@ public void onTrxMessage(RpcMessage request, ChannelHandlerContext ctx) { RpcContext rpcContext = ChannelManager.getContextFromIdentified(ctx.channel()); if (LOGGER.isDebugEnabled()) { LOGGER.debug("server received:{},clientIp:{},vgroup:{}", message, - NetUtil.toIpAddress(ctx.channel().remoteAddress()), rpcContext.getTransactionServiceGroup()); + NetUtil.toIpAddress(ctx.channel().remoteAddress()), rpcContext.getTransactionServiceGroup()); } else { try { logQueue.put(message + ",clientIp:" + NetUtil.toIpAddress(ctx.channel().remoteAddress()) + ",vgroup:" - + rpcContext.getTransactionServiceGroup()); + + rpcContext.getTransactionServiceGroup()); } catch (InterruptedException e) { LOGGER.error("put message to logQueue error: {}", e.getMessage(), e); } @@ -106,65 +92,89 @@ public void onTrxMessage(RpcMessage request, ChannelHandlerContext ctx) { } @Override - public void onRegRmMessage(RpcMessage request, ChannelHandlerContext ctx, RegisterCheckAuthHandler checkAuthHandler) { - RegisterRMRequest message = (RegisterRMRequest)request.getBody(); + public void onRegRmMessage(RpcMessage rpcMessage, ChannelHandlerContext ctx, RegisterCheckAuthHandler checkAuthHandler) { + RegisterRMRequest message = (RegisterRMRequest) rpcMessage.getBody(); String ipAndPort = NetUtil.toStringAddress(ctx.channel().remoteAddress()); - boolean isSuccess = false; - String errorInfo = StringUtils.EMPTY; + RegisterRMResponse response = new RegisterRMResponse(false); try { - if (checkAuthHandler == null || checkAuthHandler.regResourceManagerCheckAuth(message)) { + AuthResult authResult = (checkAuthHandler != null) ? checkAuthHandler.regResourceManagerCheckAuth(message) : null; + if (checkAuthHandler == null || authResult.getResultCode().equals(ResultCode.Success) + || authResult.getResultCode().equals(ResultCode.AccessTokenNearExpiration)) { ChannelManager.registerRMChannel(message, ctx.channel()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - isSuccess = true; + response.setIdentified(true); + response.setResultCode(checkAuthHandler == null ? ResultCode.Success : authResult.getResultCode()); + response.setExtraData(checkAuthHandler.fetchNewToken(authResult)); if (LOGGER.isDebugEnabled()) { - LOGGER.debug("checkAuth for client:{},vgroup:{},applicationId:{} is OK", ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + LOGGER.debug("RM checkAuth for client:{},vgroup:{},applicationId:{} is OK", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + } + } else { + if (authResult.getResultCode().equals(ResultCode.Failed)) { + response.setMsg("RM checkAuth failed!Please check your username/password or token."); + } else if (authResult.getResultCode().equals(ResultCode.AccessTokenExpired)) { + response.setMsg("RM checkAuth failed! The access token has been expired."); + } else if (authResult.getResultCode().equals(ResultCode.RefreshTokenExpired)) { + response.setMsg("RM checkAuth failed! The refresh token has been expired."); + } + response.setResultCode(authResult.getResultCode()); + if (LOGGER.isWarnEnabled()) { + LOGGER.warn("RM checkAuth for client:{},vgroup:{},applicationId:{} is FAIL", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } - } catch (Exception exx) { - isSuccess = false; - errorInfo = exx.getMessage(); - LOGGER.error("RM register fail, error message:{}", errorInfo); - } - RegisterRMResponse response = new RegisterRMResponse(isSuccess); - if (StringUtils.isNotEmpty(errorInfo)) { - response.setMsg(errorInfo); + } catch (IncompatibleVersionException e) { + LOGGER.error("RM register fail, error message:{}", e.getMessage()); + response.setResultCode(ResultCode.Failed); } - getServerMessageSender().sendAsyncResponse(request, ctx.channel(), response); - if (LOGGER.isInfoEnabled()) { + remotingServer.sendAsyncResponse(rpcMessage, ctx.channel(), response); + if (response.isIdentified() && LOGGER.isInfoEnabled()) { LOGGER.info("RM register success,message:{},channel:{},client version:{}", message, ctx.channel(), - message.getVersion()); + message.getVersion()); } } @Override - public void onRegTmMessage(RpcMessage request, ChannelHandlerContext ctx, RegisterCheckAuthHandler checkAuthHandler) { - RegisterTMRequest message = (RegisterTMRequest)request.getBody(); + public void onRegTmMessage(RpcMessage rpcMessage, ChannelHandlerContext ctx, RegisterCheckAuthHandler checkAuthHandler) { + RegisterTMRequest message = (RegisterTMRequest) rpcMessage.getBody(); String ipAndPort = NetUtil.toStringAddress(ctx.channel().remoteAddress()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - boolean isSuccess = false; - String errorInfo = StringUtils.EMPTY; + RegisterTMResponse response = new RegisterTMResponse(false); try { - if (checkAuthHandler == null || checkAuthHandler.regTransactionManagerCheckAuth(message)) { + AuthResult authResult = (checkAuthHandler != null) ? checkAuthHandler.regTransactionManagerCheckAuth(message) : null; + if (checkAuthHandler == null || authResult.getResultCode().equals(ResultCode.Success) + || authResult.getResultCode().equals(ResultCode.AccessTokenNearExpiration)) { ChannelManager.registerTMChannel(message, ctx.channel()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - isSuccess = true; + response.setIdentified(true); + response.setResultCode(checkAuthHandler == null ? ResultCode.Success : authResult.getResultCode()); + response.setExtraData(checkAuthHandler.fetchNewToken(authResult)); if (LOGGER.isDebugEnabled()) { - LOGGER.debug("checkAuth for client:{},vgroup:{},applicationId:{} is OK", ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + LOGGER.debug("TM checkAuth for client:{},vgroup:{},applicationId:{} is OK", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + } + } else { + if (authResult.getResultCode().equals(ResultCode.Failed)) { + response.setMsg("TM checkAuth failed!Please check your username/password."); + } else if (authResult.getResultCode().equals(ResultCode.AccessTokenExpired)) { + response.setMsg("TM checkAuth failed! The access token has been expired."); + } else if (authResult.getResultCode().equals(ResultCode.RefreshTokenExpired)) { + response.setMsg("TM checkAuth failed! The refresh token has been expired."); + } + response.setResultCode(authResult.getResultCode()); + if (LOGGER.isWarnEnabled()) { + LOGGER.warn("TM checkAuth for client:{},vgroup:{},applicationId:{} is FAIL", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } - } catch (Exception exx) { - isSuccess = false; - errorInfo = exx.getMessage(); - LOGGER.error("TM register fail, error message:{}", errorInfo); - } - RegisterTMResponse response = new RegisterTMResponse(isSuccess); - if (StringUtils.isNotEmpty(errorInfo)) { - response.setMsg(errorInfo); + } catch (IncompatibleVersionException e) { + LOGGER.error("TM register fail, error message:{}", e.getMessage()); + response.setResultCode(ResultCode.Failed); } - getServerMessageSender().sendAsyncResponse(request, ctx.channel(), response); - if (LOGGER.isInfoEnabled()) { + remotingServer.sendAsyncResponse(rpcMessage, ctx.channel(), response); + if (response.isIdentified() && LOGGER.isInfoEnabled()) { LOGGER.info("TM register success,message:{},channel:{},client version:{}", message, ctx.channel(), - message.getVersion()); + message.getVersion()); } } @@ -185,8 +195,8 @@ public void onCheckMessage(RpcMessage request, ChannelHandlerContext ctx) { */ public void init() { ExecutorService mergeSendExecutorService = new ThreadPoolExecutor(MAX_LOG_SEND_THREAD, MAX_LOG_SEND_THREAD, - KEEP_ALIVE_TIME, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(), - new NamedThreadFactory(THREAD_PREFIX, MAX_LOG_SEND_THREAD, true)); + KEEP_ALIVE_TIME, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(), + new NamedThreadFactory(THREAD_PREFIX, MAX_LOG_SEND_THREAD, true)); mergeSendExecutorService.submit(new BatchLogRunnable()); } diff --git a/core/src/main/java/org/apache/seata/core/rpc/RegisterCheckAuthHandler.java b/core/src/main/java/org/apache/seata/core/rpc/RegisterCheckAuthHandler.java index c17cbbad1f4..edfd3b5e476 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/RegisterCheckAuthHandler.java +++ b/core/src/main/java/org/apache/seata/core/rpc/RegisterCheckAuthHandler.java @@ -16,6 +16,7 @@ */ package org.apache.seata.core.rpc; +import org.apache.seata.core.auth.AuthResult; import org.apache.seata.core.protocol.RegisterRMRequest; import org.apache.seata.core.protocol.RegisterTMRequest; @@ -31,7 +32,7 @@ public interface RegisterCheckAuthHandler { * @param request the request * @return the boolean */ - boolean regTransactionManagerCheckAuth(RegisterTMRequest request); + AuthResult regTransactionManagerCheckAuth(RegisterTMRequest request); /** * Reg resource manager check auth boolean. @@ -39,5 +40,11 @@ public interface RegisterCheckAuthHandler { * @param request the request * @return the boolean */ - boolean regResourceManagerCheckAuth(RegisterRMRequest request); + AuthResult regResourceManagerCheckAuth(RegisterRMRequest request); + + /** + * Fetch new token + * @return the String + */ + String fetchNewToken(AuthResult authResult) ; } diff --git a/core/src/main/java/org/apache/seata/core/rpc/netty/AbstractNettyRemotingClient.java b/core/src/main/java/org/apache/seata/core/rpc/netty/AbstractNettyRemotingClient.java index 248e8f48f6d..c6e6e34f746 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/netty/AbstractNettyRemotingClient.java +++ b/core/src/main/java/org/apache/seata/core/rpc/netty/AbstractNettyRemotingClient.java @@ -16,19 +16,6 @@ */ package org.apache.seata.core.rpc.netty; -import java.lang.reflect.Field; -import java.net.InetSocketAddress; -import java.util.List; -import java.util.Map; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.function.Function; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandler.Sharable; @@ -43,13 +30,8 @@ import org.apache.seata.common.util.CollectionUtils; import org.apache.seata.common.util.NetUtil; import org.apache.seata.common.util.StringUtils; -import org.apache.seata.core.protocol.AbstractMessage; -import org.apache.seata.core.protocol.HeartbeatMessage; -import org.apache.seata.core.protocol.MergeMessage; -import org.apache.seata.core.protocol.MergedWarpMessage; -import org.apache.seata.core.protocol.MessageFuture; -import org.apache.seata.core.protocol.ProtocolConstants; -import org.apache.seata.core.protocol.RpcMessage; +import org.apache.seata.core.auth.JwtAuthManager; +import org.apache.seata.core.protocol.*; import org.apache.seata.core.protocol.transaction.AbstractGlobalEndRequest; import org.apache.seata.core.protocol.transaction.BranchRegisterRequest; import org.apache.seata.core.protocol.transaction.BranchReportRequest; @@ -63,6 +45,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.*; +import java.util.function.Function; + import static org.apache.seata.common.exception.FrameworkErrorCode.NoAvailableService; /** @@ -72,6 +62,8 @@ public abstract class AbstractNettyRemotingClient extends AbstractNettyRemoting implements RemotingClient { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractNettyRemotingClient.class); + private static final String PRO_TOKEN = "token"; + private static final String PRO_REFRESH_TOKEN = "refresh_token"; private static final String MSG_ID_PREFIX = "msgId:"; private static final String FUTURES_PREFIX = "futures:"; private static final String SINGLE_LOG_POSTFIX = ";"; @@ -89,6 +81,11 @@ public abstract class AbstractNettyRemotingClient extends AbstractNettyRemoting */ protected final Map mergeMsgMap = new ConcurrentHashMap<>(); + /** + * When sending message type is {@link RegisterRMRequest}, will be stored to regRmRequestMap. + */ + private final Map regRmRequestMap = new ConcurrentHashMap<>(); + /** * When batch sending is enabled, the message will be stored to basketMap * Send via asynchronous thread {@link AbstractNettyRemotingClient.MergedSendRunnable} @@ -113,10 +110,10 @@ public void init() { }, SCHEDULE_DELAY_MILLS, SCHEDULE_INTERVAL_MILLS, TimeUnit.MILLISECONDS); if (this.isEnableClientBatchSendRequest()) { mergeSendExecutorService = new ThreadPoolExecutor(MAX_MERGE_SEND_THREAD, - MAX_MERGE_SEND_THREAD, - KEEP_ALIVE_TIME, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue<>(), - new NamedThreadFactory(getThreadPrefix(), MAX_MERGE_SEND_THREAD)); + MAX_MERGE_SEND_THREAD, + KEEP_ALIVE_TIME, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(), + new NamedThreadFactory(getThreadPrefix(), MAX_MERGE_SEND_THREAD)); mergeSendExecutorService.submit(new MergedSendRunnable()); } super.init(); @@ -130,7 +127,7 @@ public AbstractNettyRemotingClient(NettyClientConfig nettyClientConfig, EventExe clientBootstrap = new NettyClientBootstrap(nettyClientConfig, eventExecutorGroup, transactionRole); clientBootstrap.setChannelHandlers(new ClientHandler()); clientChannelManager = new NettyClientChannelManager( - new NettyPoolableFactory(this, clientBootstrap), getPoolKeyFunction(), nettyClientConfig); + new NettyPoolableFactory(this, clientBootstrap), getPoolKeyFunction(), nettyClientConfig); } @Override @@ -151,10 +148,10 @@ public Object sendSyncRequest(Object msg) throws TimeoutException { // put message into basketMap BlockingQueue basket = CollectionUtils.computeIfAbsent(basketMap, serverAddress, - key -> new LinkedBlockingQueue<>()); + key -> new LinkedBlockingQueue<>()); if (!basket.offer(rpcMessage)) { LOGGER.error("put message into basketMap offer failed, serverAddress:{},rpcMessage:{}", - serverAddress, rpcMessage); + serverAddress, rpcMessage); return null; } if (LOGGER.isDebugEnabled()) { @@ -201,10 +198,12 @@ public void sendAsyncRequest(Channel channel, Object msg) { return; } RpcMessage rpcMessage = buildRequestMessage(msg, msg instanceof HeartbeatMessage - ? ProtocolConstants.MSGTYPE_HEARTBEAT_REQUEST - : ProtocolConstants.MSGTYPE_RESQUEST_ONEWAY); + ? ProtocolConstants.MSGTYPE_HEARTBEAT_REQUEST + : ProtocolConstants.MSGTYPE_RESQUEST_ONEWAY); if (rpcMessage.getBody() instanceof MergeMessage) { mergeMsgMap.put(rpcMessage.getId(), (MergeMessage) rpcMessage.getBody()); + } else if (rpcMessage.getBody() instanceof RegisterRMRequest) { + regRmRequestMap.put(rpcMessage.getId(), (RegisterRMRequest) rpcMessage.getBody()); } super.sendAsync(channel, rpcMessage); } @@ -236,6 +235,16 @@ public void destroy() { super.destroy(); } + public void refreshAuthToken(String extraData) { + if (StringUtils.isBlank(extraData)) { + return; + } + HashMap extraDataMap = StringUtils.string2Map(extraData); + String newAccessToken = extraDataMap.get(PRO_TOKEN); + String newRefreshToken = extraDataMap.get(PRO_REFRESH_TOKEN); + JwtAuthManager.getInstance().refreshToken(newAccessToken, newRefreshToken); + } + public void setTransactionMessageHandler(TransactionMessageHandler transactionMessageHandler) { this.transactionMessageHandler = transactionMessageHandler; } @@ -248,12 +257,20 @@ public NettyClientChannelManager getClientChannelManager() { return clientChannelManager; } + public RegisterRMRequest getRegisterRMRequest(Integer rpcMessageId) { + return regRmRequestMap.get(rpcMessageId); + } + + public RegisterRMRequest removeRegisterRMRequest(Integer rpcMessageId) { + return regRmRequestMap.remove(rpcMessageId); + } + protected String loadBalance(String transactionServiceGroup, Object msg) { InetSocketAddress address = null; try { @SuppressWarnings("unchecked") List inetSocketAddressList = - RegistryFactory.getInstance().aliveLookup(transactionServiceGroup); + RegistryFactory.getInstance().aliveLookup(transactionServiceGroup); address = this.doSelect(inetSocketAddressList, msg); } catch (Exception ex) { LOGGER.error("Select the address failed: {}", ex.getMessage()); @@ -295,6 +312,10 @@ protected String getXid(Object msg) { return StringUtils.isBlank(xid) ? String.valueOf(ThreadLocalRandom.current().nextLong(Long.MAX_VALUE)) : xid; } + protected String getAuthData() { + return JwtAuthManager.getInstance().getAuthData(); + } + private String getThreadPrefix() { return AbstractNettyRemotingClient.MERGE_THREAD_PREFIX + THREAD_PREFIX_SPLIT_CHAR + transactionRole.name(); } @@ -372,7 +393,7 @@ public void run() { MessageFuture messageFuture = futures.remove(msgId); if (messageFuture != null) { messageFuture.setResultMessage( - new RuntimeException(String.format("%s is unreachable", address), e)); + new RuntimeException(String.format("%s is unreachable", address), e)); } } LOGGER.error("client merge call failed: {}", e.getMessage(), e); @@ -471,7 +492,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { LOGGER.error(FrameworkErrorCode.ExceptionCaught.getErrCode(), - NetUtil.toStringAddress(ctx.channel().remoteAddress()) + "connect exception. " + cause.getMessage(), cause); + NetUtil.toStringAddress(ctx.channel().remoteAddress()) + "connect exception. " + cause.getMessage(), cause); clientChannelManager.releaseChannel(ctx.channel(), getAddressFromChannel(ctx.channel())); if (LOGGER.isInfoEnabled()) { LOGGER.info("remove exception rm channel:{}", ctx.channel()); diff --git a/core/src/main/java/org/apache/seata/core/rpc/netty/NettyClientChannelManager.java b/core/src/main/java/org/apache/seata/core/rpc/netty/NettyClientChannelManager.java index 7be0de2e729..f830aed742f 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/netty/NettyClientChannelManager.java +++ b/core/src/main/java/org/apache/seata/core/rpc/netty/NettyClientChannelManager.java @@ -16,20 +16,8 @@ */ package org.apache.seata.core.rpc.netty; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.function.Function; -import java.util.stream.Collectors; - import io.netty.channel.Channel; +import org.apache.commons.pool.impl.GenericKeyedObjectPool; import org.apache.seata.common.ConfigurationKeys; import org.apache.seata.common.exception.FrameworkErrorCode; import org.apache.seata.common.exception.FrameworkException; @@ -39,15 +27,21 @@ import org.apache.seata.discovery.registry.FileRegistryServiceImpl; import org.apache.seata.discovery.registry.RegistryFactory; import org.apache.seata.discovery.registry.RegistryService; -import org.apache.commons.pool.impl.GenericKeyedObjectPool; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.net.InetSocketAddress; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; +import java.util.stream.Collectors; + /** * Netty client pool manager. * */ -class NettyClientChannelManager { +class NettyClientChannelManager { private static final Logger LOGGER = LoggerFactory.getLogger(NettyClientChannelManager.class); @@ -62,7 +56,7 @@ class NettyClientChannelManager { private Function poolKeyFunction; NettyClientChannelManager(final NettyPoolableFactory keyPoolableFactory, final Function poolKeyFunction, - final NettyClientConfig clientConfig) { + final NettyClientConfig clientConfig) { nettyClientKeyPool = new GenericKeyedObjectPool<>(keyPoolableFactory); nettyClientKeyPool.setConfig(getNettyPoolConfig(clientConfig)); this.poolKeyFunction = poolKeyFunction; @@ -338,4 +332,3 @@ private void throwFailFastException(boolean failFast, String message) { } } - diff --git a/core/src/main/java/org/apache/seata/core/rpc/netty/NettyPoolableFactory.java b/core/src/main/java/org/apache/seata/core/rpc/netty/NettyPoolableFactory.java index 7755091eaf6..11384a71325 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/netty/NettyPoolableFactory.java +++ b/core/src/main/java/org/apache/seata/core/rpc/netty/NettyPoolableFactory.java @@ -16,17 +16,19 @@ */ package org.apache.seata.core.rpc.netty; -import java.net.InetSocketAddress; - import io.netty.channel.Channel; +import org.apache.commons.pool.KeyedPoolableObjectFactory; import org.apache.seata.common.exception.FrameworkException; import org.apache.seata.common.util.NetUtil; -import org.apache.seata.core.protocol.RegisterRMResponse; -import org.apache.seata.core.protocol.RegisterTMResponse; -import org.apache.commons.pool.KeyedPoolableObjectFactory; +import org.apache.seata.common.util.StringUtils; +import org.apache.seata.core.auth.JwtAuthManager; +import org.apache.seata.core.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.net.InetSocketAddress; +import java.util.HashMap; + /** * The type Netty key poolable factory. * @@ -63,48 +65,64 @@ public Channel makeObject(NettyPoolKey key) { throw new FrameworkException("register msg is null, role:" + key.getTransactionRole().name()); } try { + JwtAuthManager authManager = JwtAuthManager.getInstance(); response = rpcRemotingClient.sendSyncRequest(tmpChannel, key.getMessage()); - if (!isRegisterSuccess(response, key.getTransactionRole())) { + + if ((key.getTransactionRole().equals(NettyPoolKey.TransactionRole.TMROLE) + && !(response instanceof RegisterTMResponse)) + || (key.getTransactionRole().equals(NettyPoolKey.TransactionRole.RMROLE) + && !(response instanceof RegisterRMResponse))) { rpcRemotingClient.onRegisterMsgFail(key.getAddress(), tmpChannel, response, key.getMessage()); - } else { + } + AbstractIdentifyRequest request = (AbstractIdentifyRequest)key.getMessage(); + if (((AbstractIdentifyResponse) response).getResultCode().equals(ResultCode.AccessTokenExpired)) { + // refresh token to get access token + authManager.setAccessToken(null); + String identifyExtraData = authManager.getAuthData(); + request.setExtraData(identifyExtraData); + response = rpcRemotingClient.sendSyncRequest(tmpChannel, request); + } + ResultCode resultCode = ((AbstractIdentifyResponse) response).getResultCode(); + HashMap authMap = StringUtils.string2Map(request.getExtraData()); + boolean isTokenAuthFailed = resultCode.equals(ResultCode.Failed) && + (authMap.containsKey(JwtAuthManager.PRO_TOKEN) || authMap.containsKey(JwtAuthManager.PRO_REFRESH_TOKEN)); + if (resultCode.equals(ResultCode.RefreshTokenExpired) || isTokenAuthFailed) { + // relogin to get refresh token and access token + authManager.setAccessToken(null); + authManager.setRefreshToken(null); + String identifyExtraData = authManager.getAuthData(); + request.setExtraData(identifyExtraData); + response = rpcRemotingClient.sendSyncRequest(tmpChannel, request); + } + resultCode = ((AbstractIdentifyResponse) response).getResultCode(); + if (resultCode.equals(ResultCode.AccessTokenNearExpiration)) { + // access token near expiration + authManager.setAccessTokenNearExpiration(true); channelToServer = tmpChannel; rpcRemotingClient.onRegisterMsgSuccess(key.getAddress(), tmpChannel, response, key.getMessage()); + rpcRemotingClient.getClientChannelManager().registerChannel(key.getAddress(), tmpChannel); + } else if (resultCode.equals(ResultCode.Success)) { + channelToServer = tmpChannel; + rpcRemotingClient.onRegisterMsgSuccess(key.getAddress(), tmpChannel, response, key.getMessage()); + rpcRemotingClient.getClientChannelManager().registerChannel(key.getAddress(), tmpChannel); + } else { + rpcRemotingClient.onRegisterMsgFail(key.getAddress(), tmpChannel, response, key.getMessage()); } } catch (Exception exx) { if (tmpChannel != null) { tmpChannel.close(); } throw new FrameworkException( - "register " + key.getTransactionRole().name() + " error, errMsg:" + exx.getMessage()); + "register " + key.getTransactionRole().name() + " error, errMsg:" + exx.getMessage()); } if (LOGGER.isInfoEnabled()) { LOGGER.info("register success, cost " + (System.currentTimeMillis() - start) + " ms, version:" + getVersion( - response, key.getTransactionRole()) + ",role:" + key.getTransactionRole().name() + ",channel:" - + channelToServer); + response, key.getTransactionRole()) + ",role:" + key.getTransactionRole().name() + ",channel:" + + channelToServer); } return channelToServer; } - private boolean isRegisterSuccess(Object response, NettyPoolKey.TransactionRole transactionRole) { - if (response == null) { - return false; - } - if (transactionRole.equals(NettyPoolKey.TransactionRole.TMROLE)) { - if (!(response instanceof RegisterTMResponse)) { - return false; - } - RegisterTMResponse registerTMResponse = (RegisterTMResponse)response; - return registerTMResponse.isIdentified(); - } else if (transactionRole.equals(NettyPoolKey.TransactionRole.RMROLE)) { - if (!(response instanceof RegisterRMResponse)) { - return false; - } - RegisterRMResponse registerRMResponse = (RegisterRMResponse)response; - return registerRMResponse.isIdentified(); - } - return false; - } - private String getVersion(Object response, NettyPoolKey.TransactionRole transactionRole) { if (transactionRole.equals(NettyPoolKey.TransactionRole.TMROLE)) { return ((RegisterTMResponse) response).getVersion(); diff --git a/core/src/main/java/org/apache/seata/core/rpc/netty/RmNettyRemotingClient.java b/core/src/main/java/org/apache/seata/core/rpc/netty/RmNettyRemotingClient.java index 92cbafd0a5d..8d553ec018c 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/netty/RmNettyRemotingClient.java +++ b/core/src/main/java/org/apache/seata/core/rpc/netty/RmNettyRemotingClient.java @@ -16,14 +16,6 @@ */ package org.apache.seata.core.rpc.netty; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - import io.netty.channel.Channel; import io.netty.util.concurrent.EventExecutorGroup; import org.apache.seata.common.DefaultValues; @@ -43,14 +35,18 @@ import org.apache.seata.core.protocol.RegisterRMRequest; import org.apache.seata.core.protocol.RegisterRMResponse; import org.apache.seata.core.rpc.netty.NettyPoolKey.TransactionRole; -import org.apache.seata.core.rpc.processor.client.ClientHeartbeatProcessor; -import org.apache.seata.core.rpc.processor.client.ClientOnResponseProcessor; -import org.apache.seata.core.rpc.processor.client.RmBranchCommitProcessor; -import org.apache.seata.core.rpc.processor.client.RmBranchRollbackProcessor; -import org.apache.seata.core.rpc.processor.client.RmUndoLogProcessor; +import org.apache.seata.core.rpc.processor.client.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + import static org.apache.seata.common.Constants.DBKEYS_SPLIT_CHAR; /** @@ -132,10 +128,10 @@ public static RmNettyRemotingClient getInstance() { if (instance == null) { NettyClientConfig nettyClientConfig = new NettyClientConfig(); final ThreadPoolExecutor messageExecutor = new ThreadPoolExecutor( - nettyClientConfig.getClientWorkerThreads(), nettyClientConfig.getClientWorkerThreads(), - KEEP_ALIVE_TIME, TimeUnit.SECONDS, new LinkedBlockingQueue<>(MAX_QUEUE_SIZE), - new NamedThreadFactory(nettyClientConfig.getRmDispatchThreadPrefix(), - nettyClientConfig.getClientWorkerThreads()), new ThreadPoolExecutor.CallerRunsPolicy()); + nettyClientConfig.getClientWorkerThreads(), nettyClientConfig.getClientWorkerThreads(), + KEEP_ALIVE_TIME, TimeUnit.SECONDS, new LinkedBlockingQueue<>(MAX_QUEUE_SIZE), + new NamedThreadFactory(nettyClientConfig.getRmDispatchThreadPrefix(), + nettyClientConfig.getClientWorkerThreads()), new ThreadPoolExecutor.CallerRunsPolicy()); instance = new RmNettyRemotingClient(nettyClientConfig, null, messageExecutor); } } @@ -173,19 +169,18 @@ public void setResourceManager(ResourceManager resourceManager) { @Override public void onRegisterMsgSuccess(String serverAddress, Channel channel, Object response, AbstractMessage requestMessage) { - RegisterRMRequest registerRMRequest = (RegisterRMRequest)requestMessage; - RegisterRMResponse registerRMResponse = (RegisterRMResponse)response; + RegisterRMRequest registerRMRequest = (RegisterRMRequest) requestMessage; + RegisterRMResponse registerRMResponse = (RegisterRMResponse) response; + refreshAuthToken(registerRMResponse.getExtraData()); if (LOGGER.isInfoEnabled()) { LOGGER.info("register RM success. client version:{}, server version:{},channel:{}", registerRMRequest.getVersion(), registerRMResponse.getVersion(), channel); } getClientChannelManager().registerChannel(serverAddress, channel); String dbKey = getMergedResourceKeys(); - if (registerRMRequest.getResourceIds() != null) { - if (!registerRMRequest.getResourceIds().equals(dbKey)) { - sendRegisterMessage(serverAddress, channel, dbKey); - } + if (registerRMRequest.getResourceIds() != null + && !registerRMRequest.getResourceIds().equals(dbKey)) { + sendRegisterMessage(serverAddress, channel, dbKey); } - } @Override @@ -194,7 +189,7 @@ public void onRegisterMsgFail(String serverAddress, Channel channel, Object resp RegisterRMRequest registerRMRequest = (RegisterRMRequest)requestMessage; RegisterRMResponse registerRMResponse = (RegisterRMResponse)response; String errMsg = String.format( - "register RM failed. client version: %s,server version: %s, errorMsg: %s, " + "channel: %s", registerRMRequest.getVersion(), registerRMResponse.getVersion(), registerRMResponse.getMsg(), channel); + "register RM failed. client version: %s,server version: %s, errorMsg: %s, " + "channel: %s", registerRMRequest.getVersion(), registerRMResponse.getVersion(), registerRMResponse.getMsg(), channel); throw new FrameworkException(errMsg); } @@ -237,7 +232,7 @@ public void registerResource(String resourceGroupId, String resourceId) { } public void sendRegisterMessage(String serverAddress, Channel channel, String resourceId) { - RegisterRMRequest message = new RegisterRMRequest(applicationId, transactionServiceGroup); + RegisterRMRequest message = new RegisterRMRequest(applicationId, transactionServiceGroup, getAuthData()); message.setResourceIds(resourceId); try { super.sendAsyncRequest(channel, message); @@ -248,7 +243,7 @@ public void sendRegisterMessage(String serverAddress, Channel channel, String re LOGGER.info("remove not writable channel:{}", channel); } } else { - LOGGER.error("register resource failed, channel:{},resourceId:{}", channel, resourceId, e); + LOGGER.error("sendAsyncRequest register resource failed,, channel:{},resourceId:{}", channel, resourceId, e); } } } @@ -290,7 +285,7 @@ protected Function getPoolKeyFunction() { if (resourceIds != null && LOGGER.isInfoEnabled()) { LOGGER.info("RM will register :{}", resourceIds); } - RegisterRMRequest message = new RegisterRMRequest(applicationId, transactionServiceGroup); + RegisterRMRequest message = new RegisterRMRequest(applicationId, transactionServiceGroup, getAuthData()); message.setResourceIds(resourceIds); return new NettyPoolKey(NettyPoolKey.TransactionRole.RMROLE, serverAddress, message); }; @@ -323,7 +318,7 @@ private void registerProcessor() { super.registerProcessor(MessageType.TYPE_RM_DELETE_UNDOLOG, rmUndoLogProcessor, messageExecutor); // 4.registry TC response processor ClientOnResponseProcessor onResponseProcessor = - new ClientOnResponseProcessor(mergeMsgMap, super.getFutures(), getTransactionMessageHandler()); + new ClientOnResponseProcessor(mergeMsgMap, super.getFutures(), getTransactionMessageHandler()); super.registerProcessor(MessageType.TYPE_SEATA_MERGE_RESULT, onResponseProcessor, null); super.registerProcessor(MessageType.TYPE_BRANCH_REGISTER_RESULT, onResponseProcessor, null); super.registerProcessor(MessageType.TYPE_BRANCH_STATUS_REPORT_RESULT, onResponseProcessor, null); diff --git a/core/src/main/java/org/apache/seata/core/rpc/netty/TmNettyRemotingClient.java b/core/src/main/java/org/apache/seata/core/rpc/netty/TmNettyRemotingClient.java index 68ff739bbb0..07789a8c451 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/netty/TmNettyRemotingClient.java +++ b/core/src/main/java/org/apache/seata/core/rpc/netty/TmNettyRemotingClient.java @@ -16,12 +16,6 @@ */ package org.apache.seata.core.rpc.netty; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - import io.netty.channel.Channel; import io.netty.util.concurrent.EventExecutorGroup; import org.apache.commons.lang.StringUtils; @@ -46,6 +40,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR; +import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR; + /** * The rm netty client. * @@ -213,6 +216,7 @@ public void onRegisterMsgSuccess(String serverAddress, Channel channel, Object r AbstractMessage requestMessage) { RegisterTMRequest registerTMRequest = (RegisterTMRequest) requestMessage; RegisterTMResponse registerTMResponse = (RegisterTMResponse) response; + refreshAuthToken(registerTMResponse.getExtraData()); if (LOGGER.isInfoEnabled()) { LOGGER.info("register TM success. client version:{}, server version:{},channel:{}", registerTMRequest.getVersion(), registerTMResponse.getVersion(), channel); } @@ -272,15 +276,12 @@ private String getExtraData() { } String digest = signer.sign(digestSource, secretKey); StringBuilder sb = new StringBuilder(); - sb.append(RegisterTMRequest.UDATA_AK).append(org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR).append(accessKey).append( - org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR); - sb.append(RegisterTMRequest.UDATA_DIGEST).append(org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR).append(digest).append( - org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR); - sb.append(RegisterTMRequest.UDATA_TIMESTAMP).append(org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR).append(timestamp).append( - org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR); - sb.append(RegisterTMRequest.UDATA_AUTH_VERSION).append( - org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR).append(signer.getSignVersion()).append( - org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR); + sb.append(RegisterTMRequest.UDATA_AK).append(EXTRA_DATA_KV_CHAR).append(accessKey).append(EXTRA_DATA_SPLIT_CHAR); + sb.append(RegisterTMRequest.UDATA_DIGEST).append(EXTRA_DATA_KV_CHAR).append(digest).append(EXTRA_DATA_SPLIT_CHAR); + sb.append(RegisterTMRequest.UDATA_TIMESTAMP).append(EXTRA_DATA_KV_CHAR).append(timestamp).append(EXTRA_DATA_SPLIT_CHAR); + sb.append(RegisterTMRequest.UDATA_AUTH_VERSION).append(EXTRA_DATA_KV_CHAR).append(signer.getSignVersion()).append(EXTRA_DATA_SPLIT_CHAR); + String authExtraData = getAuthData(); + sb.append(authExtraData); return sb.toString(); } diff --git a/core/src/main/java/org/apache/seata/core/rpc/processor/client/ClientOnResponseProcessor.java b/core/src/main/java/org/apache/seata/core/rpc/processor/client/ClientOnResponseProcessor.java index f7b44c2e563..dadb3c38bb9 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/processor/client/ClientOnResponseProcessor.java +++ b/core/src/main/java/org/apache/seata/core/rpc/processor/client/ClientOnResponseProcessor.java @@ -16,32 +16,19 @@ */ package org.apache.seata.core.rpc.processor.client; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - import io.netty.channel.ChannelHandlerContext; -import org.apache.seata.core.protocol.AbstractResultMessage; -import org.apache.seata.core.protocol.BatchResultMessage; -import org.apache.seata.core.protocol.MergeMessage; -import org.apache.seata.core.protocol.MergeResultMessage; -import org.apache.seata.core.protocol.MergedWarpMessage; -import org.apache.seata.core.protocol.MessageFuture; -import org.apache.seata.core.protocol.RegisterRMResponse; -import org.apache.seata.core.protocol.RegisterTMResponse; -import org.apache.seata.core.protocol.RpcMessage; -import org.apache.seata.core.protocol.transaction.BranchRegisterResponse; -import org.apache.seata.core.protocol.transaction.BranchReportResponse; -import org.apache.seata.core.protocol.transaction.GlobalBeginResponse; -import org.apache.seata.core.protocol.transaction.GlobalCommitResponse; -import org.apache.seata.core.protocol.transaction.GlobalLockQueryResponse; -import org.apache.seata.core.protocol.transaction.GlobalReportResponse; -import org.apache.seata.core.protocol.transaction.GlobalRollbackResponse; +import org.apache.seata.core.auth.RegisterHandler; +import org.apache.seata.core.protocol.*; +import org.apache.seata.core.protocol.transaction.*; import org.apache.seata.core.rpc.TransactionMessageHandler; import org.apache.seata.core.rpc.processor.RemotingProcessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + /** * process TC response message. *

@@ -126,7 +113,12 @@ public void process(ChannelHandlerContext ctx, RpcMessage rpcMessage) throws Exc if (messageFuture != null) { messageFuture.setResultMessage(rpcMessage.getBody()); } else { - if (rpcMessage.getBody() instanceof AbstractResultMessage) { + if (rpcMessage.getBody() instanceof RegisterRMResponse) { + if (transactionMessageHandler != null) { + ((RegisterHandler) transactionMessageHandler).onRegisterResponse( + (RegisterRMResponse) rpcMessage.getBody(), ctx.channel(), rpcMessage.getId()); + } + } else if (rpcMessage.getBody() instanceof AbstractResultMessage) { if (transactionMessageHandler != null) { transactionMessageHandler.onResponse((AbstractResultMessage) rpcMessage.getBody(), null); } diff --git a/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegRmProcessor.java b/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegRmProcessor.java index 622f30039fa..ee341b58a7c 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegRmProcessor.java +++ b/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegRmProcessor.java @@ -19,15 +19,12 @@ import io.netty.channel.ChannelHandlerContext; import org.apache.seata.common.loader.EnhancedServiceLoader; import org.apache.seata.common.util.NetUtil; -import org.apache.seata.core.protocol.RegisterRMRequest; -import org.apache.seata.core.protocol.RegisterRMResponse; -import org.apache.seata.core.protocol.RpcMessage; -import org.apache.seata.core.protocol.Version; -import org.apache.seata.core.rpc.netty.ChannelManager; -import org.apache.seata.core.rpc.RemotingServer; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.protocol.*; import org.apache.seata.core.rpc.RegisterCheckAuthHandler; +import org.apache.seata.core.rpc.RemotingServer; +import org.apache.seata.core.rpc.netty.ChannelManager; import org.apache.seata.core.rpc.processor.RemotingProcessor; -import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -60,34 +57,42 @@ public void process(ChannelHandlerContext ctx, RpcMessage rpcMessage) throws Exc private void onRegRmMessage(ChannelHandlerContext ctx, RpcMessage rpcMessage) { RegisterRMRequest message = (RegisterRMRequest) rpcMessage.getBody(); String ipAndPort = NetUtil.toStringAddress(ctx.channel().remoteAddress()); - boolean isSuccess = false; - String errorInfo = StringUtils.EMPTY; + RegisterRMResponse response = new RegisterRMResponse(false); try { - if (null == checkAuthHandler || checkAuthHandler.regResourceManagerCheckAuth(message)) { + AuthResult authResult = (checkAuthHandler != null) ? checkAuthHandler.regResourceManagerCheckAuth(message) : null; + if (checkAuthHandler == null || authResult.getResultCode().equals(ResultCode.Success) + || authResult.getResultCode().equals(ResultCode.AccessTokenNearExpiration)) { ChannelManager.registerRMChannel(message, ctx.channel()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - isSuccess = true; + response.setIdentified(true); + response.setResultCode(checkAuthHandler == null ? ResultCode.Success : authResult.getResultCode()); + response.setExtraData(checkAuthHandler.fetchNewToken(authResult)); if (LOGGER.isDebugEnabled()) { - LOGGER.debug("RM checkAuth for client:{},vgroup:{},applicationId:{} is OK", ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + LOGGER.debug("RM checkAuth for client:{},vgroup:{},applicationId:{} is OK", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } else { + if (authResult.getResultCode().equals(ResultCode.Failed)) { + response.setMsg("RM checkAuth failed!Please check your username/password or token."); + } else if (authResult.getResultCode().equals(ResultCode.AccessTokenExpired)) { + response.setMsg("RM checkAuth failed! The access token has been expired."); + } else if (authResult.getResultCode().equals(ResultCode.RefreshTokenExpired)) { + response.setMsg("RM checkAuth failed! The refresh token has been expired."); + } + response.setResultCode(authResult.getResultCode()); if (LOGGER.isWarnEnabled()) { - LOGGER.warn("RM checkAuth for client:{},vgroup:{},applicationId:{} is FAIL", ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + LOGGER.warn("RM checkAuth for client:{},vgroup:{},applicationId:{} is FAIL", + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } - } catch (Exception exx) { - isSuccess = false; - errorInfo = exx.getMessage(); - LOGGER.error("RM register fail, error message:{}", errorInfo); - } - RegisterRMResponse response = new RegisterRMResponse(isSuccess); - if (StringUtils.isNotEmpty(errorInfo)) { - response.setMsg(errorInfo); + } catch (IncompatibleVersionException e) { + LOGGER.error("RM register fail, error message:{}", e.getMessage()); + response.setResultCode(ResultCode.Failed); } remotingServer.sendAsyncResponse(rpcMessage, ctx.channel(), response); - if (isSuccess && LOGGER.isInfoEnabled()) { + if (response.isIdentified() && LOGGER.isInfoEnabled()) { LOGGER.info("RM register success,message:{},channel:{},client version:{}", message, ctx.channel(), - message.getVersion()); + message.getVersion()); } } diff --git a/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegTmProcessor.java b/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegTmProcessor.java index 6090232c6c4..e428753b9ea 100644 --- a/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegTmProcessor.java +++ b/core/src/main/java/org/apache/seata/core/rpc/processor/server/RegTmProcessor.java @@ -19,15 +19,12 @@ import io.netty.channel.ChannelHandlerContext; import org.apache.seata.common.loader.EnhancedServiceLoader; import org.apache.seata.common.util.NetUtil; -import org.apache.seata.core.protocol.RegisterTMRequest; -import org.apache.seata.core.protocol.RegisterTMResponse; -import org.apache.seata.core.protocol.RpcMessage; -import org.apache.seata.core.protocol.Version; -import org.apache.seata.core.rpc.netty.ChannelManager; -import org.apache.seata.core.rpc.RemotingServer; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.protocol.*; import org.apache.seata.core.rpc.RegisterCheckAuthHandler; +import org.apache.seata.core.rpc.RemotingServer; +import org.apache.seata.core.rpc.netty.ChannelManager; import org.apache.seata.core.rpc.processor.RemotingProcessor; -import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,36 +58,42 @@ private void onRegTmMessage(ChannelHandlerContext ctx, RpcMessage rpcMessage) { RegisterTMRequest message = (RegisterTMRequest) rpcMessage.getBody(); String ipAndPort = NetUtil.toStringAddress(ctx.channel().remoteAddress()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - boolean isSuccess = false; - String errorInfo = StringUtils.EMPTY; + RegisterTMResponse response = new RegisterTMResponse(false); try { - if (null == checkAuthHandler || checkAuthHandler.regTransactionManagerCheckAuth(message)) { + AuthResult authResult = (checkAuthHandler != null) ? checkAuthHandler.regTransactionManagerCheckAuth(message) : null; + if (checkAuthHandler == null || authResult.getResultCode().equals(ResultCode.Success) + || authResult.getResultCode().equals(ResultCode.AccessTokenNearExpiration)) { ChannelManager.registerTMChannel(message, ctx.channel()); Version.putChannelVersion(ctx.channel(), message.getVersion()); - isSuccess = true; + response.setIdentified(true); + response.setResultCode(checkAuthHandler == null ? ResultCode.Success : authResult.getResultCode()); + response.setExtraData(checkAuthHandler.fetchNewToken(authResult)); if (LOGGER.isDebugEnabled()) { LOGGER.debug("TM checkAuth for client:{},vgroup:{},applicationId:{} is OK", - ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); + ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } else { + if (authResult.getResultCode().equals(ResultCode.Failed)) { + response.setMsg("TM checkAuth failed!Please check your username/password."); + } else if (authResult.getResultCode().equals(ResultCode.AccessTokenExpired)) { + response.setMsg("TM checkAuth failed! The access token has been expired."); + } else if (authResult.getResultCode().equals(ResultCode.RefreshTokenExpired)) { + response.setMsg("TM checkAuth failed! The refresh token has been expired."); + } + response.setResultCode(authResult.getResultCode()); if (LOGGER.isWarnEnabled()) { LOGGER.warn("TM checkAuth for client:{},vgroup:{},applicationId:{} is FAIL", ipAndPort, message.getTransactionServiceGroup(), message.getApplicationId()); } } - } catch (Exception exx) { - isSuccess = false; - errorInfo = exx.getMessage(); - LOGGER.error("TM register fail, error message:{}", errorInfo); - } - RegisterTMResponse response = new RegisterTMResponse(isSuccess); - if (StringUtils.isNotEmpty(errorInfo)) { - response.setMsg(errorInfo); + } catch (IncompatibleVersionException e) { + LOGGER.error("TM register fail, error message:{}", e.getMessage()); + response.setResultCode(ResultCode.Failed); } remotingServer.sendAsyncResponse(rpcMessage, ctx.channel(), response); - if (isSuccess && LOGGER.isInfoEnabled()) { - LOGGER.info("TM register success,message:{},channel:{},client version:{},client protocol-version:{}" - , message, ctx.channel(), message.getVersion(), rpcMessage.getOtherSideVersion()); + if (response.isIdentified() && LOGGER.isInfoEnabled()) { + LOGGER.info("TM register success,message:{},channel:{},client version:{}", message, ctx.channel(), + message.getVersion()); } } diff --git a/core/src/test/java/org/apache/seata/core/protocol/ResultCodeTest.java b/core/src/test/java/org/apache/seata/core/protocol/ResultCodeTest.java index b1c7ef3e741..5ac620f7ce8 100644 --- a/core/src/test/java/org/apache/seata/core/protocol/ResultCodeTest.java +++ b/core/src/test/java/org/apache/seata/core/protocol/ResultCodeTest.java @@ -29,8 +29,11 @@ class ResultCodeTest { void getByte() { Assertions.assertEquals(ResultCode.Failed, ResultCode.get((byte) 0)); Assertions.assertEquals(ResultCode.Success, ResultCode.get((byte) 1)); + Assertions.assertEquals(ResultCode.AccessTokenExpired, ResultCode.get((byte) 2)); + Assertions.assertEquals(ResultCode.AccessTokenNearExpiration, ResultCode.get((byte) 3)); + Assertions.assertEquals(ResultCode.RefreshTokenExpired, ResultCode.get((byte) 4)); Assertions.assertThrows(IllegalArgumentException.class, () -> { - ResultCode.get((byte) 2); + ResultCode.get((byte) 5); }); } @@ -38,14 +41,19 @@ void getByte() { void getInt() { Assertions.assertEquals(ResultCode.Failed, ResultCode.get(0)); Assertions.assertEquals(ResultCode.Success, ResultCode.get(1)); + Assertions.assertEquals(ResultCode.AccessTokenExpired, ResultCode.get(2)); + Assertions.assertEquals(ResultCode.AccessTokenNearExpiration, ResultCode.get(3)); + Assertions.assertEquals(ResultCode.RefreshTokenExpired, ResultCode.get(4)); Assertions.assertThrows(IllegalArgumentException.class, () -> { - ResultCode.get(2); + ResultCode.get(5); }); } @Test void values() { - Assertions.assertArrayEquals(new ResultCode[]{ResultCode.Failed, ResultCode.Success}, ResultCode.values()); + Assertions.assertArrayEquals(new ResultCode[]{ResultCode.Failed, + ResultCode.Success, ResultCode.AccessTokenExpired, + ResultCode.AccessTokenNearExpiration, ResultCode.RefreshTokenExpired}, ResultCode.values()); } @Test @@ -58,5 +66,17 @@ void valueOf() { Assertions.assertThrows(IllegalArgumentException.class, () -> { ResultCode.valueOf("SUCCESS"); }); + Assertions.assertEquals(ResultCode.AccessTokenExpired, ResultCode.valueOf("AccessTokenExpired")); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + ResultCode.valueOf("ACCESSTOKENEXPIRED"); + }); + Assertions.assertEquals(ResultCode.AccessTokenNearExpiration, ResultCode.valueOf("AccessTokenNearExpiration")); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + ResultCode.valueOf("ACCESSTOKENNEAREXPIRATION"); + }); + Assertions.assertEquals(ResultCode.RefreshTokenExpired, ResultCode.valueOf("RefreshTokenExpired")); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + ResultCode.valueOf("REFRESHTOKENEXPIRED"); + }); } } diff --git a/discovery/seata-discovery-raft/src/main/java/org/apache/seata/discovery/registry/raft/RaftRegistryServiceImpl.java b/discovery/seata-discovery-raft/src/main/java/org/apache/seata/discovery/registry/raft/RaftRegistryServiceImpl.java index f52464ef4ae..4f1c258a0f9 100644 --- a/discovery/seata-discovery-raft/src/main/java/org/apache/seata/discovery/registry/raft/RaftRegistryServiceImpl.java +++ b/discovery/seata-discovery-raft/src/main/java/org/apache/seata/discovery/registry/raft/RaftRegistryServiceImpl.java @@ -16,26 +16,15 @@ */ package org.apache.seata.discovery.registry.raft; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.stream.Collectors; -import java.util.stream.Stream; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.HttpStatus; +import org.apache.http.StatusLine; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.entity.ContentType; +import org.apache.http.protocol.HTTP; +import org.apache.http.util.EntityUtils; import org.apache.seata.common.ConfigurationKeys; import org.apache.seata.common.exception.AuthenticationFailedException; import org.apache.seata.common.exception.RetryableException; @@ -50,15 +39,26 @@ import org.apache.seata.config.Configuration; import org.apache.seata.config.ConfigurationFactory; import org.apache.seata.discovery.registry.RegistryService; -import org.apache.http.HttpStatus; -import org.apache.http.StatusLine; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.entity.ContentType; -import org.apache.http.util.EntityUtils; -import org.apache.http.protocol.HTTP; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.Stream; + /** * The type File registry service. * @@ -164,7 +164,7 @@ protected static void startQueryMetadata() { synchronized (INIT_ADDRESSES) { if (REFRESH_METADATA_EXECUTOR == null) { REFRESH_METADATA_EXECUTOR = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue<>(), new NamedThreadFactory("refreshMetadata", 1, true)); + new LinkedBlockingQueue<>(), new NamedThreadFactory("refreshMetadata", 1, true)); REFRESH_METADATA_EXECUTOR.execute(() -> { long metadataMaxAgeMs = CONFIG.getLong(ConfigurationKeys.CLIENT_METADATA_MAX_AGE_MS, 30000L); long currentTime = System.currentTimeMillis(); @@ -221,7 +221,7 @@ private static String queryHttpAddress(String clusterName, String group) { List inetSocketAddresses = ALIVE_NODES.get(CURRENT_TRANSACTION_SERVICE_GROUP); if (CollectionUtils.isEmpty(inetSocketAddresses)) { addressList = - nodeList.stream().map(node -> node.getControl().createAddress()).collect(Collectors.toList()); + nodeList.stream().map(node -> node.getControl().createAddress()).collect(Collectors.toList()); } else { stream = inetSocketAddresses.stream(); } @@ -235,14 +235,14 @@ private static String queryHttpAddress(String clusterName, String group) { if (CollectionUtils.isNotEmpty(nodeList)) { for (Node node : nodeList) { map.put(new InetSocketAddress(node.getTransaction().getHost(), node.getTransaction().getPort()).getAddress().getHostAddress() - + IP_PORT_SPLIT_CHAR + node.getTransaction().getPort(), node); + + IP_PORT_SPLIT_CHAR + node.getTransaction().getPort(), node); } } addressList = stream.map(inetSocketAddress -> { String host = inetSocketAddress.getAddress().getHostAddress(); Node node = map.get(host + IP_PORT_SPLIT_CHAR + inetSocketAddress.getPort()); return host + IP_PORT_SPLIT_CHAR - + (node != null ? node.getControl().getPort() : inetSocketAddress.getPort()); + + (node != null ? node.getControl().getPort() : inetSocketAddress.getPort()); }).collect(Collectors.toList()); return addressList.get(ThreadLocalRandom.current().nextInt(addressList.size())); } @@ -250,22 +250,22 @@ private static String queryHttpAddress(String clusterName, String group) { private static String getRaftAddrFileKey() { return String.join(ConfigurationKeys.FILE_CONFIG_SPLIT_CHAR, ConfigurationKeys.FILE_ROOT_REGISTRY, - REGISTRY_TYPE, PRO_SERVER_ADDR_KEY); + REGISTRY_TYPE, PRO_SERVER_ADDR_KEY); } private static String getRaftUserNameKey() { return String.join(ConfigurationKeys.FILE_CONFIG_SPLIT_CHAR, ConfigurationKeys.FILE_ROOT_REGISTRY, - REGISTRY_TYPE, PRO_USERNAME_KEY); + REGISTRY_TYPE, PRO_USERNAME_KEY); } private static String getRaftPassWordKey() { return String.join(ConfigurationKeys.FILE_CONFIG_SPLIT_CHAR, ConfigurationKeys.FILE_ROOT_REGISTRY, - REGISTRY_TYPE, PRO_PASSWORD_KEY); + REGISTRY_TYPE, PRO_PASSWORD_KEY); } private static String getTokenExpireTimeInMillisecondsKey() { return String.join(ConfigurationKeys.FILE_CONFIG_SPLIT_CHAR, ConfigurationKeys.FILE_ROOT_REGISTRY, - REGISTRY_TYPE, TOKEN_VALID_TIME_MS_KEY); + REGISTRY_TYPE, TOKEN_VALID_TIME_MS_KEY); } private static boolean isTokenExpired() { @@ -314,7 +314,7 @@ private static boolean watch() throws RetryableException { header.put(AUTHORIZATION_HEADER, jwtToken); } try (CloseableHttpResponse response = - HttpClientUtil.doPost("http://" + tcAddress + "/metadata/v1/watch", param, header, 30000)) { + HttpClientUtil.doPost("http://" + tcAddress + "/metadata/v1/watch", param, header, 30000)) { if (response != null) { StatusLine statusLine = response.getStatusLine(); if (statusLine != null && statusLine.getStatusCode() == HttpStatus.SC_UNAUTHORIZED) { @@ -337,16 +337,16 @@ private static boolean watch() throws RetryableException { @Override public List refreshAliveLookup(String transactionServiceGroup, - List aliveAddress) { + List aliveAddress) { if (METADATA.isRaftMode()) { Node leader = METADATA.getLeader(getServiceGroup(transactionServiceGroup)); InetSocketAddress leaderAddress = convertInetSocketAddress(leader); return ALIVE_NODES.put(transactionServiceGroup, - aliveAddress.isEmpty() ? aliveAddress : aliveAddress.parallelStream().filter(inetSocketAddress -> { - // Since only follower will turn into leader, only the follower node needs to be listened to - return inetSocketAddress.getPort() != leaderAddress.getPort() || !inetSocketAddress.getAddress() - .getHostAddress().equals(leaderAddress.getAddress().getHostAddress()); - }).collect(Collectors.toList())); + aliveAddress.isEmpty() ? aliveAddress : aliveAddress.parallelStream().filter(inetSocketAddress -> { + // Since only follower will turn into leader, only the follower node needs to be listened to + return inetSocketAddress.getPort() != leaderAddress.getPort() || !inetSocketAddress.getAddress() + .getHostAddress().equals(leaderAddress.getAddress().getHostAddress()); + }).collect(Collectors.toList())); } else { return RegistryService.super.refreshAliveLookup(transactionServiceGroup, aliveAddress); } @@ -375,7 +375,7 @@ private static void acquireClusterMetaData(String clusterName, String group) thr param.put("group", group); String response = null; try (CloseableHttpResponse httpResponse = - HttpClientUtil.doGet("http://" + tcAddress + "/metadata/v1/cluster", param, header, 1000)) { + HttpClientUtil.doGet("http://" + tcAddress + "/metadata/v1/cluster", param, header, 1000)) { if (httpResponse != null) { if (httpResponse.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { response = EntityUtils.toString(httpResponse.getEntity(), StandardCharsets.UTF_8); @@ -416,7 +416,7 @@ private static void refreshToken(String tcAddress) throws RetryableException { String response = null; tokenTimeStamp = System.currentTimeMillis(); try (CloseableHttpResponse httpResponse = - HttpClientUtil.doPost("http://" + tcAddress + "/api/v1/auth/login", param, header, 1000)) { + HttpClientUtil.doPost("http://" + tcAddress + "/api/v1/auth/login", param, header, 1000)) { if (httpResponse != null) { if (httpResponse.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { response = EntityUtils.toString(httpResponse.getEntity(), StandardCharsets.UTF_8); diff --git a/rm/src/main/java/org/apache/seata/rm/AbstractRMHandler.java b/rm/src/main/java/org/apache/seata/rm/AbstractRMHandler.java index 682a03605ef..dbcc4cc816a 100644 --- a/rm/src/main/java/org/apache/seata/rm/AbstractRMHandler.java +++ b/rm/src/main/java/org/apache/seata/rm/AbstractRMHandler.java @@ -16,31 +16,32 @@ */ package org.apache.seata.rm; +import io.netty.channel.Channel; +import org.apache.seata.common.exception.FrameworkException; +import org.apache.seata.common.util.StringUtils; +import org.apache.seata.core.auth.JwtAuthManager; +import org.apache.seata.core.auth.RegisterHandler; import org.apache.seata.core.exception.AbstractExceptionHandler; import org.apache.seata.core.exception.TransactionException; import org.apache.seata.core.model.BranchStatus; import org.apache.seata.core.model.BranchType; import org.apache.seata.core.model.ResourceManager; -import org.apache.seata.core.protocol.AbstractMessage; -import org.apache.seata.core.protocol.AbstractResultMessage; -import org.apache.seata.core.protocol.transaction.AbstractTransactionRequestToRM; -import org.apache.seata.core.protocol.transaction.BranchCommitRequest; -import org.apache.seata.core.protocol.transaction.BranchCommitResponse; -import org.apache.seata.core.protocol.transaction.BranchRollbackRequest; -import org.apache.seata.core.protocol.transaction.BranchRollbackResponse; -import org.apache.seata.core.protocol.transaction.RMInboundHandler; -import org.apache.seata.core.protocol.transaction.UndoLogDeleteRequest; +import org.apache.seata.core.protocol.*; +import org.apache.seata.core.protocol.transaction.*; import org.apache.seata.core.rpc.RpcContext; import org.apache.seata.core.rpc.TransactionMessageHandler; +import org.apache.seata.core.rpc.netty.RmNettyRemotingClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; + /** * The Abstract RM event handler * */ public abstract class AbstractRMHandler extends AbstractExceptionHandler - implements RMInboundHandler, TransactionMessageHandler { + implements RMInboundHandler, TransactionMessageHandler, RegisterHandler { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRMHandler.class); @@ -50,7 +51,7 @@ public BranchCommitResponse handle(BranchCommitRequest request) { exceptionHandleTemplate(new AbstractCallback() { @Override public void execute(BranchCommitRequest request, BranchCommitResponse response) - throws TransactionException { + throws TransactionException { doBranchCommit(request, response); } }, request, response); @@ -63,7 +64,7 @@ public BranchRollbackResponse handle(BranchRollbackRequest request) { exceptionHandleTemplate(new AbstractCallback() { @Override public void execute(BranchRollbackRequest request, BranchRollbackResponse response) - throws TransactionException { + throws TransactionException { doBranchRollback(request, response); } }, request, response); @@ -87,7 +88,7 @@ public void handle(UndoLogDeleteRequest request) { * @throws TransactionException the transaction exception */ protected void doBranchCommit(BranchCommitRequest request, BranchCommitResponse response) - throws TransactionException { + throws TransactionException { String xid = request.getXid(); long branchId = request.getBranchId(); String resourceId = request.getResourceId(); @@ -96,7 +97,7 @@ protected void doBranchCommit(BranchCommitRequest request, BranchCommitResponse LOGGER.info("Branch committing: " + xid + " " + branchId + " " + resourceId + " " + applicationData); } BranchStatus status = getResourceManager().branchCommit(request.getBranchType(), xid, branchId, resourceId, - applicationData); + applicationData); response.setXid(xid); response.setBranchId(branchId); response.setBranchStatus(status); @@ -114,7 +115,7 @@ protected void doBranchCommit(BranchCommitRequest request, BranchCommitResponse * @throws TransactionException the transaction exception */ protected void doBranchRollback(BranchRollbackRequest request, BranchRollbackResponse response) - throws TransactionException { + throws TransactionException { String xid = request.getXid(); long branchId = request.getBranchId(); String resourceId = request.getResourceId(); @@ -123,7 +124,7 @@ protected void doBranchRollback(BranchRollbackRequest request, BranchRollbackRes LOGGER.info("Branch Rollbacking: " + xid + " " + branchId + " " + resourceId); } BranchStatus status = getResourceManager().branchRollback(request.getBranchType(), xid, branchId, resourceId, - applicationData); + applicationData); response.setXid(xid); response.setBranchId(branchId); response.setBranchStatus(status); @@ -155,5 +156,59 @@ public void onResponse(AbstractResultMessage response, RpcContext context) { LOGGER.info("the rm client received response msg [{}] from tc server.", response.toString()); } + @Override + public void onRegisterResponse(RegisterRMResponse response, Channel channel, Integer rpcId) { + LOGGER.info("the rm client received register response msg [{}] from tc server.", response.toString()); + try { + JwtAuthManager authManager = JwtAuthManager.getInstance(); + ResultCode resultCode = response.getResultCode(); + RegisterRMRequest request = RmNettyRemotingClient.getInstance().getRegisterRMRequest(rpcId); + HashMap authMap = StringUtils.string2Map(request.getExtraData()); + boolean isTokenAuthFailed = resultCode.equals(ResultCode.Failed) && + (authMap.containsKey(JwtAuthManager.PRO_TOKEN) || authMap.containsKey(JwtAuthManager.PRO_REFRESH_TOKEN)); + if (resultCode.equals(ResultCode.AccessTokenExpired)) { + // refresh token to get access token + authManager.setAccessToken(null); + String identifyExtraData = authManager.getAuthData(); + request.setExtraData(identifyExtraData); + RmNettyRemotingClient.getInstance().sendAsyncRequest(channel, request); + } else if (resultCode.equals(ResultCode.RefreshTokenExpired) || isTokenAuthFailed) { + // relogin to get refresh token and access token + authManager.setAccessToken(null); + authManager.setRefreshToken(null); + String identifyExtraData = authManager.getAuthData(); + request.setExtraData(identifyExtraData); + RmNettyRemotingClient.getInstance().sendAsyncRequest(channel, request); + } else if (resultCode.equals(ResultCode.AccessTokenNearExpiration)) { + // + authManager.setAccessTokenNearExpiration(true); + if (LOGGER.isInfoEnabled()) { + LOGGER.info("register RM success. client version:{}, server version:{},channel:{}", + request.getVersion(), request.getVersion(), channel); + } + + } else if (resultCode.equals(ResultCode.Success)) { + RmNettyRemotingClient.getInstance().refreshAuthToken(response.getExtraData()); + if (LOGGER.isInfoEnabled()) { + LOGGER.info("register RM success. client version:{}, server version:{},channel:{}", + request.getVersion(), request.getVersion(), channel); + } + RmNettyRemotingClient.getInstance().removeRegisterRMRequest(rpcId); + } else { + String errMsg = String.format( + "register RM failed. client version: %s,server version: %s, errorMsg: %s, " + "channel: %s", + request.getVersion(), request.getVersion(), response.getMsg(), channel); + RmNettyRemotingClient.getInstance().removeRegisterRMRequest(rpcId); + throw new FrameworkException(errMsg); + } + } catch (Exception exx) { + throw new FrameworkException( + "register RM" + " error, errMsg:" + exx.getMessage()); + } + if (LOGGER.isInfoEnabled()) { + LOGGER.info("register RM success"); + } + } + public abstract BranchType getBranchType(); } diff --git a/script/client/spring/application.properties b/script/client/spring/application.properties index 2a72d1e5f79..f67b3bf9b1c 100755 --- a/script/client/spring/application.properties +++ b/script/client/spring/application.properties @@ -27,6 +27,11 @@ seata.enable-auto-data-source-proxy=true seata.data-source-proxy-mode=AT seata.use-jdk-proxy=false seata.expose-proxy=false +security.username=seata +security.password=seata +security.secretKey=SeataSecretKey0c382ef121d778043159209298fd40bf3850a017 +security.accessTokenValidityInMilliseconds=600000 +security.refreshTokenValidityInMilliseconds=86400000 seata.client.rm.async-commit-buffer-limit=10000 seata.client.rm.report-retry-count=5 seata.client.rm.table-meta-check-enable=false diff --git a/script/client/spring/application.yml b/script/client/spring/application.yml index a6100f05740..cc2bba8ab69 100755 --- a/script/client/spring/application.yml +++ b/script/client/spring/application.yml @@ -28,6 +28,15 @@ seata: scan-packages: firstPackage,secondPackage excludes-for-scanning: firstBeanNameForExclude,secondBeanNameForExclude excludes-for-auto-proxying: firstClassNameForExclude,secondClassNameForExclude + security: + username: seata + password: seata + secretKey: SeataSecretKey0c382ef121d778043159209298fd40bf3850a017 + accessTokenValidityInMilliseconds: 600000 + refreshTokenValidityInMilliseconds: 86400000 + csrf-ignore-urls: /metadata/v1/** + ignore: + urls: /,/**/*.css,/**/*.js,/**/*.html,/**/*.map,/**/*.svg,/**/*.png,/**/*.jpeg,/**/*.ico,/api/v1/auth/login,/version.json,/health,/error,/vgroup/v1/**,/metadata/v1/auth/login client: rm: async-commit-buffer-limit: 10000 diff --git a/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractIdentifyResponseCodec.java b/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractIdentifyResponseCodec.java index 56840f9e618..aa38aecfdba 100644 --- a/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractIdentifyResponseCodec.java +++ b/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractIdentifyResponseCodec.java @@ -16,14 +16,13 @@ */ package org.apache.seata.serializer.seata.protocol; -import java.nio.ByteBuffer; - import io.netty.buffer.ByteBuf; import org.apache.seata.core.protocol.AbstractIdentifyResponse; +import java.nio.ByteBuffer; + /** * The type Abstract identify response. - * */ public abstract class AbstractIdentifyResponseCodec extends AbstractResultMessageCodec { @@ -34,25 +33,38 @@ public Class getMessageClassType() { @Override public void encode(T t, ByteBuf out) { - AbstractIdentifyResponse abstractIdentifyResponse = (AbstractIdentifyResponse)t; + super.encode(t,out); + AbstractIdentifyResponse abstractIdentifyResponse = (AbstractIdentifyResponse) t; boolean identified = abstractIdentifyResponse.isIdentified(); String version = abstractIdentifyResponse.getVersion(); + String extraData = abstractIdentifyResponse.getExtraData(); - out.writeByte(identified ? (byte)1 : (byte)0); + out.writeByte(identified ? (byte) 1 : (byte) 0); if (version != null) { byte[] bs = version.getBytes(UTF8); - out.writeShort((short)bs.length); + out.writeShort((short) bs.length); if (bs.length > 0) { out.writeBytes(bs); } } else { - out.writeShort((short)0); + out.writeShort((short) 0); + } + + if (extraData != null) { + byte[] bs = extraData.getBytes(UTF8); + out.writeShort((short) bs.length); + if (bs.length > 0) { + out.writeBytes(bs); + } + } else { + out.writeShort((short) 0); } } @Override public void decode(T t, ByteBuffer in) { - AbstractIdentifyResponse abstractIdentifyResponse = (AbstractIdentifyResponse)t; + super.decode(t,in); + AbstractIdentifyResponse abstractIdentifyResponse = (AbstractIdentifyResponse) t; abstractIdentifyResponse.setIdentified(in.get() == 1); short len = in.getShort(); @@ -65,6 +77,20 @@ public void decode(T t, ByteBuffer in) { byte[] bs = new byte[len]; in.get(bs); abstractIdentifyResponse.setVersion(new String(bs, UTF8)); + + //ExtraData len + if (in.remaining() < 2) { + return; + } + len = in.getShort(); + + if (in.remaining() >= len) { + bs = new byte[len]; + in.get(bs); + abstractIdentifyResponse.setExtraData(new String(bs, UTF8)); + } else { + //maybe null + } } } diff --git a/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractResultMessageCodec.java b/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractResultMessageCodec.java index 3eedf496872..a25e2cef571 100644 --- a/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractResultMessageCodec.java +++ b/serializer/seata-serializer-seata/src/main/java/org/apache/seata/serializer/seata/protocol/AbstractResultMessageCodec.java @@ -16,16 +16,15 @@ */ package org.apache.seata.serializer.seata.protocol; -import java.nio.ByteBuffer; - import io.netty.buffer.ByteBuf; import org.apache.seata.common.util.StringUtils; import org.apache.seata.core.protocol.AbstractResultMessage; import org.apache.seata.core.protocol.ResultCode; +import java.nio.ByteBuffer; + /** * The type Abstract result message codec. - * */ public abstract class AbstractResultMessageCodec extends AbstractMessageCodec { @@ -36,41 +35,44 @@ public Class getMessageClassType() { @Override public void encode(T t, ByteBuf out) { - AbstractResultMessage abstractResultMessage = (AbstractResultMessage)t; + AbstractResultMessage abstractResultMessage = (AbstractResultMessage) t; ResultCode resultCode = abstractResultMessage.getResultCode(); String resultMsg = abstractResultMessage.getMsg(); + if (null != resultCode) { + out.writeByte(resultCode.ordinal()); + } else { + out.writeByte(ResultCode.values().length); + } - out.writeByte(resultCode.ordinal()); - if (resultCode == ResultCode.Failed) { - if (StringUtils.isNotEmpty(resultMsg)) { - String msg; - if (resultMsg.length() > Short.MAX_VALUE) { - msg = resultMsg.substring(0, Short.MAX_VALUE); - } else { - msg = resultMsg; - } - byte[] bs = msg.getBytes(UTF8); - out.writeShort((short)bs.length); - out.writeBytes(bs); + if (StringUtils.isNotEmpty(resultMsg)) { + String msg; + if (resultMsg.length() > Short.MAX_VALUE) { + msg = resultMsg.substring(0, Short.MAX_VALUE); } else { - out.writeShort((short)0); + msg = resultMsg; } + byte[] bs = msg.getBytes(UTF8); + out.writeShort((short) bs.length); + out.writeBytes(bs); + } else { + out.writeShort((short) 0); } } @Override public void decode(T t, ByteBuffer in) { - AbstractResultMessage abstractResultMessage = (AbstractResultMessage)t; - - ResultCode resultCode = ResultCode.get(in.get()); - abstractResultMessage.setResultCode(resultCode); - if (resultCode == ResultCode.Failed) { - short len = in.getShort(); - if (len > 0) { - byte[] msg = new byte[len]; - in.get(msg); - abstractResultMessage.setMsg(new String(msg, UTF8)); - } + AbstractResultMessage abstractResultMessage = (AbstractResultMessage) t; + ResultCode resultCode = null; + byte resultCodeOrdinal = in.get(); + if (resultCodeOrdinal < ResultCode.values().length) { + resultCode = ResultCode.get(resultCodeOrdinal); + abstractResultMessage.setResultCode(resultCode); + } + short len = in.getShort(); + if (len > 0) { + byte[] msg = new byte[len]; + in.get(msg); + abstractResultMessage.setMsg(new String(msg, UTF8)); } } diff --git a/serializer/seata-serializer-seata/src/test/java/org/apache/seata/serializer/seata/protocol/RegisterTMRequestSerializerTest.java b/serializer/seata-serializer-seata/src/test/java/org/apache/seata/serializer/seata/protocol/RegisterTMRequestSerializerTest.java index 73c53e89ae2..7c9714ed322 100644 --- a/serializer/seata-serializer-seata/src/test/java/org/apache/seata/serializer/seata/protocol/RegisterTMRequestSerializerTest.java +++ b/serializer/seata-serializer-seata/src/test/java/org/apache/seata/serializer/seata/protocol/RegisterTMRequestSerializerTest.java @@ -17,14 +17,13 @@ package org.apache.seata.serializer.seata.protocol; import io.netty.buffer.ByteBuf; +import org.apache.seata.core.protocol.*; import org.apache.seata.serializer.seata.SeataSerializer; -import org.apache.seata.core.protocol.AbstractIdentifyRequest; -import org.apache.seata.core.protocol.RegisterTMRequest; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import org.apache.seata.core.protocol.ProtocolConstants; + import static io.netty.buffer.Unpooled.buffer; import static org.assertj.core.api.Assertions.assertThat; @@ -68,6 +67,19 @@ public void test_codec() { assertThat(registerTMRequest2.getVersion()).isEqualTo(registerTMRequest.getVersion()); } + @Test + public void test_codec1() { + RegisterTMResponse registerTMResponse = new RegisterTMResponse(); + registerTMResponse.setIdentified(true); + registerTMResponse.setVersion("2.3.0-SNAPSHOT"); + registerTMResponse.setResultCode(ResultCode.RefreshTokenExpired); + byte[] bytes = seataSerializer.serialize(registerTMResponse); + RegisterTMResponse registerTMResponse2 = seataSerializer.deserialize(bytes); + assertThat(registerTMResponse2.isIdentified()).isEqualTo(registerTMResponse.isIdentified()); + assertThat(registerTMResponse2.getVersion()).isEqualTo(registerTMResponse.getVersion()); + assertThat(registerTMResponse2.getResultCode()).isEqualTo(registerTMResponse.getResultCode()); + } + /** * Constructor without arguments **/ diff --git a/server/src/main/java/org/apache/seata/server/auth/AbstractCheckAuthHandler.java b/server/src/main/java/org/apache/seata/server/auth/AbstractCheckAuthHandler.java index 705ec2d2c9b..7f837190a41 100644 --- a/server/src/main/java/org/apache/seata/server/auth/AbstractCheckAuthHandler.java +++ b/server/src/main/java/org/apache/seata/server/auth/AbstractCheckAuthHandler.java @@ -17,9 +17,12 @@ package org.apache.seata.server.auth; import org.apache.seata.config.ConfigurationFactory; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.auth.AuthResultBuilder; import org.apache.seata.core.constants.ConfigurationKeys; import org.apache.seata.core.protocol.RegisterRMRequest; import org.apache.seata.core.protocol.RegisterTMRequest; +import org.apache.seata.core.protocol.ResultCode; import org.apache.seata.core.rpc.RegisterCheckAuthHandler; import static org.apache.seata.common.DefaultValues.DEFAULT_SERVER_ENABLE_CHECK_AUTH; @@ -29,25 +32,31 @@ public abstract class AbstractCheckAuthHandler implements RegisterCheckAuthHandler { private static final Boolean ENABLE_CHECK_AUTH = ConfigurationFactory.getInstance().getBoolean( - ConfigurationKeys.SERVER_ENABLE_CHECK_AUTH, DEFAULT_SERVER_ENABLE_CHECK_AUTH); + ConfigurationKeys.SERVER_ENABLE_CHECK_AUTH, DEFAULT_SERVER_ENABLE_CHECK_AUTH); @Override - public boolean regTransactionManagerCheckAuth(RegisterTMRequest request) { + public AuthResult regTransactionManagerCheckAuth(RegisterTMRequest request) { if (!ENABLE_CHECK_AUTH) { - return true; + return new AuthResultBuilder().setResultCode(ResultCode.Success).build(); } return doRegTransactionManagerCheck(request); } - public abstract boolean doRegTransactionManagerCheck(RegisterTMRequest request); + public abstract AuthResult doRegTransactionManagerCheck(RegisterTMRequest request); @Override - public boolean regResourceManagerCheckAuth(RegisterRMRequest request) { + public AuthResult regResourceManagerCheckAuth(RegisterRMRequest request) { if (!ENABLE_CHECK_AUTH) { - return true; + return new AuthResultBuilder().setResultCode(ResultCode.Success).build(); } return doRegResourceManagerCheck(request); } - public abstract boolean doRegResourceManagerCheck(RegisterRMRequest request); + public abstract AuthResult doRegResourceManagerCheck(RegisterRMRequest request); + + @Override + public String fetchNewToken(AuthResult authResult) { + return null; + } + } diff --git a/server/src/main/java/org/apache/seata/server/auth/DefaultCheckAuthHandler.java b/server/src/main/java/org/apache/seata/server/auth/DefaultCheckAuthHandler.java index 05329eb27ee..105316082c8 100644 --- a/server/src/main/java/org/apache/seata/server/auth/DefaultCheckAuthHandler.java +++ b/server/src/main/java/org/apache/seata/server/auth/DefaultCheckAuthHandler.java @@ -17,21 +17,23 @@ package org.apache.seata.server.auth; import org.apache.seata.common.loader.LoadLevel; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.auth.AuthResultBuilder; import org.apache.seata.core.protocol.RegisterRMRequest; import org.apache.seata.core.protocol.RegisterTMRequest; +import org.apache.seata.core.protocol.ResultCode; /** */ -@LoadLevel(name = "defaultCheckAuthHandler", order = 100) +@LoadLevel(name = "defaultCheckAuthHandler", order = 2) public class DefaultCheckAuthHandler extends AbstractCheckAuthHandler { @Override - public boolean doRegTransactionManagerCheck(RegisterTMRequest request) { - return true; - } + public AuthResult doRegTransactionManagerCheck(RegisterTMRequest request) { + return new AuthResultBuilder().setResultCode(ResultCode.Success).build(); } @Override - public boolean doRegResourceManagerCheck(RegisterRMRequest request) { - return true; - } + public AuthResult doRegResourceManagerCheck(RegisterRMRequest request) { + return new AuthResultBuilder().setResultCode(ResultCode.Success).build(); } + } diff --git a/server/src/main/java/org/apache/seata/server/auth/JwtCheckAuthHandler.java b/server/src/main/java/org/apache/seata/server/auth/JwtCheckAuthHandler.java new file mode 100644 index 00000000000..8ddf3283b3d --- /dev/null +++ b/server/src/main/java/org/apache/seata/server/auth/JwtCheckAuthHandler.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.server.auth; + + +import org.apache.seata.common.loader.LoadLevel; +import org.apache.seata.common.util.StringUtils; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.auth.AuthResultBuilder; +import org.apache.seata.core.protocol.RegisterRMRequest; +import org.apache.seata.core.protocol.RegisterTMRequest; +import org.apache.seata.core.protocol.ResultCode; +import org.apache.seata.server.auth.utils.ManagerRegJwtTokenUtils; + +import java.util.HashMap; + + +@LoadLevel(name = "jwtCheckAuthHandler", order = 1) +public class JwtCheckAuthHandler extends AbstractCheckAuthHandler { + + private static final String PRO_TOKEN = "token"; + + private static final String PRO_REFRESH_TOKEN = "refresh_token"; + + private static final String PRO_USERNAME = "username"; + + private static final String PRO_PASSWORD = "password"; + + private static final ManagerRegJwtTokenUtils jwtTokenUtils = new ManagerRegJwtTokenUtils(); + + @Override + public AuthResult doRegTransactionManagerCheck(RegisterTMRequest request) { + return checkAuthData(request.getExtraData()); + } + + @Override + public AuthResult doRegResourceManagerCheck(RegisterRMRequest request) { + return checkAuthData(request.getExtraData()); + } + + private AuthResult checkAuthData(String extraData) { + HashMap extraDataMap = StringUtils.string2Map(extraData); + // 1.check username/password + String username = extraDataMap.get(PRO_USERNAME); + String password = extraDataMap.get(PRO_PASSWORD); + String accessToken = extraDataMap.get(PRO_TOKEN); + String refreshToken = extraDataMap.get(PRO_REFRESH_TOKEN); + if (username != null && password != null) { + return jwtTokenUtils.checkUsernamePassword(username, password); + } else if (accessToken != null) { + // 2.check token + return jwtTokenUtils.checkAccessToken(accessToken); + } else if (refreshToken != null) { + return jwtTokenUtils.checkRefreshToken(refreshToken); + } + return new AuthResultBuilder().setResultCode(ResultCode.Failed).build(); + } + + @Override + public String fetchNewToken(AuthResult authResult) { + if (authResult != null && authResult.getResultCode().equals(ResultCode.Success)) { + HashMap extraDataMap = new HashMap<>(); + if (authResult.getAccessToken() != null) { + extraDataMap.put(PRO_TOKEN, authResult.getAccessToken()); + } + if (authResult.getRefreshToken() != null) { + extraDataMap.put(PRO_REFRESH_TOKEN, authResult.getRefreshToken()); + } + return StringUtils.map2String(extraDataMap); + } + return null; + } +} diff --git a/server/src/main/java/org/apache/seata/server/auth/utils/ManagerRegJwtTokenUtils.java b/server/src/main/java/org/apache/seata/server/auth/utils/ManagerRegJwtTokenUtils.java new file mode 100644 index 00000000000..23df4456120 --- /dev/null +++ b/server/src/main/java/org/apache/seata/server/auth/utils/ManagerRegJwtTokenUtils.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.seata.server.auth.utils; + +import io.jsonwebtoken.*; +import io.jsonwebtoken.io.Decoders; +import org.apache.seata.common.ConfigurationKeys; +import org.apache.seata.common.util.StringUtils; +import org.apache.seata.config.ConfigurationFactory; +import org.apache.seata.core.auth.AuthResult; +import org.apache.seata.core.auth.AuthResultBuilder; +import org.apache.seata.core.protocol.ResultCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.crypto.spec.SecretKeySpec; +import java.util.Date; + +public class ManagerRegJwtTokenUtils { + + private static final Logger LOGGER = LoggerFactory.getLogger(ManagerRegJwtTokenUtils.class); + + private static final String AUTHORITIES_KEY = "auth"; + + private static final String secretKey = ConfigurationFactory.getInstance().getConfig(ConfigurationKeys.SECURITY_SECRET_KEY); + + private static final String accessTokenValidityInMilliseconds = ConfigurationFactory + .getInstance().getConfig(ConfigurationKeys.SECURITY_ACCESS_TOKEN_VALID_TIME); + + private static final String refreshTokenValidityInMilliseconds = ConfigurationFactory + .getInstance().getConfig(ConfigurationKeys.SECURITY_REFRESH_TOKEN_VALID_TIME); + + + /** + * Create access token + * @return token string + */ + public String createAccessToken(String username) { + return createToken(username, accessTokenValidityInMilliseconds); + } + + /** + * Create access token + * @return token string + */ + public String createRefreshToken(String username) { + return createToken(username, refreshTokenValidityInMilliseconds); + } + + private String createToken(String username, String tokenValidityInMilliseconds) { + /** + * Current time + */ + long now = (new Date()).getTime(); + /** + * Expiration date + */ + Date expirationDate = new Date(now + Long.parseLong(tokenValidityInMilliseconds)); + /** + * Key + */ + SecretKeySpec secretKeySpec = new SecretKeySpec(Decoders.BASE64.decode(secretKey), + SignatureAlgorithm.HS256.getJcaName()); + /** + * create token + */ + return Jwts.builder().setSubject(username).claim(AUTHORITIES_KEY, "").setExpiration( + expirationDate).signWith(secretKeySpec, SignatureAlgorithm.HS256).compact(); + } + + public AuthResult checkUsernamePassword(String username, String password) { + if(StringUtils.equals(username, ConfigurationFactory.getInstance().getConfig(ConfigurationKeys.SECURITY_USERNME)) + && StringUtils.equals(password, ConfigurationFactory.getInstance().getConfig(ConfigurationKeys.SECURITY_PASSWORD))){ + return new AuthResultBuilder() + .setResultCode(ResultCode.Success) + .setAccessToken(createAccessToken(username)) + .setRefreshToken(createRefreshToken(username)) + .build(); + } else { + return new AuthResultBuilder().setResultCode(ResultCode.Failed).build(); + } + } + + public AuthResult checkAccessToken(String accessToken) { + try { + Jws claimsJws = Jwts.parser().setSigningKey(secretKey).parseClaimsJws(accessToken); + Claims claims = claimsJws.getBody(); + Date expiration = claims.getExpiration(); + if (System.currentTimeMillis() > expiration.getTime() - Long.parseLong(accessTokenValidityInMilliseconds) / 3) { + LOGGER.warn("jwt token will be expired, need refresh token"); + return new AuthResultBuilder().setResultCode(ResultCode.AccessTokenNearExpiration).build(); + } + return new AuthResultBuilder().setResultCode(ResultCode.Success).build(); + } catch (ExpiredJwtException e) { + LOGGER.warn("jwt token has been expired: " + e); + return new AuthResultBuilder().setResultCode(ResultCode.AccessTokenExpired).build(); + } catch (Exception e) { + LOGGER.error("jwt token authentication failed: " + e); + return new AuthResultBuilder().setResultCode(ResultCode.Failed).build(); + } + } + + public AuthResult checkRefreshToken(String refreshToken) { + try { + Jws claimsJws = Jwts.parser().setSigningKey(secretKey).parseClaimsJws(refreshToken); + return new AuthResultBuilder().setResultCode(ResultCode.Success) + .setAccessToken(createAccessToken(claimsJws.getBody().getSubject())) + .build(); + } catch (ExpiredJwtException e) { + LOGGER.warn("jwt token has been expired: " + e); + return new AuthResultBuilder().setResultCode(ResultCode.RefreshTokenExpired).build(); + } catch (Exception e) { + LOGGER.error("jwt token authentication failed: " + e); + return new AuthResultBuilder().setResultCode(ResultCode.Failed).build(); + } + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.seata.core.rpc.RegisterCheckAuthHandler b/server/src/main/resources/META-INF/services/org.apache.seata.core.rpc.RegisterCheckAuthHandler index 4b0b8c359cf..97d5815b0ed 100644 --- a/server/src/main/resources/META-INF/services/org.apache.seata.core.rpc.RegisterCheckAuthHandler +++ b/server/src/main/resources/META-INF/services/org.apache.seata.core.rpc.RegisterCheckAuthHandler @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -org.apache.seata.server.auth.DefaultCheckAuthHandler \ No newline at end of file +org.apache.seata.server.auth.JwtCheckAuthHandler diff --git a/server/src/main/resources/application.yml b/server/src/main/resources/application.yml index dad4e4ba7f0..d0cb9b9031d 100644 --- a/server/src/main/resources/application.yml +++ b/server/src/main/resources/application.yml @@ -50,8 +50,13 @@ seata: # server: # service-port: 8091 #If not configured, the default is '${server.port} + 1000' security: + username: seata + password: seata secretKey: SeataSecretKey0c382ef121d778043159209298fd40bf3850a017 tokenValidityInMilliseconds: 1800000 + accessTokenValidityInMilliseconds: 600000 + refreshTokenValidityInMilliseconds: 86400000 csrf-ignore-urls: /metadata/v1/** ignore: - urls: /,/**/*.css,/**/*.js,/**/*.html,/**/*.map,/**/*.svg,/**/*.png,/**/*.jpeg,/**/*.ico,/api/v1/auth/login,/version.json,/health,/error,/vgroup/v1/** \ No newline at end of file + urls: /,/**/*.css,/**/*.js,/**/*.html,/**/*.map,/**/*.svg,/**/*.png,/**/*.jpeg,/**/*.ico,/api/v1/auth/login,/version.json,/health,/error,/vgroup/v1/** +