busniess-user-center/pkg/middleware/jwt/jwt.go

220 lines
5.5 KiB
Go

package jwt
import (
"context"
"crypto/sha1"
"fmt"
"os"
"time"
"busniess-user-center/pkg/redis"
contextUtil "busniess-user-center/pkg/utils/context"
"busniess-user-center/pkg/utils/session"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
goRedis "github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
var (
isLocalDevelopment = false
)
func init() {
isLocalDevelopment = len(os.Getenv("LOCAL_DEVELOPMENT")) > 0
}
var (
JwtIdentityKey = "id"
Account = "account"
tokenName = "token"
errInvalidToken = errors.New("invalid token")
USER_ID_TOKEN_RELATION = "user_id_token_relation"
whiteUrlList = map[string]bool{
"/user/login": true,
"/sso": true,
}
)
type login struct {
Account string `form:"account" json:"account" binding:"required"`
Password string `form:"password" json:"password" binding:"required"`
}
type JwtAuthMiddleware struct {
Middleware *jwt.GinJWTMiddleware
authHandler gin.HandlerFunc
}
type TokenRefresher struct {
expireIn time.Duration
redis *redis.Redis
}
func NewTokenRefresher(redis *redis.Redis, expireIn time.Duration) *TokenRefresher {
return &TokenRefresher{
expireIn: expireIn,
redis: redis,
}
}
func (tf *TokenRefresher) generateTokenKey(tokenID, token string) string {
sha1Val := sha1.Sum([]byte(fmt.Sprintf("%s:%s", tokenID, token)))
return fmt.Sprintf("loginbigdata:jwt:token:%x", sha1Val)
}
func (tf *TokenRefresher) Refresh(tokenID, token string) error {
cacheKey := tf.generateTokenKey(tokenID, token)
result, err := tf.redis.Client.Get(context.Background(), cacheKey).Result()
if err != nil {
if err == goRedis.Nil {
return errInvalidToken
}
return errors.Wrapf(err, "failed to refresh token cause by redis. key: %s", cacheKey)
}
if result != "" {
if err := tf.redis.Client.Expire(context.Background(), cacheKey, tf.expireIn).Err(); err != nil {
return errors.Wrapf(err, "faild to refresh token case by expire redis. key: %s", cacheKey)
}
}
return nil
}
func (tf *TokenRefresher) SetToken(tokenID, token, data string) error {
cacheKey := tf.generateTokenKey(tokenID, token)
_, err := tf.redis.Client.Set(context.Background(), cacheKey, data, tf.expireIn).Result()
if err != nil {
if err == goRedis.Nil {
return errInvalidToken
}
return errors.Wrapf(err, "failed to set token cause by redis. key: %s", cacheKey)
}
return nil
}
func (tf *TokenRefresher) SetUseridTokenRelation(userid, token string) error {
_, err := tf.redis.Client.HSet(context.Background(), USER_ID_TOKEN_RELATION, userid, token).Result()
if err != nil {
if err == goRedis.Nil {
return errInvalidToken
}
return errors.Wrapf(err, "failed to SetUseridTokenRelation cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid)
}
return nil
}
func (tf *TokenRefresher) GetUserToken(userid string) (string, error) {
token, err := tf.redis.Client.HGet(context.Background(), USER_ID_TOKEN_RELATION, userid).Result()
if err != nil {
if err == goRedis.Nil {
return token, errInvalidToken
}
return token, errors.Wrapf(err, "failed to GetUserToken cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid)
}
return token, nil
}
func NewJwtAuthMiddleware(logger *zap.SugaredLogger, refresher *TokenRefresher, secret []byte, timeout time.Duration) (*JwtAuthMiddleware, error) {
authMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{
Realm: "",
Key: secret,
Timeout: timeout,
MaxRefresh: time.Hour,
IdentityKey: JwtIdentityKey,
PayloadFunc: func(data interface{}) jwt.MapClaims {
if v, ok := data.(*session.Session); ok {
return jwt.MapClaims{
JwtIdentityKey: v.ID,
Account: v.Account,
}
}
return jwt.MapClaims{}
},
IdentityHandler: func(c *gin.Context) interface{} {
var (
id = ""
account = ""
)
claims := jwt.ExtractClaims(c)
AccountInterface, ok := claims[Account]
if ok {
account = AccountInterface.(string)
}
idInterface, ok := claims[JwtIdentityKey]
if ok {
id = idInterface.(string)
}
return &session.Session{
ID: id,
Account: account,
}
},
Authenticator: func(c *gin.Context) (interface{}, error) {
var loginVals login
if err := c.ShouldBind(&loginVals); err != nil {
return "", jwt.ErrMissingLoginValues
}
userID := loginVals.Account
password := loginVals.Password
if userID == "admin" && password == "admin" {
return &session.Session{
ID: "xxx",
Account: "admin",
}, nil
}
return nil, jwt.ErrFailedAuthentication
},
Authorizator: func(data interface{}, c *gin.Context) bool {
if v, ok := data.(*session.Session); ok {
c.Request = c.Request.WithContext(contextUtil.PutSession(c.Request.Context(), v))
return v.ID != ""
}
return false
},
Unauthorized: func(c *gin.Context, code int, message string) {
c.String(code, message)
},
TokenLookup: "header:Authorization, cookie:" + tokenName,
TokenHeadName: "Bearer",
// TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens.
TimeFunc: time.Now,
})
return &JwtAuthMiddleware{
Middleware: authMiddleware,
authHandler: authMiddleware.MiddlewareFunc(),
}, err
}
func (s *JwtAuthMiddleware) AuthHandler() gin.HandlerFunc {
return func(c *gin.Context) {
rPath := c.Request.URL.Path
if _, ok := whiteUrlList[rPath]; !ok {
s.authHandler(c)
} else {
c.Next()
}
}
}