Files
user-system/internal/auth/jwt.go

548 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
cryptorand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
jwtAlgorithmHS256 = "HS256"
jwtAlgorithmRS256 = "RS256"
)
// JWTOptions controls JWT signing behavior.
type JWTOptions struct {
Algorithm string
HS256Secret string
RSAPrivateKeyPEM string
RSAPublicKeyPEM string
RSAPrivateKeyPath string
RSAPublicKeyPath string
RequireExistingRSAKeys bool
AccessTokenExpire time.Duration
RefreshTokenExpire time.Duration
RememberLoginExpire time.Duration // 记住登录时的refresh token有效期
}
// JWT JWT管理器
type JWT struct {
algorithm string
secret []byte
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
accessTokenExpire time.Duration
refreshTokenExpire time.Duration
rememberLoginExpire time.Duration
initErr error
}
// Claims JWT声明
type Claims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
Type string `json:"type"` // access, refresh
Remember bool `json:"remember,omitempty"` // 记住登录标记
JTI string `json:"jti"` // JWT ID用于黑名单
PCE int64 `json:"pce,omitempty"` // Password Changed Epoch密码变更时间戳用于 token 失效机制
DeviceID string `json:"device_id,omitempty"`
jwt.RegisteredClaims
}
// generateJTI 生成唯一的 JWT ID
// 使用时间戳 + 密码学安全随机数,防止枚举攻击
// 格式: {timestamp(8字节hex)}{random(16字节hex)},共 24 字符
func generateJTI() (string, error) {
// 时间戳部分8 字节 hex足够 584 年)
timestamp := time.Now().Unix()
// 随机数部分16 字节128 位)
b := make([]byte, 16)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate jwt jti failed: %w", err)
}
// 组合时间戳和随机数timestamp(8字节) + random(16字节) = 24字节 hex
return fmt.Sprintf("%016x%x", timestamp, b), nil
}
// NewJWT creates a legacy HS256 JWT manager for compatibility in tests and callers
// that still only provide a shared secret.
func NewJWT(secret string, accessTokenExpire, refreshTokenExpire time.Duration) (*JWT, error) {
return NewJWTWithOptions(JWTOptions{
Algorithm: jwtAlgorithmHS256,
HS256Secret: secret,
AccessTokenExpire: accessTokenExpire,
RefreshTokenExpire: refreshTokenExpire,
})
}
func (j *JWT) ensureReady() error {
if j == nil {
return errors.New("jwt manager is nil")
}
if j.initErr != nil {
return j.initErr
}
return nil
}
// NewJWTWithOptions creates a JWT manager from explicit signing options.
func NewJWTWithOptions(opts JWTOptions) (*JWT, error) {
algorithm := strings.ToUpper(strings.TrimSpace(opts.Algorithm))
if algorithm == "" {
if opts.HS256Secret != "" && opts.RSAPrivateKeyPEM == "" && opts.RSAPrivateKeyPath == "" {
algorithm = jwtAlgorithmHS256
} else {
algorithm = jwtAlgorithmRS256
}
}
manager := &JWT{
algorithm: algorithm,
accessTokenExpire: opts.AccessTokenExpire,
refreshTokenExpire: opts.RefreshTokenExpire,
rememberLoginExpire: opts.RememberLoginExpire,
}
switch algorithm {
case jwtAlgorithmHS256:
if opts.HS256Secret == "" {
return nil, errors.New("jwt secret is required for HS256")
}
manager.secret = []byte(opts.HS256Secret)
case jwtAlgorithmRS256:
if err := manager.loadRSAKeys(opts); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported jwt algorithm: %s", algorithm)
}
return manager, nil
}
func (j *JWT) loadRSAKeys(opts JWTOptions) error {
privatePEM, err := readPEM(opts.RSAPrivateKeyPEM, opts.RSAPrivateKeyPath)
if err != nil {
return fmt.Errorf("load jwt private key failed: %w", err)
}
publicPEM, err := readPEM(opts.RSAPublicKeyPEM, opts.RSAPublicKeyPath)
if err != nil {
return fmt.Errorf("load jwt public key failed: %w", err)
}
if privatePEM == "" && publicPEM == "" {
if strings.TrimSpace(opts.RSAPrivateKeyPath) == "" || strings.TrimSpace(opts.RSAPublicKeyPath) == "" {
return errors.New("rsa private/public key paths or inline pem are required for RS256")
}
if opts.RequireExistingRSAKeys {
return errors.New("existing rsa private/public key files or inline pem are required for RS256")
}
privatePEM, publicPEM, err = generateAndPersistRSAKeyPair(opts.RSAPrivateKeyPath, opts.RSAPublicKeyPath)
if err != nil {
return fmt.Errorf("generate rsa key pair failed: %w", err)
}
}
if privatePEM != "" {
privateKey, err := parseRSAPrivateKey(privatePEM)
if err != nil {
return err
}
j.privateKey = privateKey
j.publicKey = &privateKey.PublicKey
}
if publicPEM != "" {
publicKey, err := parseRSAPublicKey(publicPEM)
if err != nil {
return err
}
j.publicKey = publicKey
}
if j.privateKey == nil {
return errors.New("rsa private key is required for signing")
}
if j.publicKey == nil {
return errors.New("rsa public key is required for verification")
}
return nil
}
func generateAndPersistRSAKeyPair(privatePath, publicPath string) (string, string, error) {
privatePath = strings.TrimSpace(privatePath)
publicPath = strings.TrimSpace(publicPath)
if privatePath == "" || publicPath == "" {
return "", "", errors.New("rsa key paths must not be empty")
}
privateKey, err := rsa.GenerateKey(cryptorand.Reader, 2048)
if err != nil {
return "", "", err
}
privateDER := x509.MarshalPKCS1PrivateKey(privateKey)
privatePEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateDER})
publicDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return "", "", err
}
publicPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: publicDER})
if err := os.MkdirAll(filepath.Dir(privatePath), 0o700); err != nil {
return "", "", err
}
if err := os.MkdirAll(filepath.Dir(publicPath), 0o700); err != nil {
return "", "", err
}
if err := os.WriteFile(privatePath, privatePEM, 0o600); err != nil {
return "", "", err
}
if err := os.WriteFile(publicPath, publicPEM, 0o644); err != nil {
return "", "", err
}
return string(privatePEM), string(publicPEM), nil
}
func readPEM(inlinePEM, path string) (string, error) {
inlinePEM = strings.TrimSpace(inlinePEM)
if inlinePEM != "" {
return inlinePEM, nil
}
path = strings.TrimSpace(path)
if path == "" {
return "", nil
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return "", nil
}
return "", err
}
return string(data), nil
}
func parseRSAPrivateKey(pemValue string) (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, errors.New("invalid rsa private key pem")
}
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse rsa private key failed: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not rsa")
}
return rsaKey, nil
}
func parseRSAPublicKey(pemValue string) (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, errors.New("invalid rsa public key pem")
}
if key, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not rsa")
}
return rsaKey, nil
}
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, errors.New("certificate public key is not rsa")
}
return rsaKey, nil
}
return nil, errors.New("parse rsa public key failed")
}
func (j *JWT) signingMethod() jwt.SigningMethod {
if j.algorithm == jwtAlgorithmRS256 {
return jwt.SigningMethodRS256
}
return jwt.SigningMethodHS256
}
func (j *JWT) signingKey() interface{} {
if j.algorithm == jwtAlgorithmRS256 {
return j.privateKey
}
return j.secret
}
func (j *JWT) verifyKey(token *jwt.Token) (interface{}, error) {
if token.Method.Alg() != j.signingMethod().Alg() {
return nil, fmt.Errorf("unexpected signing method: %s", token.Method.Alg())
}
if j.algorithm == jwtAlgorithmRS256 {
return j.publicKey, nil
}
return j.secret, nil
}
// GetAlgorithm returns the configured JWT signing algorithm.
func (j *JWT) GetAlgorithm() string {
return j.algorithm
}
// GenerateAccessToken 生成访问令牌含JTI和密码变更时间戳
func (j *JWT) GenerateAccessToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Type: "access",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// GenerateRefreshToken 生成刷新令牌含JTI和密码变更时间戳
func (j *JWT) GenerateRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Type: "refresh",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// GetAccessTokenExpire 获取访问令牌有效期
func (j *JWT) GetAccessTokenExpire() time.Duration {
return j.accessTokenExpire
}
// GetRefreshTokenExpire 获取刷新令牌有效期
func (j *JWT) GetRefreshTokenExpire() time.Duration {
return j.refreshTokenExpire
}
// GenerateTokenPair 生成令牌对(含密码变更时间戳)
func (j *JWT) GenerateTokenPair(userID int64, username string, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录,含密码变更时间戳)
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
if remember {
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username, pce)
} else {
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
}
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用,含密码变更时间戳)
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
// 使用rememberLoginExpire如果未配置则使用默认的refreshTokenExpire
expireDuration := j.rememberLoginExpire
if expireDuration == 0 {
expireDuration = j.refreshTokenExpire
}
claims := Claims{
UserID: userID,
Username: username,
Type: "refresh",
Remember: true, // 长期会话标记
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// ParseToken 解析令牌
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
if err := j.ensureReady(); err != nil {
return nil, err
}
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return j.verifyKey(token)
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}
// ValidateAccessToken 验证访问令牌
func (j *JWT) ValidateAccessToken(tokenString string) (*Claims, error) {
claims, err := j.ParseToken(tokenString)
if err != nil {
return nil, err
}
if claims.Type != "access" {
return nil, errors.New("invalid token type")
}
return claims, nil
}
// ValidateRefreshToken 验证刷新令牌
func (j *JWT) ValidateRefreshToken(tokenString string) (*Claims, error) {
claims, err := j.ParseToken(tokenString)
if err != nil {
return nil, err
}
if claims.Type != "refresh" {
return nil, errors.New("invalid token type")
}
return claims, nil
}
func (j *JWT) GenerateTOTPChallengeToken(userID int64, username, deviceID string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Type: "totp_challenge",
JTI: jti,
PCE: pce,
DeviceID: strings.TrimSpace(deviceID),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
func (j *JWT) ValidateTOTPChallengeToken(tokenString string) (*Claims, error) {
claims, err := j.ParseToken(tokenString)
if err != nil {
return nil, err
}
if claims.Type != "totp_challenge" {
return nil, errors.New("invalid token type")
}
return claims, nil
}
// RefreshAccessToken 刷新访问令牌
func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) {
claims, err := j.ValidateRefreshToken(refreshTokenString)
if err != nil {
return "", err
}
return j.GenerateAccessToken(claims.UserID, claims.Username, claims.PCE)
}