Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

优化ws握手认证 #82

Merged
merged 6 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.JWTVerifier;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.junit.Test;

import java.util.Date;
Expand All @@ -18,7 +17,7 @@ public class CreateTokenTest {
@Test
public void create(){
String token = JWT.create()
.withClaim("uid", 123L) // 只存一个uid信息,其他的自己去redis查
.withClaim("uid", 10004L) // 只存一个uid信息,其他的自己去redis查
.withClaim("createTime", new Date())
.sign(Algorithm.HMAC256("dsfsdfsdfsdfsd")); // signature
log.info("生成的token为 {}",token);
Expand All @@ -32,13 +31,4 @@ public void create(){
log.info("decode error,token:{}", token, e);
}
}

@Test
public void verifyToken(){
String token = JWT.create()
.withClaim("uid", 1) // 只存一个uid信息,其他的自己去redis查
.withClaim("createTime", new Date())
.sign(Algorithm.HMAC256("dsfsdfsdfsdfsd")); // signature
log.info("生成的token为{}",token);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Condition;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.abin.mallchat.custom.user.websocket;

import cn.hutool.core.net.url.UrlBuilder;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
Expand All @@ -13,15 +14,27 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
HttpHeaders headers = ((FullHttpRequest) msg).headers();
FullHttpRequest request = (FullHttpRequest) msg;
UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri());

// 获取token参数
String token = urlBuilder.getQuery().get("token").toString();
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);

// 获取请求路径
request.setUri(urlBuilder.getPath().toString());
HttpHeaders headers = request.headers();
String ip = headers.get("X-Real-IP");
if (StringUtils.isEmpty(ip)) {//如果没经过nginx,就直接获取远端地址
InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
ip = address.getAddress().getHostAddress();
}
NettyUtil.setAttr(ctx.channel(), NettyUtil.IP, ip);
ctx.pipeline().remove(this);
ctx.fireChannelRead(request);
}else
{
ctx.fireChannelRead(msg);
}
ctx.fireChannelRead(msg);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.abin.mallchat.custom.user.websocket;

import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;

Expand All @@ -15,6 +16,7 @@ public class NettyUtil {
public static AttributeKey<String> TOKEN = AttributeKey.valueOf("token");
public static AttributeKey<String> IP = AttributeKey.valueOf("ip");
public static AttributeKey<Long> UID = AttributeKey.valueOf("uid");
public static AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");

public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
Attribute<T> attr = channel.attr(attributeKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
Expand Down Expand Up @@ -88,7 +89,7 @@ protected void initChannel(SocketChannel socketChannel) throws Exception {
* 4. WebSocketServerProtocolHandler 核心功能是把 http协议升级为 ws 协议,保持长连接;
* 是通过一个状态码 101 来切换的
*/
pipeline.addLast(new WebSocketHandshakeHandler());
pipeline.addLast(new WebSocketServerProtocolHandler("/"));
// 自定义handler ,处理业务逻辑
pipeline.addLast(new NettyWebSocketServerHandler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
@Slf4j
public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

private WebSocketService webSocketService;

// 当web客户端连接后,触发该方法
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// getService().connect(ctx.channel());
this.webSocketService = getService();
}

// 客户端离线
Expand All @@ -45,7 +47,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

private void userOffLine(ChannelHandlerContext ctx) {
getService().removed(ctx.channel());
this.webSocketService.removed(ctx.channel());
ctx.channel().close();
}

Expand All @@ -66,10 +68,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
userOffLine(ctx);
}
} else if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
getService().connect(ctx.channel());
this.webSocketService.connect(ctx.channel());
String token = NettyUtil.getAttr(ctx.channel(), NettyUtil.TOKEN);
if (StrUtil.isNotBlank(token)) {
getService().authorize(ctx.channel(), new WSAuthorize(token));
this.webSocketService.authorize(ctx.channel(), new WSAuthorize(token));
}
}
super.userEventTriggered(ctx, evt);
Expand All @@ -93,13 +95,13 @@ protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) t
WSReqTypeEnum wsReqTypeEnum = WSReqTypeEnum.of(wsBaseReq.getType());
switch (wsReqTypeEnum) {
case LOGIN:
getService().handleLoginReq(ctx.channel());
this.webSocketService.handleLoginReq(ctx.channel());
log.info("请求二维码 = " + msg.text());
break;
case HEARTBEAT:
break;
case AUTHORIZE:
getService().authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
this.webSocketService.authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
log.info("主动认证 = " + msg.text());
break;
default:
Expand Down

This file was deleted.