Files
user-system/internal/auth/jwt.go
long-agent 8095307d82 fix: P0/P1 security and quality fixes
P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go
P0-02: Add atomic Increment to L1Cache and L2Cache interfaces
P0-07: Add TOTP verification step after password login
P1-01: Sanitize error messages in error.go middleware
P1-03: Remove err.Error() from export error messages
P1-04: Add error return to CountByResultSince in login_log.go
P1-05: Add transactional DeleteCascade to RoleRepository
P1-06: Add PasswordChangedAt tracking for JWT token invalidation
P1-07: Wrap theme SetDefault in database transaction
P1-08: Use config values for database pool parameters
P1-09: Add rows.Err() checks in social_account_repo.go
P1-10: Validate sortOrder with map in user.go ORDER BY
P1-11: Add GORM tags to Announcement struct
P1-15: Add pageSize upper limit (100) to device and log handlers
2026-04-18 15:33:12 +08:00

515 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
cryptorand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
jwtAlgorithmHS256 = "HS256"
jwtAlgorithmRS256 = "RS256"
)
// JWTOptions controls JWT signing behavior.
type JWTOptions struct {
Algorithm string
HS256Secret string
RSAPrivateKeyPEM string
RSAPublicKeyPEM string
RSAPrivateKeyPath string
RSAPublicKeyPath string
RequireExistingRSAKeys bool
AccessTokenExpire time.Duration
RefreshTokenExpire time.Duration
RememberLoginExpire time.Duration // 记住登录时的refresh token有效期
}
// JWT JWT管理器
type JWT struct {
algorithm string
secret []byte
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
accessTokenExpire time.Duration
refreshTokenExpire time.Duration
rememberLoginExpire time.Duration
initErr error
}
// Claims JWT声明
type Claims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
Type string `json:"type"` // access, refresh
Remember bool `json:"remember,omitempty"` // 记住登录标记
JTI string `json:"jti"` // JWT ID用于黑名单
PCE int64 `json:"pce,omitempty"` // Password Changed Epoch密码变更时间戳用于 token 失效机制
jwt.RegisteredClaims
}
// generateJTI 生成唯一的 JWT ID
// 使用时间戳 + 密码学安全随机数,防止枚举攻击
// 格式: {timestamp(8字节hex)}{random(16字节hex)},共 24 字符
func generateJTI() (string, error) {
// 时间戳部分8 字节 hex足够 584 年)
timestamp := time.Now().Unix()
// 随机数部分16 字节128 位)
b := make([]byte, 16)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate jwt jti failed: %w", err)
}
// 组合时间戳和随机数timestamp(8字节) + random(16字节) = 24字节 hex
return fmt.Sprintf("%016x%x", timestamp, b), nil
}
// NewJWT creates a legacy HS256 JWT manager for compatibility in tests and callers
// that still only provide a shared secret.
func NewJWT(secret string, accessTokenExpire, refreshTokenExpire time.Duration) *JWT {
manager, err := NewJWTWithOptions(JWTOptions{
Algorithm: jwtAlgorithmHS256,
HS256Secret: secret,
AccessTokenExpire: accessTokenExpire,
RefreshTokenExpire: refreshTokenExpire,
})
if err != nil {
return &JWT{
algorithm: jwtAlgorithmHS256,
accessTokenExpire: accessTokenExpire,
refreshTokenExpire: refreshTokenExpire,
initErr: err,
}
}
return manager
}
func (j *JWT) ensureReady() error {
if j == nil {
return errors.New("jwt manager is nil")
}
if j.initErr != nil {
return j.initErr
}
return nil
}
// NewJWTWithOptions creates a JWT manager from explicit signing options.
func NewJWTWithOptions(opts JWTOptions) (*JWT, error) {
algorithm := strings.ToUpper(strings.TrimSpace(opts.Algorithm))
if algorithm == "" {
if opts.HS256Secret != "" && opts.RSAPrivateKeyPEM == "" && opts.RSAPrivateKeyPath == "" {
algorithm = jwtAlgorithmHS256
} else {
algorithm = jwtAlgorithmRS256
}
}
manager := &JWT{
algorithm: algorithm,
accessTokenExpire: opts.AccessTokenExpire,
refreshTokenExpire: opts.RefreshTokenExpire,
rememberLoginExpire: opts.RememberLoginExpire,
}
switch algorithm {
case jwtAlgorithmHS256:
if opts.HS256Secret == "" {
return nil, errors.New("jwt secret is required for HS256")
}
manager.secret = []byte(opts.HS256Secret)
case jwtAlgorithmRS256:
if err := manager.loadRSAKeys(opts); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported jwt algorithm: %s", algorithm)
}
return manager, nil
}
func (j *JWT) loadRSAKeys(opts JWTOptions) error {
privatePEM, err := readPEM(opts.RSAPrivateKeyPEM, opts.RSAPrivateKeyPath)
if err != nil {
return fmt.Errorf("load jwt private key failed: %w", err)
}
publicPEM, err := readPEM(opts.RSAPublicKeyPEM, opts.RSAPublicKeyPath)
if err != nil {
return fmt.Errorf("load jwt public key failed: %w", err)
}
if privatePEM == "" && publicPEM == "" {
if strings.TrimSpace(opts.RSAPrivateKeyPath) == "" || strings.TrimSpace(opts.RSAPublicKeyPath) == "" {
return errors.New("rsa private/public key paths or inline pem are required for RS256")
}
if opts.RequireExistingRSAKeys {
return errors.New("existing rsa private/public key files or inline pem are required for RS256")
}
privatePEM, publicPEM, err = generateAndPersistRSAKeyPair(opts.RSAPrivateKeyPath, opts.RSAPublicKeyPath)
if err != nil {
return fmt.Errorf("generate rsa key pair failed: %w", err)
}
}
if privatePEM != "" {
privateKey, err := parseRSAPrivateKey(privatePEM)
if err != nil {
return err
}
j.privateKey = privateKey
j.publicKey = &privateKey.PublicKey
}
if publicPEM != "" {
publicKey, err := parseRSAPublicKey(publicPEM)
if err != nil {
return err
}
j.publicKey = publicKey
}
if j.privateKey == nil {
return errors.New("rsa private key is required for signing")
}
if j.publicKey == nil {
return errors.New("rsa public key is required for verification")
}
return nil
}
func generateAndPersistRSAKeyPair(privatePath, publicPath string) (string, string, error) {
privatePath = strings.TrimSpace(privatePath)
publicPath = strings.TrimSpace(publicPath)
if privatePath == "" || publicPath == "" {
return "", "", errors.New("rsa key paths must not be empty")
}
privateKey, err := rsa.GenerateKey(cryptorand.Reader, 2048)
if err != nil {
return "", "", err
}
privateDER := x509.MarshalPKCS1PrivateKey(privateKey)
privatePEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateDER})
publicDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return "", "", err
}
publicPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: publicDER})
if err := os.MkdirAll(filepath.Dir(privatePath), 0o700); err != nil {
return "", "", err
}
if err := os.MkdirAll(filepath.Dir(publicPath), 0o700); err != nil {
return "", "", err
}
if err := os.WriteFile(privatePath, privatePEM, 0o600); err != nil {
return "", "", err
}
if err := os.WriteFile(publicPath, publicPEM, 0o644); err != nil {
return "", "", err
}
return string(privatePEM), string(publicPEM), nil
}
func readPEM(inlinePEM, path string) (string, error) {
inlinePEM = strings.TrimSpace(inlinePEM)
if inlinePEM != "" {
return inlinePEM, nil
}
path = strings.TrimSpace(path)
if path == "" {
return "", nil
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return "", nil
}
return "", err
}
return string(data), nil
}
func parseRSAPrivateKey(pemValue string) (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, errors.New("invalid rsa private key pem")
}
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse rsa private key failed: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not rsa")
}
return rsaKey, nil
}
func parseRSAPublicKey(pemValue string) (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, errors.New("invalid rsa public key pem")
}
if key, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not rsa")
}
return rsaKey, nil
}
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, errors.New("certificate public key is not rsa")
}
return rsaKey, nil
}
return nil, errors.New("parse rsa public key failed")
}
func (j *JWT) signingMethod() jwt.SigningMethod {
if j.algorithm == jwtAlgorithmRS256 {
return jwt.SigningMethodRS256
}
return jwt.SigningMethodHS256
}
func (j *JWT) signingKey() interface{} {
if j.algorithm == jwtAlgorithmRS256 {
return j.privateKey
}
return j.secret
}
func (j *JWT) verifyKey(token *jwt.Token) (interface{}, error) {
if token.Method.Alg() != j.signingMethod().Alg() {
return nil, fmt.Errorf("unexpected signing method: %s", token.Method.Alg())
}
if j.algorithm == jwtAlgorithmRS256 {
return j.publicKey, nil
}
return j.secret, nil
}
// GetAlgorithm returns the configured JWT signing algorithm.
func (j *JWT) GetAlgorithm() string {
return j.algorithm
}
// GenerateAccessToken 生成访问令牌含JTI和密码变更时间戳
func (j *JWT) GenerateAccessToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Type: "access",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// GenerateRefreshToken 生成刷新令牌含JTI和密码变更时间戳
func (j *JWT) GenerateRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Type: "refresh",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// GetAccessTokenExpire 获取访问令牌有效期
func (j *JWT) GetAccessTokenExpire() time.Duration {
return j.accessTokenExpire
}
// GetRefreshTokenExpire 获取刷新令牌有效期
func (j *JWT) GetRefreshTokenExpire() time.Duration {
return j.refreshTokenExpire
}
// GenerateTokenPair 生成令牌对(含密码变更时间戳)
func (j *JWT) GenerateTokenPair(userID int64, username string, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录,含密码变更时间戳)
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
if remember {
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username, pce)
} else {
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
}
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用,含密码变更时间戳)
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
now := time.Now()
jti, err := generateJTI()
if err != nil {
return "", err
}
// 使用rememberLoginExpire如果未配置则使用默认的refreshTokenExpire
expireDuration := j.rememberLoginExpire
if expireDuration == 0 {
expireDuration = j.refreshTokenExpire
}
claims := Claims{
UserID: userID,
Username: username,
Type: "refresh",
Remember: true, // 长期会话标记
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(j.signingMethod(), claims)
return token.SignedString(j.signingKey())
}
// ParseToken 解析令牌
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
if err := j.ensureReady(); err != nil {
return nil, err
}
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return j.verifyKey(token)
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}
// ValidateAccessToken 验证访问令牌
func (j *JWT) ValidateAccessToken(tokenString string) (*Claims, error) {
claims, err := j.ParseToken(tokenString)
if err != nil {
return nil, err
}
if claims.Type != "access" {
return nil, errors.New("invalid token type")
}
return claims, nil
}
// ValidateRefreshToken 验证刷新令牌
func (j *JWT) ValidateRefreshToken(tokenString string) (*Claims, error) {
claims, err := j.ParseToken(tokenString)
if err != nil {
return nil, err
}
if claims.Type != "refresh" {
return nil, errors.New("invalid token type")
}
return claims, nil
}
// RefreshAccessToken 刷新访问令牌
func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) {
claims, err := j.ValidateRefreshToken(refreshTokenString)
if err != nil {
return "", err
}
return j.GenerateAccessToken(claims.UserID, claims.Username, claims.PCE)
}