Files
sub2api-cn-relay-manager/internal/routing/sticky_redis.go
2026-05-29 07:43:29 +08:00

353 lines
9.2 KiB
Go

package routing
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
)
const (
defaultRedisDialTimeout = 3 * time.Second
)
type RedisConfig struct {
Addr string
Password string
DB int
DialTimeout time.Duration
}
type RedisStickyStore struct {
cfg RedisConfig
}
func NewRedisStickyStore(ctx context.Context, cfg RedisConfig) (*RedisStickyStore, error) {
cfg = normalizeRedisConfig(cfg)
if strings.TrimSpace(cfg.Addr) == "" {
return nil, fmt.Errorf("redis addr is required")
}
store := &RedisStickyStore{cfg: cfg}
if err := store.ping(ctx); err != nil {
return nil, err
}
return store, nil
}
func (s *RedisStickyStore) Get(ctx context.Context, key string) (StickyBinding, bool, error) {
key = strings.TrimSpace(key)
if key == "" {
return StickyBinding{}, false, nil
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return StickyBinding{}, ok, err
}
var binding StickyBinding
if err := json.Unmarshal(payload, &binding); err != nil {
return StickyBinding{}, false, fmt.Errorf("decode sticky binding %q: %w", key, err)
}
return binding, true, nil
}
func (s *RedisStickyStore) Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
binding, err := normalizeStickyBinding(binding, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, binding, ttl)
}
func (s *RedisStickyStore) Delete(ctx context.Context, key string) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error) {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return RouteFailureState{}, false, err
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return RouteFailureState{}, ok, err
}
var state RouteFailureState
if err := json.Unmarshal(payload, &state); err != nil {
return RouteFailureState{}, false, fmt.Errorf("decode route failure %q: %w", routeID, err)
}
return state, true, nil
}
func (s *RedisStickyStore) SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return err
}
state, err = normalizeRouteFailureState(routeID, state, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, state, ttl)
}
func (s *RedisStickyStore) ClearRouteFailure(ctx context.Context, routeID string) error {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return err
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error) {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return RouteCooldownState{}, false, err
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return RouteCooldownState{}, ok, err
}
var state RouteCooldownState
if err := json.Unmarshal(payload, &state); err != nil {
return RouteCooldownState{}, false, fmt.Errorf("decode route cooldown %q: %w", routeID, err)
}
return state, true, nil
}
func (s *RedisStickyStore) SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return err
}
state, err = normalizeRouteCooldownState(routeID, state, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, state, ttl)
}
func (s *RedisStickyStore) ClearCooldown(ctx context.Context, routeID string) error {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return err
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) ping(ctx context.Context) error {
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
if err := writeRESPArray(conn, "PING"); err != nil {
return fmt.Errorf("redis ping write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis ping read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "PONG" {
return fmt.Errorf("redis ping unexpected response: kind=%q value=%q", reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) getJSON(ctx context.Context, key string) ([]byte, bool, error) {
conn, reader, err := s.open(ctx)
if err != nil {
return nil, false, err
}
defer conn.Close()
if err := writeRESPArray(conn, "GET", key); err != nil {
return nil, false, fmt.Errorf("redis GET %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return nil, false, fmt.Errorf("redis GET %q: read: %w", key, err)
}
switch reply.kind {
case '$':
return []byte(reply.stringValue), true, nil
case '_':
return nil, false, nil
default:
return nil, false, fmt.Errorf("redis GET %q: unexpected response %q", key, reply.kind)
}
}
func (s *RedisStickyStore) setJSON(ctx context.Context, key string, payload any, ttl time.Duration) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal redis value for %q: %w", key, err)
}
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
seconds := int(ttl / time.Second)
if ttl%time.Second != 0 {
seconds++
}
if seconds <= 0 {
seconds = 1
}
if err := writeRESPArray(conn, "SET", key, string(encoded), "EX", strconv.Itoa(seconds)); err != nil {
return fmt.Errorf("redis SET %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis SET %q: read: %w", key, err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
return fmt.Errorf("redis SET %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) delKey(ctx context.Context, key string) error {
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
if err := writeRESPArray(conn, "DEL", key); err != nil {
return fmt.Errorf("redis DEL %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis DEL %q: read: %w", key, err)
}
if reply.kind != ':' {
return fmt.Errorf("redis DEL %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) open(ctx context.Context) (net.Conn, *bufio.Reader, error) {
cfg := normalizeRedisConfig(s.cfg)
dialer := &net.Dialer{Timeout: cfg.DialTimeout}
conn, err := dialer.DialContext(ctx, "tcp", cfg.Addr)
if err != nil {
return nil, nil, fmt.Errorf("dial redis %q: %w", cfg.Addr, err)
}
reader := bufio.NewReader(conn)
if strings.TrimSpace(cfg.Password) != "" {
if err := writeRESPArray(conn, "AUTH", cfg.Password); err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
}
}
if cfg.DB > 0 {
if err := writeRESPArray(conn, "SELECT", strconv.Itoa(cfg.DB)); err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
}
}
return conn, reader, nil
}
func normalizeRedisConfig(cfg RedisConfig) RedisConfig {
cfg.Addr = strings.TrimSpace(cfg.Addr)
cfg.Password = strings.TrimSpace(cfg.Password)
if cfg.DialTimeout <= 0 {
cfg.DialTimeout = defaultRedisDialTimeout
}
return cfg
}
type respValue struct {
kind byte
stringValue string
}
func writeRESPArray(w io.Writer, parts ...string) error {
if _, err := io.WriteString(w, fmt.Sprintf("*%d\r\n", len(parts))); err != nil {
return err
}
for _, part := range parts {
if _, err := io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(part), part)); err != nil {
return err
}
}
return nil
}
func readRESPValue(r *bufio.Reader) (respValue, error) {
prefix, err := r.ReadByte()
if err != nil {
return respValue{}, err
}
line, err := r.ReadString('\n')
if err != nil {
return respValue{}, err
}
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
switch prefix {
case '+', '-', ':':
if prefix == '-' {
return respValue{}, fmt.Errorf("redis error: %s", line)
}
return respValue{kind: prefix, stringValue: line}, nil
case '$':
size, err := strconv.Atoi(line)
if err != nil {
return respValue{}, fmt.Errorf("parse bulk length: %w", err)
}
if size < 0 {
return respValue{kind: '_'}, nil
}
payload := make([]byte, size+2)
if _, err := io.ReadFull(r, payload); err != nil {
return respValue{}, err
}
return respValue{kind: '$', stringValue: string(payload[:size])}, nil
default:
return respValue{}, fmt.Errorf("unsupported redis response prefix %q", prefix)
}
}