Skip to content

Commit

Permalink
Merge pull request #82 from zbzbzzz/main
Browse files Browse the repository at this point in the history
优化ws握手认证
  • Loading branch information
zongzibinbin authored Jul 5, 2023
2 parents 1107075 + 919bdc7 commit c0cce23
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.JWTVerifier;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.junit.Test;

import java.util.Date;
Expand All @@ -18,7 +17,7 @@ public class CreateTokenTest {
@Test
public void create(){
String token = JWT.create()
.withClaim("uid", 123L) // 只存一个uid信息,其他的自己去redis查
.withClaim("uid", 10004L) // 只存一个uid信息,其他的自己去redis查
.withClaim("createTime", new Date())
.sign(Algorithm.HMAC256("dsfsdfsdfsdfsd")); // signature
log.info("生成的token为 {}",token);
Expand All @@ -32,13 +31,4 @@ public void create(){
log.info("decode error,token:{}", token, e);
}
}

@Test
public void verifyToken(){
String token = JWT.create()
.withClaim("uid", 1) // 只存一个uid信息,其他的自己去redis查
.withClaim("createTime", new Date())
.sign(Algorithm.HMAC256("dsfsdfsdfsdfsd")); // signature
log.info("生成的token为{}",token);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Condition;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.abin.mallchat.custom.user.websocket;

import cn.hutool.core.net.url.UrlBuilder;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
Expand All @@ -13,15 +14,27 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
HttpHeaders headers = ((FullHttpRequest) msg).headers();
FullHttpRequest request = (FullHttpRequest) msg;
UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri());

// 获取token参数
String token = urlBuilder.getQuery().get("token").toString();
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);

// 获取请求路径
request.setUri(urlBuilder.getPath().toString());
HttpHeaders headers = request.headers();
String ip = headers.get("X-Real-IP");
if (StringUtils.isEmpty(ip)) {//如果没经过nginx,就直接获取远端地址
InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
ip = address.getAddress().getHostAddress();
}
NettyUtil.setAttr(ctx.channel(), NettyUtil.IP, ip);
ctx.pipeline().remove(this);
ctx.fireChannelRead(request);
}else
{
ctx.fireChannelRead(msg);
}
ctx.fireChannelRead(msg);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.abin.mallchat.custom.user.websocket;

import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;

Expand All @@ -15,6 +16,7 @@ public class NettyUtil {
public static AttributeKey<String> TOKEN = AttributeKey.valueOf("token");
public static AttributeKey<String> IP = AttributeKey.valueOf("ip");
public static AttributeKey<Long> UID = AttributeKey.valueOf("uid");
public static AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");

public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
Attribute<T> attr = channel.attr(attributeKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
Expand Down Expand Up @@ -88,7 +89,7 @@ protected void initChannel(SocketChannel socketChannel) throws Exception {
* 4. WebSocketServerProtocolHandler 核心功能是把 http协议升级为 ws 协议,保持长连接;
* 是通过一个状态码 101 来切换的
*/
pipeline.addLast(new WebSocketHandshakeHandler());
pipeline.addLast(new WebSocketServerProtocolHandler("/"));
// 自定义handler ,处理业务逻辑
pipeline.addLast(new NettyWebSocketServerHandler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
@Slf4j
public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

private WebSocketService webSocketService;

// 当web客户端连接后,触发该方法
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// getService().connect(ctx.channel());
this.webSocketService = getService();
}

// 客户端离线
Expand All @@ -45,7 +47,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

private void userOffLine(ChannelHandlerContext ctx) {
getService().removed(ctx.channel());
this.webSocketService.removed(ctx.channel());
ctx.channel().close();
}

Expand All @@ -66,10 +68,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
userOffLine(ctx);
}
} else if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
getService().connect(ctx.channel());
this.webSocketService.connect(ctx.channel());
String token = NettyUtil.getAttr(ctx.channel(), NettyUtil.TOKEN);
if (StrUtil.isNotBlank(token)) {
getService().authorize(ctx.channel(), new WSAuthorize(token));
this.webSocketService.authorize(ctx.channel(), new WSAuthorize(token));
}
}
super.userEventTriggered(ctx, evt);
Expand All @@ -93,13 +95,13 @@ protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) t
WSReqTypeEnum wsReqTypeEnum = WSReqTypeEnum.of(wsBaseReq.getType());
switch (wsReqTypeEnum) {
case LOGIN:
getService().handleLoginReq(ctx.channel());
this.webSocketService.handleLoginReq(ctx.channel());
log.info("请求二维码 = " + msg.text());
break;
case HEARTBEAT:
break;
case AUTHORIZE:
getService().authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
this.webSocketService.authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
log.info("主动认证 = " + msg.text());
break;
default:
Expand Down

This file was deleted.

0 comments on commit c0cce23

Please sign in to comment.