diff --git a/internal/base/constant/cache_key.go b/internal/base/constant/cache_key.go index 31b6b7520..f8b80d094 100644 --- a/internal/base/constant/cache_key.go +++ b/internal/base/constant/cache_key.go @@ -7,6 +7,9 @@ const ( UserStatusChangedCacheTime = 7 * 24 * time.Hour UserTokenCacheKey = "answer:user:token:" UserTokenCacheTime = 7 * 24 * time.Hour + UserVisitTokenCacheKey = "answer:user:visit:" + UserVisitCacheTime = 7 * 24 * 60 * 60 + UserVisitCookiesCacheKey = "visit" AdminTokenCacheKey = "answer:admin:token:" AdminTokenCacheTime = 7 * 24 * time.Hour UserTokenMappingCacheKey = "answer:user-token:mapping:" diff --git a/internal/base/middleware/visit_img_auth.go b/internal/base/middleware/visit_img_auth.go new file mode 100644 index 000000000..5ac97b340 --- /dev/null +++ b/internal/base/middleware/visit_img_auth.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "github.com/answerdev/answer/internal/base/constant" + "github.com/gin-gonic/gin" + "net/http" +) + +// VisitAuth when user visit the site image, check visit token. This only for private mode. +func (am *AuthUserMiddleware) VisitAuth() gin.HandlerFunc { + return func(ctx *gin.Context) { + siteLogin, err := am.siteInfoCommonService.GetSiteLogin(ctx) + if err != nil { + return + } + if !siteLogin.LoginRequired { + ctx.Next() + return + } + + visitToken, err := ctx.Cookie(constant.UserVisitCookiesCacheKey) + if err != nil || len(visitToken) == 0 { + ctx.Abort() + ctx.Redirect(http.StatusFound, "/403") + return + } + + if !am.authService.CheckUserVisitToken(ctx, visitToken) { + ctx.Abort() + ctx.Redirect(http.StatusFound, "/403") + return + } + } +} diff --git a/internal/base/server/http.go b/internal/base/server/http.go index c5a5697a8..e15a5ae7b 100644 --- a/internal/base/server/http.go +++ b/internal/base/server/http.go @@ -43,7 +43,7 @@ func NewHTTPServer(debug bool, rootGroup := r.Group("") swaggerRouter.Register(rootGroup) static := r.Group("") - static.Use(avatarMiddleware.AvatarThumb()) + static.Use(avatarMiddleware.AvatarThumb(), authUserMiddleware.VisitAuth()) staticRouter.RegisterStaticRouter(static) // The route must be available without logging in diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 47a394162..6d0166887 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -1,6 +1,7 @@ package controller import ( + "github.com/answerdev/answer/internal/base/constant" "github.com/answerdev/answer/internal/base/handler" "github.com/answerdev/answer/internal/base/middleware" "github.com/answerdev/answer/internal/base/reason" @@ -18,6 +19,7 @@ import ( "github.com/gin-gonic/gin" "github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/log" + "net/url" ) // UserController user controller @@ -73,6 +75,7 @@ func (uc *UserController) GetUserInfoByUserID(ctx *gin.Context) { } resp, err := uc.userService.GetUserInfoByUserID(ctx, token, userInfo.UserID) + uc.setVisitCookies(ctx, userInfo.VisitToken) handler.HandleResponse(ctx, err, resp) } @@ -136,6 +139,7 @@ func (uc *UserController) UserEmailLogin(ctx *gin.Context) { if !isAdmin { uc.actionService.ActionRecordDel(ctx, entity.CaptchaActionPassword, ctx.ClientIP()) } + uc.setVisitCookies(ctx, resp.VisitToken) handler.HandleResponse(ctx, nil, resp) } @@ -673,3 +677,22 @@ func (uc *UserController) SearchUserListByName(ctx *gin.Context) { resp, err := uc.userService.SearchUserListByName(ctx, req) handler.HandleResponse(ctx, err, resp) } + +func (uc *UserController) setVisitCookies(ctx *gin.Context, visitToken string) { + cookie, err := ctx.Cookie(constant.UserVisitCookiesCacheKey) + if err == nil && len(cookie) > 0 { + return + } + general, err := uc.siteInfoCommonService.GetSiteGeneral(ctx) + if err != nil { + log.Errorf("get site general error: %v", err) + return + } + parsedURL, err := url.Parse(general.SiteUrl) + if err != nil { + log.Errorf("parse url error: %v", err) + return + } + ctx.SetCookie(constant.UserVisitCookiesCacheKey, + visitToken, constant.UserVisitCacheTime, "/", parsedURL.Host, true, true) +} diff --git a/internal/entity/auth_user_entity.go b/internal/entity/auth_user_entity.go index 79d24fb12..8a1c51b14 100644 --- a/internal/entity/auth_user_entity.go +++ b/internal/entity/auth_user_entity.go @@ -7,4 +7,5 @@ type UserCacheInfo struct { EmailStatus int `json:"email_status"` RoleID int `json:"role_id"` ExternalID string `json:"external_id"` + VisitToken string `json:"visit_token"` } diff --git a/internal/repo/auth/auth.go b/internal/repo/auth/auth.go index 1fa201c21..1826e96d0 100644 --- a/internal/repo/auth/auth.go +++ b/internal/repo/auth/auth.go @@ -40,7 +40,9 @@ func (ar *authRepo) GetUserCacheInfo(ctx context.Context, accessToken string) (u } // SetUserCacheInfo set user cache info -func (ar *authRepo) SetUserCacheInfo(ctx context.Context, accessToken string, userInfo *entity.UserCacheInfo) (err error) { +func (ar *authRepo) SetUserCacheInfo(ctx context.Context, + accessToken, visitToken string, userInfo *entity.UserCacheInfo) (err error) { + userInfo.VisitToken = visitToken userInfoCache, err := json.Marshal(userInfo) if err != nil { return err @@ -53,9 +55,28 @@ func (ar *authRepo) SetUserCacheInfo(ctx context.Context, accessToken string, us if err := ar.AddUserTokenMapping(ctx, userInfo.UserID, accessToken); err != nil { log.Error(err) } + if len(visitToken) == 0 { + return nil + } + if err := ar.data.Cache.SetString(ctx, constant.UserVisitTokenCacheKey+visitToken, + accessToken, constant.UserTokenCacheTime); err != nil { + log.Error(err) + } return nil } +// GetUserVisitCacheInfo get user visit cache info +func (ar *authRepo) GetUserVisitCacheInfo(ctx context.Context, visitToken string) (accessToken string, err error) { + accessToken, exist, err := ar.data.Cache.GetString(ctx, constant.UserVisitTokenCacheKey+visitToken) + if err != nil { + return "", errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() + } + if !exist { + return "", nil + } + return accessToken, nil +} + // RemoveUserCacheInfo remove user cache info func (ar *authRepo) RemoveUserCacheInfo(ctx context.Context, accessToken string) (err error) { err = ar.data.Cache.Del(ctx, constant.UserTokenCacheKey+accessToken) diff --git a/internal/schema/user_schema.go b/internal/schema/user_schema.go index 93b6291a4..f8bd66b7c 100644 --- a/internal/schema/user_schema.go +++ b/internal/schema/user_schema.go @@ -69,6 +69,8 @@ type UserLoginResp struct { Status string `json:"status"` // user have password HavePassword bool `json:"have_password"` + // visit token + VisitToken string `json:"visit_token"` } func (r *UserLoginResp) ConvertFromUserEntity(userInfo *entity.User) { diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 33e987eba..bea9a044b 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -6,13 +6,13 @@ import ( "github.com/answerdev/answer/internal/entity" "github.com/answerdev/answer/pkg/token" "github.com/answerdev/answer/plugin" - "github.com/segmentfault/pacman/log" ) // AuthRepo auth repository type AuthRepo interface { GetUserCacheInfo(ctx context.Context, accessToken string) (userInfo *entity.UserCacheInfo, err error) - SetUserCacheInfo(ctx context.Context, accessToken string, userInfo *entity.UserCacheInfo) error + SetUserCacheInfo(ctx context.Context, accessToken, visitToken string, userInfo *entity.UserCacheInfo) error + GetUserVisitCacheInfo(ctx context.Context, visitToken string) (accessToken string, err error) RemoveUserCacheInfo(ctx context.Context, accessToken string) (err error) SetUserStatus(ctx context.Context, userID string, userInfo *entity.UserCacheInfo) (err error) GetUserStatus(ctx context.Context, userID string) (userInfo *entity.UserCacheInfo, err error) @@ -50,7 +50,7 @@ func (as *AuthService) GetUserCacheInfo(ctx context.Context, accessToken string) userCacheInfo.EmailStatus = cacheInfo.EmailStatus userCacheInfo.RoleID = cacheInfo.RoleID // update current user cache info - err := as.authRepo.SetUserCacheInfo(ctx, accessToken, userCacheInfo) + err := as.authRepo.SetUserCacheInfo(ctx, accessToken, userCacheInfo.VisitToken, userCacheInfo) if err != nil { return nil, err } @@ -66,25 +66,30 @@ func (as *AuthService) GetUserCacheInfo(ctx context.Context, accessToken string) return userCacheInfo, nil } -func (as *AuthService) SetUserCacheInfo(ctx context.Context, userInfo *entity.UserCacheInfo) (accessToken string, err error) { +func (as *AuthService) SetUserCacheInfo(ctx context.Context, userInfo *entity.UserCacheInfo) ( + accessToken string, visitToken string, err error) { accessToken = token.GenerateToken() - err = as.authRepo.SetUserCacheInfo(ctx, accessToken, userInfo) - return accessToken, err -} - -func (as *AuthService) SetUserStatus(ctx context.Context, userInfo *entity.UserCacheInfo) (err error) { - return as.authRepo.SetUserStatus(ctx, userInfo.UserID, userInfo) + visitToken = token.GenerateToken() + err = as.authRepo.SetUserCacheInfo(ctx, accessToken, visitToken, userInfo) + if err != nil { + return "", "", err + } + return accessToken, visitToken, err } -func (as *AuthService) UpdateUserCacheInfo(ctx context.Context, token string, userInfo *entity.UserCacheInfo) (err error) { - err = as.authRepo.SetUserCacheInfo(ctx, token, userInfo) +func (as *AuthService) CheckUserVisitToken(ctx context.Context, visitToken string) bool { + accessToken, err := as.authRepo.GetUserVisitCacheInfo(ctx, visitToken) if err != nil { - return err + return false } - if err := as.authRepo.RemoveUserStatus(ctx, userInfo.UserID); err != nil { - log.Error(err) + if len(accessToken) == 0 { + return false } - return + return true +} + +func (as *AuthService) SetUserStatus(ctx context.Context, userInfo *entity.UserCacheInfo) (err error) { + return as.authRepo.SetUserStatus(ctx, userInfo.UserID, userInfo) } func (as *AuthService) RemoveUserCacheInfo(ctx context.Context, accessToken string) (err error) { diff --git a/internal/service/user_common/user.go b/internal/service/user_common/user.go index 35948e409..2f76020ab 100644 --- a/internal/service/user_common/user.go +++ b/internal/service/user_common/user.go @@ -195,7 +195,7 @@ func (us *UserCommon) CacheLoginUserInfo(ctx context.Context, userID string, use ExternalID: externalID, } - accessToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) + accessToken, _, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) if err != nil { return "", nil, err } diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 560ff7c2c..cca6f7770 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -153,7 +153,7 @@ func (us *UserService) EmailLogin(ctx context.Context, req *schema.UserEmailLogi RoleID: roleID, ExternalID: externalID, } - resp.AccessToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) + resp.AccessToken, resp.VisitToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) if err != nil { return nil, err } @@ -436,7 +436,7 @@ func (us *UserService) UserRegisterByEmail(ctx context.Context, registerUserInfo UserStatus: userInfo.Status, RoleID: roleID, } - resp.AccessToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) + resp.AccessToken, resp.VisitToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) if err != nil { return nil, nil, err } @@ -640,7 +640,7 @@ func (us *UserService) UserChangeEmailVerify(ctx context.Context, content string UserStatus: userInfo.Status, RoleID: roleID, } - resp.AccessToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) + resp.AccessToken, resp.VisitToken, err = us.authService.SetUserCacheInfo(ctx, userCacheInfo) if err != nil { return nil, err }