- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification - Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction - Update DisableTOTP to prefer atomic verification path - Add unit tests for atomic verification success/failure paths - Maintain backward compatibility with non-atomic fallback Refs: TOTP verification atomicity completion
537 lines
16 KiB
Go
537 lines
16 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/user-management-system/internal/auth"
|
||
"github.com/user-management-system/internal/domain"
|
||
"github.com/user-management-system/internal/pagination"
|
||
)
|
||
|
||
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _)
|
||
// 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配
|
||
func escapeLikePattern(s string) string {
|
||
// 先转义 \,再转义 % 和 _(顺序很重要)
|
||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||
return s
|
||
}
|
||
|
||
// UserRepository 用户数据访问层
|
||
type UserRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewUserRepository 创建用户数据访问层
|
||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||
return &UserRepository{db: db}
|
||
}
|
||
|
||
// DB returns the underlying GORM DB for transaction support
|
||
func (r *UserRepository) DB() *gorm.DB {
|
||
return r.db
|
||
}
|
||
|
||
// Create 创建用户
|
||
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Create(user).Error
|
||
}
|
||
|
||
// Update 更新用户
|
||
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Save(user).Error
|
||
}
|
||
|
||
// Delete 删除用户(软删除)
|
||
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||
return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取用户
|
||
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByIDs 批量获取用户(消除 N+1 查询)
|
||
func (r *UserRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.User, error) {
|
||
if len(ids) == 0 {
|
||
return []*domain.User{}, nil
|
||
}
|
||
var users []*domain.User
|
||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return users, nil
|
||
}
|
||
|
||
// GetByUsername 根据用户名获取用户
|
||
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByEmail 根据邮箱获取用户
|
||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByPhone 根据手机号获取用户
|
||
func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// List 获取用户列表
|
||
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// ListByStatus 根据状态获取用户列表
|
||
func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status)
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// UpdateStatus 更新用户状态
|
||
func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
|
||
// BatchUpdateStatus 批量更新用户状态
|
||
func (r *UserRepository) BatchUpdateStatus(ctx context.Context, ids []int64, status domain.UserStatus) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id IN ?", ids).Update("status", status).Error
|
||
}
|
||
|
||
// BatchDelete 批量删除用户
|
||
func (r *UserRepository) BatchDelete(ctx context.Context, ids []int64) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
return r.db.WithContext(ctx).Where("id IN ?", ids).Delete(&domain.User{}).Error
|
||
}
|
||
|
||
// UpdateLastLogin 更新最后登录信息
|
||
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
|
||
now := time.Now()
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"last_login_time": &now,
|
||
"last_login_ip": ip,
|
||
}).Error
|
||
}
|
||
|
||
// ExistsByUsername 检查用户名是否存在
|
||
func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// ExistsByEmail 检查邮箱是否存在
|
||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// ExistsByPhone 检查手机号是否存在
|
||
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// Search 搜索用户
|
||
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
// 转义 LIKE 特殊字符,防止搜索被意外干扰
|
||
escapedKeyword := escapeLikePattern(keyword)
|
||
pattern := "%" + escapedKeyword + "%"
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where(
|
||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||
pattern, pattern, pattern, pattern,
|
||
)
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// UpdateTOTP 更新用户的 TOTP 字段
|
||
func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{
|
||
"totp_enabled": user.TOTPEnabled,
|
||
"totp_secret": user.TOTPSecret,
|
||
"totp_recovery_codes": user.TOTPRecoveryCodes,
|
||
}).Error
|
||
}
|
||
|
||
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
|
||
// 在事务中验证恢复码并更新,避免并发竞争窗口
|
||
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||
var user domain.User
|
||
var consumed bool
|
||
|
||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// 在事务中重新获取用户
|
||
// 注意:SQLite 不完全支持 FOR UPDATE,依赖事务隔离
|
||
if err := tx.First(&user, userID).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if !user.TOTPEnabled {
|
||
return errors.New("TOTP 未启用")
|
||
}
|
||
|
||
// 解析存储的哈希恢复码
|
||
var hashedCodes []string
|
||
if user.TOTPRecoveryCodes != "" {
|
||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||
}
|
||
}
|
||
|
||
// 验证恢复码(输入会被哈希后与存储的哈希比较)
|
||
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||
if !matched {
|
||
// 不匹配,标记消费失败但不返回错误
|
||
consumed = false
|
||
return nil
|
||
}
|
||
|
||
// 从列表中移除已使用的恢复码
|
||
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
|
||
codesJSON, err := json.Marshal(hashedCodes)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化恢复码失败: %w", err)
|
||
}
|
||
user.TOTPRecoveryCodes = string(codesJSON)
|
||
|
||
// 在同一事务中更新
|
||
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
consumed = true
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
|
||
return &user, consumed, nil
|
||
}
|
||
|
||
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
|
||
// 返回 (true, nil) 表示验证成功
|
||
// 返回 (false, nil) 表示验证失败(码不匹配)
|
||
// 返回 (false, error) 表示执行出错
|
||
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
|
||
var user domain.User
|
||
|
||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.First(&user, userID).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if !user.TOTPEnabled {
|
||
return errors.New("TOTP 未启用")
|
||
}
|
||
|
||
// 先验证 TOTP 码
|
||
manager := auth.NewTOTPManager()
|
||
if manager.ValidateCode(user.TOTPSecret, code) {
|
||
return nil
|
||
}
|
||
|
||
// TOTP 码无效,尝试验证恢复码
|
||
var hashedCodes []string
|
||
if user.TOTPRecoveryCodes != "" {
|
||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||
}
|
||
}
|
||
|
||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||
if !matched {
|
||
// 恢复码也不匹配,标记验证失败
|
||
return errVerificationFailed
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err == errVerificationFailed {
|
||
return false, nil
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return true, nil
|
||
}
|
||
|
||
// errVerificationFailed 标记验证失败的内部错误
|
||
var errVerificationFailed = errors.New("verification failed")
|
||
|
||
// UpdatePassword 更新用户密码
|
||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||
}
|
||
|
||
// ListCreatedAfter 查询指定时间之后创建的用户(limit=0表示不限制数量)
|
||
func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since)
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if limit > 0 {
|
||
query = query.Offset(offset).Limit(limit)
|
||
}
|
||
if err := query.Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return users, total, nil
|
||
}
|
||
|
||
// AdvancedFilter 高级用户筛选请求
|
||
type AdvancedFilter struct {
|
||
Keyword string // 关键字(用户名/邮箱/手机号/昵称)
|
||
Status int // 状态:-1 全部,0/1/2/3 对应 UserStatus
|
||
RoleIDs []int64 // 角色ID列表(按角色筛选)
|
||
CreatedFrom *time.Time // 注册时间范围(起始)
|
||
CreatedTo *time.Time // 注册时间范围(截止)
|
||
LastLoginFrom *time.Time // 最后登录时间范围(起始)
|
||
SortBy string // 排序字段:created_at, last_login_time, username
|
||
SortOrder string // 排序方向:asc, desc
|
||
Offset int
|
||
Limit int
|
||
}
|
||
|
||
// AdvancedSearch 高级用户搜索(支持多维度组合筛选)
|
||
func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||
|
||
// 关键字搜索(转义 LIKE 特殊字符)
|
||
if filter.Keyword != "" {
|
||
like := "%" + escapeLikePattern(filter.Keyword) + "%"
|
||
query = query.Where(
|
||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||
like, like, like, like,
|
||
)
|
||
}
|
||
|
||
// 状态筛选
|
||
if filter.Status >= 0 {
|
||
query = query.Where("status = ?", filter.Status)
|
||
}
|
||
|
||
// 注册时间范围
|
||
if filter.CreatedFrom != nil {
|
||
query = query.Where("created_at >= ?", filter.CreatedFrom)
|
||
}
|
||
if filter.CreatedTo != nil {
|
||
query = query.Where("created_at <= ?", filter.CreatedTo)
|
||
}
|
||
|
||
// 最后登录时间范围
|
||
if filter.LastLoginFrom != nil {
|
||
query = query.Where("last_login_time >= ?", filter.LastLoginFrom)
|
||
}
|
||
|
||
// 按角色筛选(子查询)
|
||
if len(filter.RoleIDs) > 0 {
|
||
query = query.Where(
|
||
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
|
||
filter.RoleIDs,
|
||
)
|
||
}
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 排序
|
||
sortBy := "created_at"
|
||
sortOrder := "DESC"
|
||
if filter.SortBy != "" {
|
||
allowedFields := map[string]bool{
|
||
"created_at": true, "last_login_time": true,
|
||
"username": true, "updated_at": true,
|
||
}
|
||
if allowedFields[filter.SortBy] {
|
||
sortBy = filter.SortBy
|
||
}
|
||
}
|
||
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
|
||
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
|
||
sortOrder = strings.ToUpper(filter.SortOrder)
|
||
}
|
||
query = query.Order(sortBy + " " + sortOrder)
|
||
|
||
// 分页
|
||
limit := filter.Limit
|
||
if limit <= 0 {
|
||
limit = 20
|
||
}
|
||
if limit > 200 {
|
||
limit = 200
|
||
}
|
||
query = query.Offset(filter.Offset).Limit(limit)
|
||
|
||
if err := query.Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// ListCursor 游标分页查询用户列表(支持筛选)
|
||
// Sort column: created_at DESC, id DESC
|
||
func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter, limit int, cursor *pagination.Cursor) ([]*domain.User, bool, error) {
|
||
var users []*domain.User
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||
|
||
// Apply filters (same as AdvancedFilter)
|
||
if filter.Keyword != "" {
|
||
escapedKeyword := escapeLikePattern(filter.Keyword)
|
||
pattern := "%" + escapedKeyword + "%"
|
||
query = query.Where(
|
||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||
pattern, pattern, pattern, pattern,
|
||
)
|
||
}
|
||
if filter.Status >= 0 && filter.Status <= 3 {
|
||
query = query.Where("status = ?", filter.Status)
|
||
}
|
||
if len(filter.RoleIDs) > 0 {
|
||
query = query.Where(
|
||
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
|
||
filter.RoleIDs,
|
||
)
|
||
}
|
||
if filter.CreatedFrom != nil {
|
||
query = query.Where("created_at >= ?", *filter.CreatedFrom)
|
||
}
|
||
if filter.CreatedTo != nil {
|
||
query = query.Where("created_at <= ?", *filter.CreatedTo)
|
||
}
|
||
|
||
// Apply cursor condition
|
||
// 安全修复:游标分页必须与排序字段一致。
|
||
// 如果排序字段不是 created_at,游标分页会返回错误结果。
|
||
// 因此只有在按 created_at 排序时才允许使用游标。
|
||
sortBy := "created_at"
|
||
if filter.SortBy != "" {
|
||
allowedFields := map[string]bool{
|
||
"created_at": true, "last_login_time": true,
|
||
"username": true, "updated_at": true,
|
||
}
|
||
if allowedFields[filter.SortBy] {
|
||
sortBy = filter.SortBy
|
||
}
|
||
}
|
||
|
||
// 只有在按 created_at 排序时才应用游标条件
|
||
if cursor != nil && cursor.LastID > 0 && sortBy == "created_at" {
|
||
query = query.Where(
|
||
"(created_at < ? OR (created_at = ? AND id < ?))",
|
||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||
)
|
||
}
|
||
|
||
sortOrder := "DESC"
|
||
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
|
||
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
|
||
sortOrder = strings.ToUpper(filter.SortOrder)
|
||
}
|
||
|
||
orderClause := sortBy + " " + sortOrder + ", id " + sortOrder
|
||
if err := query.Order(orderClause).Limit(limit + 1).Find(&users).Error; err != nil {
|
||
return nil, false, err
|
||
}
|
||
|
||
hasMore := len(users) > limit
|
||
if hasMore {
|
||
users = users[:limit]
|
||
}
|
||
return users, hasMore, nil
|
||
}
|