220 lines
5.5 KiB
Go
220 lines
5.5 KiB
Go
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha1"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"busniess-user-center/pkg/redis"
|
|
contextUtil "busniess-user-center/pkg/utils/context"
|
|
"busniess-user-center/pkg/utils/session"
|
|
|
|
jwt "github.com/appleboy/gin-jwt/v2"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/pkg/errors"
|
|
goRedis "github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
isLocalDevelopment = false
|
|
)
|
|
|
|
func init() {
|
|
isLocalDevelopment = len(os.Getenv("LOCAL_DEVELOPMENT")) > 0
|
|
}
|
|
|
|
var (
|
|
JwtIdentityKey = "id"
|
|
Account = "account"
|
|
tokenName = "token"
|
|
|
|
errInvalidToken = errors.New("invalid token")
|
|
USER_ID_TOKEN_RELATION = "user_id_token_relation"
|
|
|
|
whiteUrlList = map[string]bool{
|
|
"/user/login": true,
|
|
"/sso": true,
|
|
}
|
|
)
|
|
|
|
type login struct {
|
|
Account string `form:"account" json:"account" binding:"required"`
|
|
Password string `form:"password" json:"password" binding:"required"`
|
|
}
|
|
|
|
type JwtAuthMiddleware struct {
|
|
Middleware *jwt.GinJWTMiddleware
|
|
authHandler gin.HandlerFunc
|
|
}
|
|
|
|
type TokenRefresher struct {
|
|
expireIn time.Duration
|
|
redis *redis.Redis
|
|
}
|
|
|
|
func NewTokenRefresher(redis *redis.Redis, expireIn time.Duration) *TokenRefresher {
|
|
return &TokenRefresher{
|
|
expireIn: expireIn,
|
|
redis: redis,
|
|
}
|
|
}
|
|
|
|
func (tf *TokenRefresher) generateTokenKey(tokenID, token string) string {
|
|
sha1Val := sha1.Sum([]byte(fmt.Sprintf("%s:%s", tokenID, token)))
|
|
return fmt.Sprintf("loginbigdata:jwt:token:%x", sha1Val)
|
|
}
|
|
|
|
func (tf *TokenRefresher) Refresh(tokenID, token string) error {
|
|
cacheKey := tf.generateTokenKey(tokenID, token)
|
|
result, err := tf.redis.Client.Get(context.Background(), cacheKey).Result()
|
|
|
|
if err != nil {
|
|
if err == goRedis.Nil {
|
|
return errInvalidToken
|
|
}
|
|
|
|
return errors.Wrapf(err, "failed to refresh token cause by redis. key: %s", cacheKey)
|
|
}
|
|
|
|
if result != "" {
|
|
if err := tf.redis.Client.Expire(context.Background(), cacheKey, tf.expireIn).Err(); err != nil {
|
|
return errors.Wrapf(err, "faild to refresh token case by expire redis. key: %s", cacheKey)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tf *TokenRefresher) SetToken(tokenID, token, data string) error {
|
|
cacheKey := tf.generateTokenKey(tokenID, token)
|
|
_, err := tf.redis.Client.Set(context.Background(), cacheKey, data, tf.expireIn).Result()
|
|
if err != nil {
|
|
if err == goRedis.Nil {
|
|
return errInvalidToken
|
|
}
|
|
|
|
return errors.Wrapf(err, "failed to set token cause by redis. key: %s", cacheKey)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (tf *TokenRefresher) SetUseridTokenRelation(userid, token string) error {
|
|
_, err := tf.redis.Client.HSet(context.Background(), USER_ID_TOKEN_RELATION, userid, token).Result()
|
|
if err != nil {
|
|
if err == goRedis.Nil {
|
|
return errInvalidToken
|
|
}
|
|
|
|
return errors.Wrapf(err, "failed to SetUseridTokenRelation cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (tf *TokenRefresher) GetUserToken(userid string) (string, error) {
|
|
token, err := tf.redis.Client.HGet(context.Background(), USER_ID_TOKEN_RELATION, userid).Result()
|
|
if err != nil {
|
|
if err == goRedis.Nil {
|
|
return token, errInvalidToken
|
|
}
|
|
|
|
return token, errors.Wrapf(err, "failed to GetUserToken cause by redis. map:%s,key: %s", USER_ID_TOKEN_RELATION, userid)
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func NewJwtAuthMiddleware(logger *zap.SugaredLogger, refresher *TokenRefresher, secret []byte, timeout time.Duration) (*JwtAuthMiddleware, error) {
|
|
authMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{
|
|
Realm: "",
|
|
Key: secret,
|
|
Timeout: timeout,
|
|
MaxRefresh: time.Hour,
|
|
IdentityKey: JwtIdentityKey,
|
|
PayloadFunc: func(data interface{}) jwt.MapClaims {
|
|
if v, ok := data.(*session.Session); ok {
|
|
return jwt.MapClaims{
|
|
JwtIdentityKey: v.ID,
|
|
Account: v.Account,
|
|
}
|
|
}
|
|
return jwt.MapClaims{}
|
|
},
|
|
IdentityHandler: func(c *gin.Context) interface{} {
|
|
var (
|
|
id = ""
|
|
account = ""
|
|
)
|
|
|
|
claims := jwt.ExtractClaims(c)
|
|
|
|
AccountInterface, ok := claims[Account]
|
|
if ok {
|
|
account = AccountInterface.(string)
|
|
}
|
|
|
|
idInterface, ok := claims[JwtIdentityKey]
|
|
if ok {
|
|
id = idInterface.(string)
|
|
}
|
|
|
|
return &session.Session{
|
|
ID: id,
|
|
Account: account,
|
|
}
|
|
},
|
|
Authenticator: func(c *gin.Context) (interface{}, error) {
|
|
var loginVals login
|
|
if err := c.ShouldBind(&loginVals); err != nil {
|
|
return "", jwt.ErrMissingLoginValues
|
|
}
|
|
userID := loginVals.Account
|
|
password := loginVals.Password
|
|
|
|
if userID == "admin" && password == "admin" {
|
|
return &session.Session{
|
|
ID: "xxx",
|
|
Account: "admin",
|
|
}, nil
|
|
}
|
|
|
|
return nil, jwt.ErrFailedAuthentication
|
|
},
|
|
Authorizator: func(data interface{}, c *gin.Context) bool {
|
|
if v, ok := data.(*session.Session); ok {
|
|
c.Request = c.Request.WithContext(contextUtil.PutSession(c.Request.Context(), v))
|
|
return v.ID != ""
|
|
}
|
|
return false
|
|
},
|
|
Unauthorized: func(c *gin.Context, code int, message string) {
|
|
c.String(code, message)
|
|
},
|
|
|
|
TokenLookup: "header:Authorization, cookie:" + tokenName,
|
|
TokenHeadName: "Bearer",
|
|
|
|
// TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens.
|
|
TimeFunc: time.Now,
|
|
})
|
|
|
|
return &JwtAuthMiddleware{
|
|
Middleware: authMiddleware,
|
|
authHandler: authMiddleware.MiddlewareFunc(),
|
|
}, err
|
|
}
|
|
|
|
func (s *JwtAuthMiddleware) AuthHandler() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
rPath := c.Request.URL.Path
|
|
if _, ok := whiteUrlList[rPath]; !ok {
|
|
s.authHandler(c)
|
|
} else {
|
|
c.Next()
|
|
}
|
|
}
|
|
}
|