fix: P0-02 prevent login attempt counter race condition

Add atomic Increment method to cache layers:
- L2Cache interface: add Increment method signature
- RedisCache: implement using Redis INCRBY
- L1Cache: implement with mutex-protected counter
- CacheManager: add Increment that updates both L1 and L2

Update incrementFailAttempts to use atomic Increment instead
of Get-Increment-Set pattern, preventing TOCTOU race.
This commit is contained in:
2026-04-18 13:45:09 +08:00
parent 32a3d4c9e0
commit ca7ba5ccdf
4 changed files with 84 additions and 9 deletions

View File

@@ -106,3 +106,16 @@ func (cm *CacheManager) GetL1() *L1Cache {
func (cm *CacheManager) GetL2() L2Cache {
return cm.l2
}
// Increment 原子递增同时更新L1和L2
func (cm *CacheManager) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
// 先更新L1
cm.l1.Increment(key, delta, ttl)
// 再更新L2
if cm.l2 != nil {
return cm.l2.Increment(ctx, key, delta, ttl)
}
return cm.l1.Increment(key, 0, 0), nil
}

41
internal/cache/l1.go vendored
View File

@@ -169,3 +169,44 @@ func (c *L1Cache) Cleanup() {
c.removeFromAccessOrder(key)
}
}
// Increment 原子递增(用于登录失败计数器等原子操作场景)
func (c *L1Cache) Increment(key string, delta int64, ttl time.Duration) int64 {
c.mu.Lock()
defer c.mu.Unlock()
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
current := int64(0)
if item, ok := c.items[key]; ok {
if item.Expired() {
delete(c.items, key)
c.removeFromAccessOrder(key)
} else {
if v, ok := item.Value.(int64); ok {
current = v
} else if v, ok := item.Value.(int); ok {
current = int64(v)
} else if v, ok := item.Value.(float64); ok {
current = int64(v)
}
}
}
newVal := current + delta
c.items[key] = &CacheItem{
Value: newVal,
Expiration: expiration,
}
if _, exists := c.items[key]; !exists {
c.accessOrder = append(c.accessOrder, key)
} else {
c.updateAccessOrder(key)
}
return newVal
}

15
internal/cache/l2.go vendored
View File

@@ -17,6 +17,7 @@ type L2Cache interface {
Delete(ctx context.Context, key string) error
Exists(ctx context.Context, key string) (bool, error)
Clear(ctx context.Context) error
Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error)
Close() error
}
@@ -127,6 +128,20 @@ func (c *RedisCache) Close() error {
return c.client.Close()
}
func (c *RedisCache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
if !c.enabled || c.client == nil {
return 0, errors.New("redis is not enabled")
}
result, err := c.client.IncrBy(ctx, key, delta).Result()
if err != nil {
return 0, err
}
if ttl > 0 {
c.client.Expire(ctx, key, ttl)
}
return result, nil
}
func decodeRedisValue(raw string) (interface{}, error) {
decoder := json.NewDecoder(strings.NewReader(raw))
decoder.UseNumber()

View File

@@ -494,17 +494,23 @@ func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int
return 0
}
current := 0
if value, ok := s.cache.Get(ctx, key); ok {
current = attemptCount(value)
}
current++
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil {
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err)
// 使用原子递增,避免竞态条件
newVal, err := s.cache.Increment(ctx, key, 1, s.loginLockDuration)
if err != nil {
log.Printf("auth: increment login attempts failed, key=%s err=%v", key, err)
// 回退到原来的非原子方式
current := 0
if value, ok := s.cache.Get(ctx, key); ok {
current = attemptCount(value)
}
current++
if setErr := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); setErr != nil {
log.Printf("auth: store login attempts failed, key=%s err=%v", key, setErr)
}
return current
}
return current
return int(newVal)
}
func isValidPhoneSimple(phone string) bool {