From 10368f9c822a728f31eef827ef5b0bdcac3a4447 Mon Sep 17 00:00:00 2001 From: guosl Date: Thu, 4 Jul 2024 18:05:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=96=E6=B6=88=E7=99=BB=E9=99=86=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/user/interface.go | 1 + internal/service/user/user.go | 12 +++++++++++- internal/service/user/util.go | 22 ++++++++++++++++------ pkg/middleware/jwt/jwt.go | 7 ++++++- pkg/utils/gin/handler.go | 15 +++++++++++++++ pkg/utils/token/token.go | 15 +++++++++++++++ server/user/user.go | 7 ++++++- 7 files changed, 70 insertions(+), 9 deletions(-) diff --git a/internal/service/user/interface.go b/internal/service/user/interface.go index f6119b1..6547c5f 100644 --- a/internal/service/user/interface.go +++ b/internal/service/user/interface.go @@ -8,4 +8,5 @@ import ( type UserService interface { Add(ctx context.Context, info *models.AddInfo) (id uint, err error) Login(ctx context.Context, lInfo models.LoginInfo) error + Logout(ctx context.Context) error } diff --git a/internal/service/user/user.go b/internal/service/user/user.go index f34b72d..6169c40 100644 --- a/internal/service/user/user.go +++ b/internal/service/user/user.go @@ -122,10 +122,20 @@ func (u *userService) Login(ctx context.Context, info models.LoginInfo) error { return u.setLoginStatus(ctx, user, claims) } -func (u *userService) LoginOut() error { +func (u *userService) Logout(ctx context.Context) error { // 获取当前用户信息 + session, err := contextUtil.GetSession(ctx) + if err != nil { + return err + } + // 删除redis缓存 + if err = u.tokenRefresher.DeleteToken(session.ID); err != nil { + return err + } + // remome cookie + u.removeCookie(ctx) return nil } diff --git a/internal/service/user/util.go b/internal/service/user/util.go index 98a5b1b..702e2c9 100644 --- a/internal/service/user/util.go +++ b/internal/service/user/util.go @@ -47,12 +47,22 @@ func (u *userService) setLoginStatus(ctx context.Context, user repo.User, claims return fmt.Errorf("设置redis失败:%s", err.Error()) } - c := ctx.(*gin.Context) - expires := u.conf.Jwt.Expires - domain := u.conf.App.Host - c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_TOKEN, tokenStr, expires, domain)) - c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_ACCOUNT, claims["account"], expires, domain)) - c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_ID, claims["id"], expires, domain)) + if c, ok := ctx.(*gin.Context); ok { + expires := u.conf.Jwt.Expires + domain := u.conf.App.Host + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_TOKEN, tokenStr, expires, domain)) + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_ACCOUNT, claims["account"], expires, domain)) + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=%s; Max-Age=%d; Path=/;Domain=%s", COOKIE_KEY_ID, claims["id"], expires, domain)) + } return nil } + +func (u *userService) removeCookie(ctx context.Context) { + if c, ok := ctx.(*gin.Context); ok { + domain := u.conf.App.Host + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=; Max-Age=0; Path=/;Domain=%s", COOKIE_KEY_TOKEN, domain)) + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=; Max-Age=0; Path=/;Domain=%s", COOKIE_KEY_ACCOUNT, domain)) + c.Writer.Header().Add("Set-Cookie", fmt.Sprintf("%s=; Max-Age=0; Path=/;Domain=%s", COOKIE_KEY_ID, domain)) + } +} diff --git a/pkg/middleware/jwt/jwt.go b/pkg/middleware/jwt/jwt.go index c11abcf..61f9b83 100644 --- a/pkg/middleware/jwt/jwt.go +++ b/pkg/middleware/jwt/jwt.go @@ -81,7 +81,7 @@ func NewJwtAuthMiddleware(logger *zap.SugaredLogger, refresher *token.TokenRefre // 从redis获取 _, err := refresher.GetUserToken(id) if err != nil { - return fmt.Errorf("没有登陆") + return fmt.Errorf("没有权限") } return &session.Session{ @@ -115,6 +115,11 @@ func NewJwtAuthMiddleware(logger *zap.SugaredLogger, refresher *token.TokenRefre return false }, Unauthorized: func(c *gin.Context, code int, message string) { + if msg, ok := c.Get(JwtIdentityKey); ok { + if err, ok := msg.(error); ok { + message = err.Error() + } + } c.String(code, message) }, diff --git a/pkg/utils/gin/handler.go b/pkg/utils/gin/handler.go index c793f92..50870ff 100644 --- a/pkg/utils/gin/handler.go +++ b/pkg/utils/gin/handler.go @@ -16,6 +16,8 @@ type NoReqHandler[RspType any] func(ctx context.Context) (rsp RspType, err error type NoRspHandler[ReqType any] func(ctx context.Context, req ReqType) (err error) +type NoHandler func(ctx context.Context) error + func Wrap[ReqType any, RspType any](f Handler[*ReqType, RspType]) gin.HandlerFunc { return func(c *gin.Context) { req := new(ReqType) @@ -134,3 +136,16 @@ func ConvertBody(c *gin.Context, obj interface{}) error { b := binding.Default(c.Request.Method, c.ContentType()) return c.ShouldBindWith(obj, b) } + +func WrapNo(f NoHandler) gin.HandlerFunc { + return func(c *gin.Context) { + err := f(c) + if err != nil { + _ = c.Error(err) + c.JSON(200, gin.H{"result": false, "code": 400, "msg": err.Error()}) + return + } + + c.JSON(200, gin.H{"result": true, "msg": "ok"}) + } +} diff --git a/pkg/utils/token/token.go b/pkg/utils/token/token.go index 8275669..3aba10f 100644 --- a/pkg/utils/token/token.go +++ b/pkg/utils/token/token.go @@ -94,3 +94,18 @@ func (tf *TokenRefresher) GetUserToken(userid uint) (string, error) { return token, nil } + +func (tf *TokenRefresher) DeleteToken(userId uint) error { + key := strconv.Itoa(int(userId)) + + _, err := tf.redis.Client.HDel(context.Background(), USER_ID_TOKEN_RELATION, key).Result() + if err != nil { + if err == goRedis.Nil { + return errInvalidToken + } + + return errors.Wrapf(err, "failed to GetUserToken cause by redis. map:%s,key: %d", USER_ID_TOKEN_RELATION, userId) + } + + return nil +} diff --git a/server/user/user.go b/server/user/user.go index 64e4b0f..5663baf 100644 --- a/server/user/user.go +++ b/server/user/user.go @@ -33,6 +33,7 @@ func RegisterRoute(api *gin.RouterGroup) { server := do.MustInvoke[*UserServer](nil) api.POST("/add", ginUtil.Wrap(server.Add)) api.POST("/login", ginUtil.WrapNoRsp(server.Login)) + api.POST("/logout", ginUtil.WrapNo(server.Logout)) } func (u *UserServer) Add(ctx context.Context, req *models.AddInfo) (rsp proto.AddResponse, err error) { @@ -50,7 +51,7 @@ func (u *UserServer) Add(ctx context.Context, req *models.AddInfo) (rsp proto.Ad return } -func (u *UserServer) Login(ctx context.Context, req *models.LoginInfo) (err error) { +func (u *UserServer) Login(ctx context.Context, req *proto.LoginRequest) (err error) { // 转换dto info := models.LoginInfo{ Account: req.Account, @@ -59,3 +60,7 @@ func (u *UserServer) Login(ctx context.Context, req *models.LoginInfo) (err erro return u.userService.Login(ctx, info) } + +func (u *UserServer) Logout(ctx context.Context) error { + return u.userService.Logout(ctx) +}