package auth import ( cryptorand "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "os" "path/filepath" "strings" "time" "github.com/golang-jwt/jwt/v5" ) const ( jwtAlgorithmHS256 = "HS256" jwtAlgorithmRS256 = "RS256" ) // JWTOptions controls JWT signing behavior. type JWTOptions struct { Algorithm string HS256Secret string RSAPrivateKeyPEM string RSAPublicKeyPEM string RSAPrivateKeyPath string RSAPublicKeyPath string RequireExistingRSAKeys bool AccessTokenExpire time.Duration RefreshTokenExpire time.Duration RememberLoginExpire time.Duration // 记住登录时的refresh token有效期 } // JWT JWT管理器 type JWT struct { algorithm string secret []byte privateKey *rsa.PrivateKey publicKey *rsa.PublicKey accessTokenExpire time.Duration refreshTokenExpire time.Duration rememberLoginExpire time.Duration initErr error } // Claims JWT声明 type Claims struct { UserID int64 `json:"user_id"` Username string `json:"username"` Type string `json:"type"` // access, refresh Remember bool `json:"remember,omitempty"` // 记住登录标记 JTI string `json:"jti"` // JWT ID,用于黑名单 jwt.RegisteredClaims } // generateJTI 生成唯一的 JWT ID // 使用 crypto/rand 生成密码学安全的随机数,仅使用随机数不包含时间戳 func generateJTI() (string, error) { // 生成 16 字节的密码学安全随机数 b := make([]byte, 16) if _, err := cryptorand.Read(b); err != nil { return "", fmt.Errorf("generate jwt jti failed: %w", err) } // 使用十六进制编码,仅使用随机数确保不可预测 return fmt.Sprintf("%x", b), nil } // NewJWT creates a legacy HS256 JWT manager for compatibility in tests and callers // that still only provide a shared secret. func NewJWT(secret string, accessTokenExpire, refreshTokenExpire time.Duration) *JWT { manager, err := NewJWTWithOptions(JWTOptions{ Algorithm: jwtAlgorithmHS256, HS256Secret: secret, AccessTokenExpire: accessTokenExpire, RefreshTokenExpire: refreshTokenExpire, }) if err != nil { return &JWT{ algorithm: jwtAlgorithmHS256, accessTokenExpire: accessTokenExpire, refreshTokenExpire: refreshTokenExpire, initErr: err, } } return manager } func (j *JWT) ensureReady() error { if j == nil { return errors.New("jwt manager is nil") } if j.initErr != nil { return j.initErr } return nil } // NewJWTWithOptions creates a JWT manager from explicit signing options. func NewJWTWithOptions(opts JWTOptions) (*JWT, error) { algorithm := strings.ToUpper(strings.TrimSpace(opts.Algorithm)) if algorithm == "" { if opts.HS256Secret != "" && opts.RSAPrivateKeyPEM == "" && opts.RSAPrivateKeyPath == "" { algorithm = jwtAlgorithmHS256 } else { algorithm = jwtAlgorithmRS256 } } manager := &JWT{ algorithm: algorithm, accessTokenExpire: opts.AccessTokenExpire, refreshTokenExpire: opts.RefreshTokenExpire, rememberLoginExpire: opts.RememberLoginExpire, } switch algorithm { case jwtAlgorithmHS256: if opts.HS256Secret == "" { return nil, errors.New("jwt secret is required for HS256") } manager.secret = []byte(opts.HS256Secret) case jwtAlgorithmRS256: if err := manager.loadRSAKeys(opts); err != nil { return nil, err } default: return nil, fmt.Errorf("unsupported jwt algorithm: %s", algorithm) } return manager, nil } func (j *JWT) loadRSAKeys(opts JWTOptions) error { privatePEM, err := readPEM(opts.RSAPrivateKeyPEM, opts.RSAPrivateKeyPath) if err != nil { return fmt.Errorf("load jwt private key failed: %w", err) } publicPEM, err := readPEM(opts.RSAPublicKeyPEM, opts.RSAPublicKeyPath) if err != nil { return fmt.Errorf("load jwt public key failed: %w", err) } if privatePEM == "" && publicPEM == "" { if strings.TrimSpace(opts.RSAPrivateKeyPath) == "" || strings.TrimSpace(opts.RSAPublicKeyPath) == "" { return errors.New("rsa private/public key paths or inline pem are required for RS256") } if opts.RequireExistingRSAKeys { return errors.New("existing rsa private/public key files or inline pem are required for RS256") } privatePEM, publicPEM, err = generateAndPersistRSAKeyPair(opts.RSAPrivateKeyPath, opts.RSAPublicKeyPath) if err != nil { return fmt.Errorf("generate rsa key pair failed: %w", err) } } if privatePEM != "" { privateKey, err := parseRSAPrivateKey(privatePEM) if err != nil { return err } j.privateKey = privateKey j.publicKey = &privateKey.PublicKey } if publicPEM != "" { publicKey, err := parseRSAPublicKey(publicPEM) if err != nil { return err } j.publicKey = publicKey } if j.privateKey == nil { return errors.New("rsa private key is required for signing") } if j.publicKey == nil { return errors.New("rsa public key is required for verification") } return nil } func generateAndPersistRSAKeyPair(privatePath, publicPath string) (string, string, error) { privatePath = strings.TrimSpace(privatePath) publicPath = strings.TrimSpace(publicPath) if privatePath == "" || publicPath == "" { return "", "", errors.New("rsa key paths must not be empty") } privateKey, err := rsa.GenerateKey(cryptorand.Reader, 2048) if err != nil { return "", "", err } privateDER := x509.MarshalPKCS1PrivateKey(privateKey) privatePEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateDER}) publicDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) if err != nil { return "", "", err } publicPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: publicDER}) if err := os.MkdirAll(filepath.Dir(privatePath), 0o700); err != nil { return "", "", err } if err := os.MkdirAll(filepath.Dir(publicPath), 0o700); err != nil { return "", "", err } if err := os.WriteFile(privatePath, privatePEM, 0o600); err != nil { return "", "", err } if err := os.WriteFile(publicPath, publicPEM, 0o644); err != nil { return "", "", err } return string(privatePEM), string(publicPEM), nil } func readPEM(inlinePEM, path string) (string, error) { inlinePEM = strings.TrimSpace(inlinePEM) if inlinePEM != "" { return inlinePEM, nil } path = strings.TrimSpace(path) if path == "" { return "", nil } data, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { return "", nil } return "", err } return string(data), nil } func parseRSAPrivateKey(pemValue string) (*rsa.PrivateKey, error) { block, _ := pem.Decode([]byte(pemValue)) if block == nil { return nil, errors.New("invalid rsa private key pem") } if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { return key, nil } key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("parse rsa private key failed: %w", err) } rsaKey, ok := key.(*rsa.PrivateKey) if !ok { return nil, errors.New("private key is not rsa") } return rsaKey, nil } func parseRSAPublicKey(pemValue string) (*rsa.PublicKey, error) { block, _ := pem.Decode([]byte(pemValue)) if block == nil { return nil, errors.New("invalid rsa public key pem") } if key, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil { rsaKey, ok := key.(*rsa.PublicKey) if !ok { return nil, errors.New("public key is not rsa") } return rsaKey, nil } if cert, err := x509.ParseCertificate(block.Bytes); err == nil { rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) if !ok { return nil, errors.New("certificate public key is not rsa") } return rsaKey, nil } return nil, errors.New("parse rsa public key failed") } func (j *JWT) signingMethod() jwt.SigningMethod { if j.algorithm == jwtAlgorithmRS256 { return jwt.SigningMethodRS256 } return jwt.SigningMethodHS256 } func (j *JWT) signingKey() interface{} { if j.algorithm == jwtAlgorithmRS256 { return j.privateKey } return j.secret } func (j *JWT) verifyKey(token *jwt.Token) (interface{}, error) { if token.Method.Alg() != j.signingMethod().Alg() { return nil, fmt.Errorf("unexpected signing method: %s", token.Method.Alg()) } if j.algorithm == jwtAlgorithmRS256 { return j.publicKey, nil } return j.secret, nil } // GetAlgorithm returns the configured JWT signing algorithm. func (j *JWT) GetAlgorithm() string { return j.algorithm } // GenerateAccessToken 生成访问令牌(含JTI) func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) { if err := j.ensureReady(); err != nil { return "", err } now := time.Now() jti, err := generateJTI() if err != nil { return "", err } claims := Claims{ UserID: userID, Username: username, Type: "access", JTI: jti, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, } token := jwt.NewWithClaims(j.signingMethod(), claims) return token.SignedString(j.signingKey()) } // GenerateRefreshToken 生成刷新令牌(含JTI) func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error) { if err := j.ensureReady(); err != nil { return "", err } now := time.Now() jti, err := generateJTI() if err != nil { return "", err } claims := Claims{ UserID: userID, Username: username, Type: "refresh", JTI: jti, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, } token := jwt.NewWithClaims(j.signingMethod(), claims) return token.SignedString(j.signingKey()) } // GetAccessTokenExpire 获取访问令牌有效期 func (j *JWT) GetAccessTokenExpire() time.Duration { return j.accessTokenExpire } // GetRefreshTokenExpire 获取刷新令牌有效期 func (j *JWT) GetRefreshTokenExpire() time.Duration { return j.refreshTokenExpire } // GenerateTokenPair 生成令牌对 func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, refreshToken string, err error) { accessToken, err = j.GenerateAccessToken(userID, username) if err != nil { return "", "", err } refreshToken, err = j.GenerateRefreshToken(userID, username) if err != nil { return "", "", err } return accessToken, refreshToken, nil } // GenerateTokenPairWithRemember 生成令牌对(支持记住登录) func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool) (accessToken, refreshToken string, err error) { accessToken, err = j.GenerateAccessToken(userID, username) if err != nil { return "", "", err } if remember { refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username) } else { refreshToken, err = j.GenerateRefreshToken(userID, username) } if err != nil { return "", "", err } return accessToken, refreshToken, nil } // GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用) func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (string, error) { if err := j.ensureReady(); err != nil { return "", err } now := time.Now() jti, err := generateJTI() if err != nil { return "", err } // 使用rememberLoginExpire,如果未配置则使用默认的refreshTokenExpire expireDuration := j.rememberLoginExpire if expireDuration == 0 { expireDuration = j.refreshTokenExpire } claims := Claims{ UserID: userID, Username: username, Type: "refresh", Remember: true, // 长期会话标记 JTI: jti, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, } token := jwt.NewWithClaims(j.signingMethod(), claims) return token.SignedString(j.signingKey()) } // ParseToken 解析令牌 func (j *JWT) ParseToken(tokenString string) (*Claims, error) { if err := j.ensureReady(); err != nil { return nil, err } token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return j.verifyKey(token) }) if err != nil { return nil, err } if claims, ok := token.Claims.(*Claims); ok && token.Valid { return claims, nil } return nil, errors.New("invalid token") } // ValidateAccessToken 验证访问令牌 func (j *JWT) ValidateAccessToken(tokenString string) (*Claims, error) { claims, err := j.ParseToken(tokenString) if err != nil { return nil, err } if claims.Type != "access" { return nil, errors.New("invalid token type") } return claims, nil } // ValidateRefreshToken 验证刷新令牌 func (j *JWT) ValidateRefreshToken(tokenString string) (*Claims, error) { claims, err := j.ParseToken(tokenString) if err != nil { return nil, err } if claims.Type != "refresh" { return nil, errors.New("invalid token type") } return claims, nil } // RefreshAccessToken 刷新访问令牌 func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) { claims, err := j.ValidateRefreshToken(refreshTokenString) if err != nil { return "", err } return j.GenerateAccessToken(claims.UserID, claims.Username) }