Files
user-system/internal/auth/totp.go

193 lines
5.8 KiB
Go
Raw Normal View History

package auth
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base32"
"encoding/base64"
"encoding/hex"
"fmt"
"image/png"
"regexp"
"strings"
"time"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/bcrypt"
)
const (
// TOTPIssuer 应用名称(显示在 Authenticator App 中)
TOTPIssuer = "UserManagementSystem"
// TOTPPeriod TOTP 时间步长(秒)
TOTPPeriod = 30
// TOTPDigits TOTP 位数
TOTPDigits = 6
// TOTPAlgorithm TOTP 算法(使用 SHA256 更安全)
TOTPAlgorithm = otp.AlgorithmSHA256
// RecoveryCodeCount 恢复码数量
RecoveryCodeCount = 8
// RecoveryCodeLength 每个恢复码的字节长度(生成后编码为 hex 字符串)
RecoveryCodeLength = 5
)
// TOTPManager TOTP 管理器
type TOTPManager struct{}
// NewTOTPManager 创建 TOTP 管理器
func NewTOTPManager() *TOTPManager {
return &TOTPManager{}
}
// TOTPSetup TOTP 初始化结果
type TOTPSetup struct {
Secret string `json:"secret"` // Base32 密钥(用户备用)
QRCodeBase64 string `json:"qr_code_base64"` // Base64 编码的 PNG 二维码图片
RecoveryCodes []string `json:"recovery_codes"` // 一次性恢复码列表
}
// GenerateSecret 为指定用户生成 TOTP 密钥及二维码
func (m *TOTPManager) GenerateSecret(username string) (*TOTPSetup, error) {
key, err := totp.Generate(totp.GenerateOpts{
Issuer: TOTPIssuer,
AccountName: username,
Period: TOTPPeriod,
Digits: otp.DigitsSix,
Algorithm: TOTPAlgorithm,
})
if err != nil {
return nil, fmt.Errorf("generate totp key failed: %w", err)
}
// 生成二维码图片
img, err := key.Image(200, 200)
if err != nil {
return nil, fmt.Errorf("generate qr image failed: %w", err)
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, fmt.Errorf("encode qr image failed: %w", err)
}
qrBase64 := base64.StdEncoding.EncodeToString(buf.Bytes())
// 生成恢复码
codes, err := generateRecoveryCodes(RecoveryCodeCount)
if err != nil {
return nil, fmt.Errorf("generate recovery codes failed: %w", err)
}
return &TOTPSetup{
Secret: key.Secret(),
QRCodeBase64: qrBase64,
RecoveryCodes: codes,
}, nil
}
// ValidateCode 验证用户输入的 TOTP 码(允许 ±1 个时间窗口的时钟偏差)
func (m *TOTPManager) ValidateCode(secret, code string) bool {
// 注意pquerna/otp 库的 ValidateCustom 与 GenerateCode 存在算法不匹配 bugGenerateCode 固定用 SHA1
// 因此使用 totp.Validate() 代替,它内部正确处理算法检测
return totp.Validate(strings.TrimSpace(code), secret)
}
// GenerateCurrentCode 生成当前时间的 TOTP 码(用于测试)
func (m *TOTPManager) GenerateCurrentCode(secret string) (string, error) {
return totp.GenerateCode(secret, time.Now().UTC())
}
// ValidateRecoveryCode 验证恢复码(传入哈希后的已存储恢复码列表,返回匹配索引)
// 注意:调用方负责在验证后将该恢复码标记为已使用
// 使用恒定时间比较防止时序攻击
func ValidateRecoveryCode(inputCode string, storedCodes []string) (int, bool) {
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
for i, stored := range storedCodes {
storedNormalized := strings.ToUpper(strings.ReplaceAll(stored, "-", ""))
// 使用恒定时间比较防止时序攻击
if subtle.ConstantTimeCompare([]byte(normalized), []byte(storedNormalized)) == 1 {
return i, true
}
}
return -1, false
}
// HashRecoveryCode 使用 bcrypt 慢哈希恢复码(用于存储)
// P2 安全增强:将 SHA256 快速哈希升级为 bcrypt 慢哈希,防止 GPU 暴力破解
func HashRecoveryCode(code string) (string, error) {
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(code), "-", ""))
hash, err := bcrypt.GenerateFromPassword([]byte(normalized), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("hash recovery code failed: %w", err)
}
return string(hash), nil
}
// sha256HexPattern 匹配 64 位十六进制字符串(旧版 SHA256 哈希格式)
var sha256HexPattern = regexp.MustCompile("^[0-9a-fA-F]{64}$")
// isLegacySHA256Hash 检测是否为旧版 SHA256 哈希值
func isLegacySHA256Hash(hash string) bool {
return sha256HexPattern.MatchString(hash)
}
// legacyHashRecoveryCode 旧版 SHA256 哈希(用于向后兼容验证)
func legacyHashRecoveryCode(code string) string {
h := sha256.Sum256([]byte(code))
return hex.EncodeToString(h[:])
}
// VerifyRecoveryCode 验证恢复码(支持 bcrypt 新哈希和 SHA256 旧哈希向后兼容)
// 使用恒定时间比较防止时序攻击
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
found := -1
for i := 0; i < len(hashedCodes); i++ {
stored := hashedCodes[i]
if stored == "" {
continue
}
var matched bool
if isLegacySHA256Hash(stored) {
// 向后兼容:旧版 SHA256 哈希
hashedInput := legacyHashRecoveryCode(inputCode)
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(stored)) == 1 {
matched = true
}
} else {
// 新版 bcrypt 哈希
if err := bcrypt.CompareHashAndPassword([]byte(stored), []byte(normalized)); err == nil {
matched = true
}
}
if matched {
found = i
}
}
if found >= 0 {
return found, true
}
return -1, false
}
// generateRecoveryCodes 生成 N 个随机恢复码格式XXXXX-XXXXX
func generateRecoveryCodes(count int) ([]string, error) {
codes := make([]string, count)
for i := 0; i < count; i++ {
b := make([]byte, RecoveryCodeLength*2)
if _, err := rand.Read(b); err != nil {
return nil, err
}
encoded := base32.StdEncoding.EncodeToString(b)
// 格式化为 XXXXX-XXXXX
part := strings.ToUpper(encoded[:10])
codes[i] = part[:5] + "-" + part[5:]
}
return codes, nil
}