package jwt import ( "context" "crypto/sha1" "fmt" "os" "time" "busniess-user-center/pkg/redis" contextUtil "busniess-user-center/pkg/utils/context" "busniess-user-center/pkg/utils/session" jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/pkg/errors" goRedis "github.com/redis/go-redis/v9" "go.uber.org/zap" ) var ( isLocalDevelopment = false ) func init() { isLocalDevelopment = len(os.Getenv("LOCAL_DEVELOPMENT")) > 0 } var ( JwtIdentityKey = "id" Account = "account" tokenName = "token" errInvalidToken = errors.New("invalid token") USER_ID_TOKEN_RELATION = "user_id_token_relation" whiteUrlList = map[string]bool{ "/user/login": true, "/sso": true, } ) type login struct { Account string `form:"account" json:"account" binding:"required"` Password string `form:"password" json:"password" binding:"required"` } type JwtAuthMiddleware struct { Middleware *jwt.GinJWTMiddleware authHandler gin.HandlerFunc } type TokenRefresher struct { expireIn time.Duration redis *redis.Redis } func NewTokenRefresher(redis *redis.Redis, expireIn time.Duration) *TokenRefresher { return &TokenRefresher{ expireIn: expireIn, redis: redis, } } func (tf *TokenRefresher) generateTokenKey(tokenID, token string) string { sha1Val := sha1.Sum([]byte(fmt.Sprintf("%s:%s", tokenID, token))) return fmt.Sprintf("loginbigdata:jwt:token:%x", sha1Val) } func (tf *TokenRefresher) Refresh(tokenID, token string) error { cacheKey := tf.generateTokenKey(tokenID, token) result, err := tf.redis.Client.Get(context.Background(), cacheKey).Result() if err != nil { if err == goRedis.Nil { return errInvalidToken } return errors.Wrapf(err, "failed to refresh token cause by redis. key: %s", cacheKey) } if result != "" { if err := tf.redis.Client.Expire(context.Background(), cacheKey, tf.expireIn).Err(); err != nil { return errors.Wrapf(err, "faild to refresh token case by expire redis. key: %s", cacheKey) } } return nil } func (tf *TokenRefresher) SetToken(tokenID, token, data string) error { cacheKey := tf.generateTokenKey(tokenID, token) _, err := tf.redis.Client.Set(context.Background(), cacheKey, data, tf.expireIn).Result() if err != nil { if err == goRedis.Nil { return errInvalidToken } return errors.Wrapf(err, "failed to set token cause by redis. key: %s", cacheKey) } return nil } func (tf *TokenRefresher) SetUseridTokenRelation(userid, token string) error { _, err := tf.redis.Client.HSet(context.Background(), USER_ID_TOKEN_RELATION, userid, token).Result() if err != nil { if err == goRedis.Nil { return errInvalidToken } return errors.Wrapf(err, "failed to SetUseridTokenRelation cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid) } return nil } func (tf *TokenRefresher) GetUserToken(userid string) (string, error) { token, err := tf.redis.Client.HGet(context.Background(), USER_ID_TOKEN_RELATION, userid).Result() if err != nil { if err == goRedis.Nil { return token, errInvalidToken } return token, errors.Wrapf(err, "failed to GetUserToken cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid) } return token, nil } func NewJwtAuthMiddleware(logger *zap.SugaredLogger, refresher *TokenRefresher, secret []byte, timeout time.Duration) (*JwtAuthMiddleware, error) { authMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{ Realm: "", Key: secret, Timeout: timeout, MaxRefresh: time.Hour, IdentityKey: JwtIdentityKey, PayloadFunc: func(data interface{}) jwt.MapClaims { if v, ok := data.(*session.Session); ok { return jwt.MapClaims{ JwtIdentityKey: v.ID, Account: v.Account, } } return jwt.MapClaims{} }, IdentityHandler: func(c *gin.Context) interface{} { var ( id = "" account = "" ) claims := jwt.ExtractClaims(c) AccountInterface, ok := claims[Account] if ok { account = AccountInterface.(string) } idInterface, ok := claims[JwtIdentityKey] if ok { id = idInterface.(string) } return &session.Session{ ID: id, Account: account, } }, Authenticator: func(c *gin.Context) (interface{}, error) { var loginVals login if err := c.ShouldBind(&loginVals); err != nil { return "", jwt.ErrMissingLoginValues } userID := loginVals.Account password := loginVals.Password if userID == "admin" && password == "admin" { return &session.Session{ ID: "xxx", Account: "admin", }, nil } return nil, jwt.ErrFailedAuthentication }, Authorizator: func(data interface{}, c *gin.Context) bool { if v, ok := data.(*session.Session); ok { c.Request = c.Request.WithContext(contextUtil.PutSession(c.Request.Context(), v)) return v.ID != "" } return false }, Unauthorized: func(c *gin.Context, code int, message string) { c.String(code, message) }, TokenLookup: "header:Authorization, cookie:" + tokenName, TokenHeadName: "Bearer", // TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens. TimeFunc: time.Now, }) return &JwtAuthMiddleware{ Middleware: authMiddleware, authHandler: authMiddleware.MiddlewareFunc(), }, err } func (s *JwtAuthMiddleware) AuthHandler() gin.HandlerFunc { return func(c *gin.Context) { rPath := c.Request.URL.Path if _, ok := whiteUrlList[rPath]; !ok { s.authHandler(c) } else { c.Next() } } }