Files
user-system/internal/repository/user.go
Your Name 363c77d020 feat: atomic TOTP verification for DisableTOTP
- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification
- Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction
- Update DisableTOTP to prefer atomic verification path
- Add unit tests for atomic verification success/failure paths
- Maintain backward compatibility with non-atomic fallback

Refs: TOTP verification atomicity completion
2026-05-29 12:47:05 +08:00

537 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _
// 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配
func escapeLikePattern(s string) string {
// 先转义 \,再转义 % 和 _顺序很重要
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}
// UserRepository 用户数据访问层
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户数据访问层
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// DB returns the underlying GORM DB for transaction support
func (r *UserRepository) DB() *gorm.DB {
return r.db
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
// Delete 删除用户(软删除)
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByIDs 批量获取用户(消除 N+1 查询)
func (r *UserRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.User, error) {
if len(ids) == 0 {
return []*domain.User{}, nil
}
var users []*domain.User
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
if err != nil {
return nil, err
}
return users, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByPhone 根据手机号获取用户
func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// List 获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// ListByStatus 根据状态获取用户列表
func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// UpdateStatus 更新用户状态
func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error
}
// BatchUpdateStatus 批量更新用户状态
func (r *UserRepository) BatchUpdateStatus(ctx context.Context, ids []int64, status domain.UserStatus) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id IN ?", ids).Update("status", status).Error
}
// BatchDelete 批量删除用户
func (r *UserRepository) BatchDelete(ctx context.Context, ids []int64) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Where("id IN ?", ids).Delete(&domain.User{}).Error
}
// UpdateLastLogin 更新最后登录信息
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{
"last_login_time": &now,
"last_login_ip": ip,
}).Error
}
// ExistsByUsername 检查用户名是否存在
func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error
return count > 0, err
}
// ExistsByEmail 检查邮箱是否存在
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
}
// ExistsByPhone 检查手机号是否存在
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error
return count > 0, err
}
// Search 搜索用户
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.User{}).Where(
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
pattern, pattern, pattern, pattern,
)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// UpdateTOTP 更新用户的 TOTP 字段
func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{
"totp_enabled": user.TOTPEnabled,
"totp_secret": user.TOTPSecret,
"totp_recovery_codes": user.TOTPRecoveryCodes,
}).Error
}
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
// 在事务中验证恢复码并更新,避免并发竞争窗口
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
var user domain.User
var consumed bool
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 在事务中重新获取用户
// 注意SQLite 不完全支持 FOR UPDATE依赖事务隔离
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 解析存储的哈希恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
// 验证恢复码(输入会被哈希后与存储的哈希比较)
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 不匹配,标记消费失败但不返回错误
consumed = false
return nil
}
// 从列表中移除已使用的恢复码
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
// 在同一事务中更新
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
return err
}
consumed = true
return nil
})
if err != nil {
return nil, false, err
}
return &user, consumed, nil
}
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
// 返回 (true, nil) 表示验证成功
// 返回 (false, nil) 表示验证失败(码不匹配)
// 返回 (false, error) 表示执行出错
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
var user domain.User
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 先验证 TOTP 码
manager := auth.NewTOTPManager()
if manager.ValidateCode(user.TOTPSecret, code) {
return nil
}
// TOTP 码无效,尝试验证恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 恢复码也不匹配,标记验证失败
return errVerificationFailed
}
return nil
})
if err == errVerificationFailed {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// errVerificationFailed 标记验证失败的内部错误
var errVerificationFailed = errors.New("verification failed")
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
}
// ListCreatedAfter 查询指定时间之后创建的用户limit=0表示不限制数量
func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if limit > 0 {
query = query.Offset(offset).Limit(limit)
}
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// AdvancedFilter 高级用户筛选请求
type AdvancedFilter struct {
Keyword string // 关键字(用户名/邮箱/手机号/昵称)
Status int // 状态:-1 全部0/1/2/3 对应 UserStatus
RoleIDs []int64 // 角色ID列表按角色筛选
CreatedFrom *time.Time // 注册时间范围(起始)
CreatedTo *time.Time // 注册时间范围(截止)
LastLoginFrom *time.Time // 最后登录时间范围(起始)
SortBy string // 排序字段created_at, last_login_time, username
SortOrder string // 排序方向asc, desc
Offset int
Limit int
}
// AdvancedSearch 高级用户搜索(支持多维度组合筛选)
func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{})
// 关键字搜索(转义 LIKE 特殊字符)
if filter.Keyword != "" {
like := "%" + escapeLikePattern(filter.Keyword) + "%"
query = query.Where(
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
like, like, like, like,
)
}
// 状态筛选
if filter.Status >= 0 {
query = query.Where("status = ?", filter.Status)
}
// 注册时间范围
if filter.CreatedFrom != nil {
query = query.Where("created_at >= ?", filter.CreatedFrom)
}
if filter.CreatedTo != nil {
query = query.Where("created_at <= ?", filter.CreatedTo)
}
// 最后登录时间范围
if filter.LastLoginFrom != nil {
query = query.Where("last_login_time >= ?", filter.LastLoginFrom)
}
// 按角色筛选(子查询)
if len(filter.RoleIDs) > 0 {
query = query.Where(
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
filter.RoleIDs,
)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 排序
sortBy := "created_at"
sortOrder := "DESC"
if filter.SortBy != "" {
allowedFields := map[string]bool{
"created_at": true, "last_login_time": true,
"username": true, "updated_at": true,
}
if allowedFields[filter.SortBy] {
sortBy = filter.SortBy
}
}
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
sortOrder = strings.ToUpper(filter.SortOrder)
}
query = query.Order(sortBy + " " + sortOrder)
// 分页
limit := filter.Limit
if limit <= 0 {
limit = 20
}
if limit > 200 {
limit = 200
}
query = query.Offset(filter.Offset).Limit(limit)
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// ListCursor 游标分页查询用户列表(支持筛选)
// Sort column: created_at DESC, id DESC
func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter, limit int, cursor *pagination.Cursor) ([]*domain.User, bool, error) {
var users []*domain.User
query := r.db.WithContext(ctx).Model(&domain.User{})
// Apply filters (same as AdvancedFilter)
if filter.Keyword != "" {
escapedKeyword := escapeLikePattern(filter.Keyword)
pattern := "%" + escapedKeyword + "%"
query = query.Where(
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
pattern, pattern, pattern, pattern,
)
}
if filter.Status >= 0 && filter.Status <= 3 {
query = query.Where("status = ?", filter.Status)
}
if len(filter.RoleIDs) > 0 {
query = query.Where(
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
filter.RoleIDs,
)
}
if filter.CreatedFrom != nil {
query = query.Where("created_at >= ?", *filter.CreatedFrom)
}
if filter.CreatedTo != nil {
query = query.Where("created_at <= ?", *filter.CreatedTo)
}
// Apply cursor condition
// 安全修复:游标分页必须与排序字段一致。
// 如果排序字段不是 created_at游标分页会返回错误结果。
// 因此只有在按 created_at 排序时才允许使用游标。
sortBy := "created_at"
if filter.SortBy != "" {
allowedFields := map[string]bool{
"created_at": true, "last_login_time": true,
"username": true, "updated_at": true,
}
if allowedFields[filter.SortBy] {
sortBy = filter.SortBy
}
}
// 只有在按 created_at 排序时才应用游标条件
if cursor != nil && cursor.LastID > 0 && sortBy == "created_at" {
query = query.Where(
"(created_at < ? OR (created_at = ? AND id < ?))",
cursor.LastValue, cursor.LastValue, cursor.LastID,
)
}
sortOrder := "DESC"
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
sortOrder = strings.ToUpper(filter.SortOrder)
}
orderClause := sortBy + " " + sortOrder + ", id " + sortOrder
if err := query.Order(orderClause).Limit(limit + 1).Find(&users).Error; err != nil {
return nil, false, err
}
hasMore := len(users) > limit
if hasMore {
users = users[:limit]
}
return users, hasMore, nil
}