60 lines
2.1 KiB
Go
60 lines
2.1 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type SubjectRateLimitWindow struct {
|
|
ID int64 `json:"-"`
|
|
SubjectID string `json:"subject_id"`
|
|
Action string `json:"action"`
|
|
WindowStart string `json:"window_start"`
|
|
HitCount int64 `json:"hit_count"`
|
|
UpdatedAt string `json:"updated_at"`
|
|
}
|
|
|
|
type SubjectRateLimitsRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newSubjectRateLimitsRepo(db execQuerier) *SubjectRateLimitsRepo {
|
|
return &SubjectRateLimitsRepo{db: db}
|
|
}
|
|
|
|
func (r *SubjectRateLimitsRepo) IncrementWindow(ctx context.Context, subjectID, action, windowStart string) (int64, error) {
|
|
subjectID = strings.TrimSpace(subjectID)
|
|
action = strings.ToLower(strings.TrimSpace(action))
|
|
windowStart = strings.TrimSpace(windowStart)
|
|
if subjectID == "" {
|
|
return 0, fmt.Errorf("subject_id is required")
|
|
}
|
|
if action == "" {
|
|
return 0, fmt.Errorf("action is required")
|
|
}
|
|
if windowStart == "" {
|
|
return 0, fmt.Errorf("window_start is required")
|
|
}
|
|
_, err := r.db.ExecContext(ctx, `INSERT INTO subject_rate_limits (subject_id, action, window_start, hit_count, updated_at)
|
|
VALUES (?, ?, ?, 1, strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
|
ON CONFLICT(subject_id, action, window_start)
|
|
DO UPDATE SET hit_count = hit_count + 1, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')`, subjectID, action, windowStart)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("increment subject rate limit %s/%s/%s: %w", subjectID, action, windowStart, err)
|
|
}
|
|
return r.GetCount(ctx, subjectID, action, windowStart)
|
|
}
|
|
|
|
func (r *SubjectRateLimitsRepo) GetCount(ctx context.Context, subjectID, action, windowStart string) (int64, error) {
|
|
subjectID = strings.TrimSpace(subjectID)
|
|
action = strings.ToLower(strings.TrimSpace(action))
|
|
windowStart = strings.TrimSpace(windowStart)
|
|
row := r.db.QueryRowContext(ctx, `SELECT hit_count FROM subject_rate_limits WHERE subject_id = ? AND action = ? AND window_start = ?`, subjectID, action, windowStart)
|
|
var count int64
|
|
if err := row.Scan(&count); err != nil {
|
|
return 0, fmt.Errorf("get subject rate limit %s/%s/%s: %w", subjectID, action, windowStart, err)
|
|
}
|
|
return count, nil
|
|
}
|