Skip to content

Commit

Permalink
perf($ApiGateway): enhance global exception handler
Browse files Browse the repository at this point in the history
Abandon custom HttpStatus, use unified Spring Web HttpStatus

BREAKING CHANGE: abandon custom HttpStatus, use unified Spring Web HttpStatus
  • Loading branch information
Johnny Miller (锺俊) committed Dec 23, 2020
1 parent 7bb6ffa commit 0d3f265
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 548 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import com.jmsoftware.maf.apigateway.security.configuration.JwtConfiguration;
import com.jmsoftware.maf.apigateway.universal.configuration.Constants;
import com.jmsoftware.maf.apigateway.universal.configuration.RedisService;
import com.jmsoftware.maf.common.constant.HttpStatus;
import com.jmsoftware.maf.common.exception.SecurityException;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -82,7 +82,7 @@ public String createJwt(Boolean rememberMe, Long id, String subject, List<String
if (redisOperationResult) {
return jwt;
} else {
throw new SecurityException(HttpStatus.ERROR, "Cannot persist JWT into Redis", null);
throw new SecurityException(HttpStatus.INTERNAL_SERVER_ERROR, "Cannot persist JWT into Redis", null);
}
}

Expand All @@ -91,34 +91,34 @@ public Claims parseJwt(String jwt) throws SecurityException {
Claims claims;
try {
claims = Optional.ofNullable(jwtParser.parseClaimsJws(jwt).getBody())
.orElseThrow(() -> new SecurityException(HttpStatus.TOKEN_PARSE_ERROR,
.orElseThrow(() -> new SecurityException(HttpStatus.INTERNAL_SERVER_ERROR,
"The JWT Claims Set is null", null));
} catch (ExpiredJwtException e) {
log.error("JWT was expired. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_EXPIRED);
log.error("JWT is expired. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (JWT itself)");
} catch (UnsupportedJwtException e) {
log.error("JWT is unsupported. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is unsupported");
} catch (MalformedJwtException e) {
log.error("JWT is invalid. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is invalid");
} catch (IllegalArgumentException e) {
log.error("The parameter of JWT is invalid. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "The parameter of JWT is invalid");
}
val username = claims.getSubject();
val redisKeyOfJwt = Constants.REDIS_JWT_KEY_PREFIX + username;
// Check if JWT exists
val expire = redisService.getExpire(redisKeyOfJwt, TimeUnit.MILLISECONDS);
if (ObjectUtil.isNull(expire) || expire <= 0) {
throw new SecurityException(HttpStatus.TOKEN_EXPIRED);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (Redis expiration)");
}
// Check if the current JWT is equal to the one in Redis.
// If it's noe equal, that indicates current user has signed out or logged in before.
// Both situations reveal the JWT has expired.
val jwtInRedis = redisService.get(redisKeyOfJwt);
if (!StrUtil.equals(jwt, jwtInRedis)) {
throw new SecurityException(HttpStatus.TOKEN_OUT_OF_CONTROL);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (Not equaled)");
}
return claims;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.jmsoftware.maf.common.bean.ResponseBodyBean;
import com.jmsoftware.maf.muscleandfitnessserverreactivespringbootstarter.util.RequestUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.springframework.boot.web.reactive.error.ErrorWebExceptionHandler;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBufferFactory;
Expand All @@ -16,8 +18,6 @@
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.Date;

/**
* <h1>GlobalExceptionHandler</h1>
* <p>
Expand All @@ -36,27 +36,41 @@ public class GlobalExceptionHandler implements ErrorWebExceptionHandler {
@Override
@SuppressWarnings("NullableProblems")
public Mono<Void> handle(ServerWebExchange exchange, Throwable ex) {
val request = exchange.getRequest();
log.error("Exception occurred when [{}] requested access. Request URL: [{}] {}",
RequestUtil.getRequestIpAndPort(request), request.getMethod(), request.getURI());
ServerHttpResponse response = exchange.getResponse();

log.error(ex.getMessage(), ex.fillInStackTrace());
if (response.isCommitted()) {
return Mono.error(ex);
}

// header set
// Set HTTP header
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
if (ex instanceof ResponseStatusException) {
response.setStatusCode(((ResponseStatusException) ex).getStatus());
}

return response.writeWith(Mono.fromSupplier(() -> {
DataBufferFactory bufferFactory = response.bufferFactory();
final var responseBody = ResponseBodyBean.builder().timestamp(new Date()).status(HttpStatus.INTERNAL_SERVER_ERROR.value()).message(ex.getMessage()).build();
final var responseBody = setResponseBody(response, ex);
try {
return bufferFactory.wrap(objectMapper.writeValueAsBytes(responseBody));
} catch (JsonProcessingException e) {
log.warn("Error writing response", ex);
log.warn("Exception occurred when writing response", e);
return bufferFactory.wrap(new byte[0]);
}
}));
}

/**
* Sets response body.
*
* @param response the response
* @param ex the ex
* @return the response body
*/
private ResponseBodyBean<?> setResponseBody(ServerHttpResponse response, Throwable ex) {
if (ex instanceof ResponseStatusException) {
response.setStatusCode(((ResponseStatusException) ex).getStatus());
return ResponseBodyBean.ofStatus(((ResponseStatusException) ex).getStatus());
}
response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR);
return ResponseBodyBean.ofStatus(HttpStatus.INTERNAL_SERVER_ERROR, ex.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import com.jmsoftware.maf.authcenter.universal.domain.UserPrincipal;
import com.jmsoftware.maf.authcenter.universal.service.JwtService;
import com.jmsoftware.maf.authcenter.universal.service.RedisService;
import com.jmsoftware.maf.common.constant.HttpStatus;
import com.jmsoftware.maf.common.exception.SecurityException;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -85,7 +85,7 @@ public String createJwt(Boolean rememberMe, Long id, String subject, List<String
if (redisOperationResult) {
return jwt;
} else {
throw new SecurityException(HttpStatus.ERROR, "Cannot persist JWT into Redis", null);
throw new SecurityException(HttpStatus.INTERNAL_SERVER_ERROR, "Cannot persist JWT into Redis", null);
}
}

Expand All @@ -94,34 +94,34 @@ public Claims parseJwt(String jwt) throws SecurityException {
Claims claims;
try {
claims = Optional.ofNullable(jwtParser.parseClaimsJws(jwt).getBody())
.orElseThrow(() -> new SecurityException(HttpStatus.TOKEN_PARSE_ERROR,
.orElseThrow(() -> new SecurityException(HttpStatus.INTERNAL_SERVER_ERROR,
"The JWT Claims Set is null", null));
} catch (ExpiredJwtException e) {
log.error("JWT is expired. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_EXPIRED);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (JWT itself)");
} catch (UnsupportedJwtException e) {
log.error("JWT is unsupported. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is unsupported");
} catch (MalformedJwtException e) {
log.error("JWT is invalid. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is invalid");
} catch (IllegalArgumentException e) {
log.error("The parameter of JWT is invalid. Message: {} JWT: {}", e.getMessage(), jwt);
throw new SecurityException(HttpStatus.TOKEN_PARSE_ERROR);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "The parameter of JWT is invalid");
}
val username = claims.getSubject();
val redisKeyOfJwt = Constants.REDIS_JWT_KEY_PREFIX + username;
// Check if JWT exists
val expire = redisService.getExpire(redisKeyOfJwt, TimeUnit.MILLISECONDS);
if (ObjectUtil.isNull(expire) || expire <= 0) {
throw new SecurityException(HttpStatus.TOKEN_EXPIRED);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (Redis expiration)");
}
// Check if the current JWT is equal to the one in Redis.
// If it's noe equal, that indicates current user has signed out or logged in before.
// Both situations reveal the JWT has expired.
val jwtInRedis = redisService.get(redisKeyOfJwt);
if (!StrUtil.equals(jwt, jwtInRedis)) {
throw new SecurityException(HttpStatus.TOKEN_OUT_OF_CONTROL);
throw new SecurityException(HttpStatus.UNAUTHORIZED, "JWT is expired (Not equaled)");
}
return claims;
}
Expand Down
4 changes: 4 additions & 0 deletions common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
<artifactId>spring-boot-starter-web</artifactId>
</dependency>-->

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import cn.hutool.json.JSONConfig;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.jmsoftware.maf.common.constant.HttpStatus;
import com.jmsoftware.maf.common.constant.IUniversalStatus;
import com.jmsoftware.maf.common.exception.BaseException;
import com.jmsoftware.maf.common.exception.BusinessException;
import lombok.Builder;
import lombok.NonNull;
import lombok.Value;
import lombok.val;
import org.springframework.http.HttpStatus;
import org.springframework.lang.Nullable;

import java.io.Serializable;
Expand Down Expand Up @@ -63,11 +62,11 @@ public class ResponseBodyBean<ResponseBodyDataType> implements Serializable {
* @param status IUniversalStatus
* @return response body for ExceptionControllerAdvice javax.servlet.http.HttpServletResponse, Exception)
*/
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofStatus(@NonNull final IUniversalStatus status) {
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofStatus(@NonNull final HttpStatus status) {
return ResponseBodyBean.<ResponseBodyDataType>builder()
.timestamp(new Date())
.status(status.getCode())
.message(status.getMessage())
.status(status.value())
.message(status.getReasonPhrase())
.build();
}

Expand All @@ -82,12 +81,12 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSt
* @param data data to be responded to client
* @return response body for ExceptionControllerAdvice
*/
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofStatus(@NonNull final IUniversalStatus status,
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofStatus(@NonNull final HttpStatus status,
@NonNull final ResponseBodyDataType data) {
return ResponseBodyBean.<ResponseBodyDataType>builder()
.timestamp(new Date())
.status(status.getCode())
.message(status.getMessage())
.status(status.value())
.message(status.getReasonPhrase())
.data(data)
.build();
}
Expand Down Expand Up @@ -130,7 +129,7 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> setR
@NonNull final String message,
final ResponseBodyDataType data)
throws BaseException {
if (!HttpStatus.OK.getCode().equals(status)) {
if (!HttpStatus.valueOf(status).is2xxSuccessful()) {
throw new BaseException(status, message, data);
}
return ResponseBodyBean.<ResponseBodyDataType>builder()
Expand All @@ -150,8 +149,8 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> setR
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSuccess() {
return ResponseBodyBean.<ResponseBodyDataType>builder()
.timestamp(new Date())
.status(HttpStatus.OK.getCode())
.message(HttpStatus.OK.getMessage())
.status(HttpStatus.OK.value())
.message(HttpStatus.OK.getReasonPhrase())
.build();
}

Expand All @@ -165,8 +164,8 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSu
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSuccess(final ResponseBodyDataType data) {
return ResponseBodyBean.<ResponseBodyDataType>builder()
.timestamp(new Date())
.status(HttpStatus.OK.getCode())
.message(HttpStatus.OK.getMessage())
.status(HttpStatus.OK.value())
.message(HttpStatus.OK.getReasonPhrase())
.data(data)
.build();
}
Expand All @@ -180,7 +179,7 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSu
*/
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSuccess(@NonNull final String message) {
return ResponseBodyBean.<ResponseBodyDataType>builder().timestamp(new Date())
.status(HttpStatus.OK.getCode())
.status(HttpStatus.OK.value())
.message(message)
.build();
}
Expand All @@ -196,7 +195,7 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSu
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofSuccess(final ResponseBodyDataType data,
@NonNull final String message) {
return ResponseBodyBean.<ResponseBodyDataType>builder().timestamp(new Date())
.status(HttpStatus.OK.getCode())
.status(HttpStatus.OK.value())
.message(message)
.data(data)
.build();
Expand Down Expand Up @@ -244,7 +243,8 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofFa
* @return response body
*/
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofError() throws BaseException {
return setResponse(HttpStatus.ERROR.getCode(), HttpStatus.ERROR.getMessage(), null);
return setResponse(HttpStatus.INTERNAL_SERVER_ERROR.value(), HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(),
null);
}

/**
Expand All @@ -254,9 +254,9 @@ public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofEr
* @param status Error status, not OK(200)
* @return response body
*/
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofError(@NonNull final IUniversalStatus status)
public static <ResponseBodyDataType> ResponseBodyBean<ResponseBodyDataType> ofError(@NonNull final HttpStatus status)
throws BaseException {
return setResponse(status.getCode(), status.getMessage(), null);
return setResponse(status.value(), status.getReasonPhrase(), null);
}

/**
Expand Down
Loading

0 comments on commit 0d3f265

Please sign in to comment.