353 lines
9.2 KiB
Go
353 lines
9.2 KiB
Go
package routing
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultRedisDialTimeout = 3 * time.Second
|
|
)
|
|
|
|
type RedisConfig struct {
|
|
Addr string
|
|
Password string
|
|
DB int
|
|
DialTimeout time.Duration
|
|
}
|
|
|
|
type RedisStickyStore struct {
|
|
cfg RedisConfig
|
|
}
|
|
|
|
func NewRedisStickyStore(ctx context.Context, cfg RedisConfig) (*RedisStickyStore, error) {
|
|
cfg = normalizeRedisConfig(cfg)
|
|
if strings.TrimSpace(cfg.Addr) == "" {
|
|
return nil, fmt.Errorf("redis addr is required")
|
|
}
|
|
|
|
store := &RedisStickyStore{cfg: cfg}
|
|
if err := store.ping(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) Get(ctx context.Context, key string) (StickyBinding, bool, error) {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
return StickyBinding{}, false, nil
|
|
}
|
|
|
|
payload, ok, err := s.getJSON(ctx, key)
|
|
if err != nil || !ok {
|
|
return StickyBinding{}, ok, err
|
|
}
|
|
var binding StickyBinding
|
|
if err := json.Unmarshal(payload, &binding); err != nil {
|
|
return StickyBinding{}, false, fmt.Errorf("decode sticky binding %q: %w", key, err)
|
|
}
|
|
return binding, true, nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
return nil
|
|
}
|
|
binding, err := normalizeStickyBinding(binding, ttl, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.setJSON(ctx, key, binding, ttl)
|
|
}
|
|
|
|
func (s *RedisStickyStore) Delete(ctx context.Context, key string) error {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
return nil
|
|
}
|
|
return s.delKey(ctx, key)
|
|
}
|
|
|
|
func (s *RedisStickyStore) GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error) {
|
|
key, err := BuildRouteFailureKey(routeID)
|
|
if err != nil {
|
|
return RouteFailureState{}, false, err
|
|
}
|
|
payload, ok, err := s.getJSON(ctx, key)
|
|
if err != nil || !ok {
|
|
return RouteFailureState{}, ok, err
|
|
}
|
|
var state RouteFailureState
|
|
if err := json.Unmarshal(payload, &state); err != nil {
|
|
return RouteFailureState{}, false, fmt.Errorf("decode route failure %q: %w", routeID, err)
|
|
}
|
|
return state, true, nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error {
|
|
key, err := BuildRouteFailureKey(routeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
state, err = normalizeRouteFailureState(routeID, state, ttl, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.setJSON(ctx, key, state, ttl)
|
|
}
|
|
|
|
func (s *RedisStickyStore) ClearRouteFailure(ctx context.Context, routeID string) error {
|
|
key, err := BuildRouteFailureKey(routeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.delKey(ctx, key)
|
|
}
|
|
|
|
func (s *RedisStickyStore) GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error) {
|
|
key, err := BuildRouteCooldownKey(routeID)
|
|
if err != nil {
|
|
return RouteCooldownState{}, false, err
|
|
}
|
|
payload, ok, err := s.getJSON(ctx, key)
|
|
if err != nil || !ok {
|
|
return RouteCooldownState{}, ok, err
|
|
}
|
|
var state RouteCooldownState
|
|
if err := json.Unmarshal(payload, &state); err != nil {
|
|
return RouteCooldownState{}, false, fmt.Errorf("decode route cooldown %q: %w", routeID, err)
|
|
}
|
|
return state, true, nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error {
|
|
key, err := BuildRouteCooldownKey(routeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
state, err = normalizeRouteCooldownState(routeID, state, ttl, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.setJSON(ctx, key, state, ttl)
|
|
}
|
|
|
|
func (s *RedisStickyStore) ClearCooldown(ctx context.Context, routeID string) error {
|
|
key, err := BuildRouteCooldownKey(routeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.delKey(ctx, key)
|
|
}
|
|
|
|
func (s *RedisStickyStore) ping(ctx context.Context) error {
|
|
conn, reader, err := s.open(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
if err := writeRESPArray(conn, "PING"); err != nil {
|
|
return fmt.Errorf("redis ping write: %w", err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
return fmt.Errorf("redis ping read: %w", err)
|
|
}
|
|
if reply.kind != '+' || reply.stringValue != "PONG" {
|
|
return fmt.Errorf("redis ping unexpected response: kind=%q value=%q", reply.kind, reply.stringValue)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) getJSON(ctx context.Context, key string) ([]byte, bool, error) {
|
|
conn, reader, err := s.open(ctx)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := writeRESPArray(conn, "GET", key); err != nil {
|
|
return nil, false, fmt.Errorf("redis GET %q: write: %w", key, err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
return nil, false, fmt.Errorf("redis GET %q: read: %w", key, err)
|
|
}
|
|
switch reply.kind {
|
|
case '$':
|
|
return []byte(reply.stringValue), true, nil
|
|
case '_':
|
|
return nil, false, nil
|
|
default:
|
|
return nil, false, fmt.Errorf("redis GET %q: unexpected response %q", key, reply.kind)
|
|
}
|
|
}
|
|
|
|
func (s *RedisStickyStore) setJSON(ctx context.Context, key string, payload any, ttl time.Duration) error {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal redis value for %q: %w", key, err)
|
|
}
|
|
|
|
conn, reader, err := s.open(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
seconds := int(ttl / time.Second)
|
|
if ttl%time.Second != 0 {
|
|
seconds++
|
|
}
|
|
if seconds <= 0 {
|
|
seconds = 1
|
|
}
|
|
|
|
if err := writeRESPArray(conn, "SET", key, string(encoded), "EX", strconv.Itoa(seconds)); err != nil {
|
|
return fmt.Errorf("redis SET %q: write: %w", key, err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
return fmt.Errorf("redis SET %q: read: %w", key, err)
|
|
}
|
|
if reply.kind != '+' || reply.stringValue != "OK" {
|
|
return fmt.Errorf("redis SET %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) delKey(ctx context.Context, key string) error {
|
|
conn, reader, err := s.open(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := writeRESPArray(conn, "DEL", key); err != nil {
|
|
return fmt.Errorf("redis DEL %q: write: %w", key, err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
return fmt.Errorf("redis DEL %q: read: %w", key, err)
|
|
}
|
|
if reply.kind != ':' {
|
|
return fmt.Errorf("redis DEL %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStickyStore) open(ctx context.Context) (net.Conn, *bufio.Reader, error) {
|
|
cfg := normalizeRedisConfig(s.cfg)
|
|
dialer := &net.Dialer{Timeout: cfg.DialTimeout}
|
|
conn, err := dialer.DialContext(ctx, "tcp", cfg.Addr)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("dial redis %q: %w", cfg.Addr, err)
|
|
}
|
|
reader := bufio.NewReader(conn)
|
|
|
|
if strings.TrimSpace(cfg.Password) != "" {
|
|
if err := writeRESPArray(conn, "AUTH", cfg.Password); err != nil {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis AUTH write: %w", err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis AUTH read: %w", err)
|
|
}
|
|
if reply.kind != '+' || reply.stringValue != "OK" {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis AUTH unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
|
|
}
|
|
}
|
|
|
|
if cfg.DB > 0 {
|
|
if err := writeRESPArray(conn, "SELECT", strconv.Itoa(cfg.DB)); err != nil {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis SELECT write: %w", err)
|
|
}
|
|
reply, err := readRESPValue(reader)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis SELECT read: %w", err)
|
|
}
|
|
if reply.kind != '+' || reply.stringValue != "OK" {
|
|
conn.Close()
|
|
return nil, nil, fmt.Errorf("redis SELECT unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
|
|
}
|
|
}
|
|
|
|
return conn, reader, nil
|
|
}
|
|
|
|
func normalizeRedisConfig(cfg RedisConfig) RedisConfig {
|
|
cfg.Addr = strings.TrimSpace(cfg.Addr)
|
|
cfg.Password = strings.TrimSpace(cfg.Password)
|
|
if cfg.DialTimeout <= 0 {
|
|
cfg.DialTimeout = defaultRedisDialTimeout
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
type respValue struct {
|
|
kind byte
|
|
stringValue string
|
|
}
|
|
|
|
func writeRESPArray(w io.Writer, parts ...string) error {
|
|
if _, err := io.WriteString(w, fmt.Sprintf("*%d\r\n", len(parts))); err != nil {
|
|
return err
|
|
}
|
|
for _, part := range parts {
|
|
if _, err := io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(part), part)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readRESPValue(r *bufio.Reader) (respValue, error) {
|
|
prefix, err := r.ReadByte()
|
|
if err != nil {
|
|
return respValue{}, err
|
|
}
|
|
|
|
line, err := r.ReadString('\n')
|
|
if err != nil {
|
|
return respValue{}, err
|
|
}
|
|
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
|
|
|
|
switch prefix {
|
|
case '+', '-', ':':
|
|
if prefix == '-' {
|
|
return respValue{}, fmt.Errorf("redis error: %s", line)
|
|
}
|
|
return respValue{kind: prefix, stringValue: line}, nil
|
|
case '$':
|
|
size, err := strconv.Atoi(line)
|
|
if err != nil {
|
|
return respValue{}, fmt.Errorf("parse bulk length: %w", err)
|
|
}
|
|
if size < 0 {
|
|
return respValue{kind: '_'}, nil
|
|
}
|
|
payload := make([]byte, size+2)
|
|
if _, err := io.ReadFull(r, payload); err != nil {
|
|
return respValue{}, err
|
|
}
|
|
return respValue{kind: '$', stringValue: string(payload[:size])}, nil
|
|
default:
|
|
return respValue{}, fmt.Errorf("unsupported redis response prefix %q", prefix)
|
|
}
|
|
}
|