feat(vNext.4): implement trusted-subject security chain for portal user key self-service

- Add portal_auth.go: Portal user session auth with HMAC-signed cookies
- Add /api/portal/session/{login,logout,state} endpoints
- Update nginx config template: cookie-to-header trusted proxy pattern
- Update frontend: sync CRM session on login/logout
- Add TRUSTED_SUBJECT_DEPLOY_GUIDE.md with remote43 deployment steps
- Update EXECUTION_BOARD.md: mark trusted-subject blocking issue as resolved

This implements the secure chain:
  Browser → Portal → nginx (cookie→header) → CRM (verify proxy secret)

Required remote43 actions:
1. Generate 64-char hex secret
2. Update .env.crm with TRUSTED_* config
3. Update nginx with cookie map and header injection
4. Restart services

Fixes EXECUTION_BOARD.md 2026-06-08 blocking issue
This commit is contained in:
phamnazage-jpg
2026-06-09 07:48:03 +08:00
parent dd6f332b53
commit 4e2ee087fd
25 changed files with 1861 additions and 177 deletions

View File

@@ -141,7 +141,7 @@ func TestAPIAdminSessionLoginSetsCookieAndAuthorizesSubsequentRequest(t *testing
ListPacks: func(context.Context) ([]PackInfo, error) {
return []PackInfo{{PackID: "openai-cn-pack", Version: "1.1.6"}}, nil
},
}, "")
}, "", "")
loginRequest := httptestRequest(t, http.MethodPost, "/api/admin/session/login", map[string]any{
"username": "admin",
@@ -177,7 +177,7 @@ func TestAPIAdminSessionRejectsInvalidPassword(t *testing.T) {
Token: "secret-token",
Username: "admin",
Password: "pass-123",
}, ActionSet{}, "")
}, ActionSet{}, "", "")
request := httptestRequest(t, http.MethodPost, "/api/admin/session/login", map[string]any{
"username": "admin",
"password": "wrong",
@@ -192,7 +192,7 @@ func TestAPIAdminSessionLogoutClearsCookie(t *testing.T) {
Token: "secret-token",
Username: "admin",
Password: "pass-123",
}, ActionSet{}, "")
}, ActionSet{}, "", "")
request := httptestRequest(t, http.MethodPost, "/api/admin/session/logout", nil, "")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusNoContent)
@@ -219,7 +219,7 @@ func TestAPIAdminSessionMeReportsAuthenticationState(t *testing.T) {
Now: func() time.Time {
return now
},
}, ActionSet{}, "")
}, ActionSet{}, "", "")
request := httptestRequest(t, http.MethodGet, "/api/admin/session", nil, "")
response := httptestRecorder(handler, request)

View File

@@ -30,7 +30,7 @@ func Bootstrap(ctx context.Context) (*Server, error) {
Username: adminSession.Username,
Password: adminSession.Password,
SessionTTL: adminSession.SessionTTL,
}, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime), cfg.Database.SQLiteDSN)
}, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime, cfg.UserKeyAuth), cfg.Database.SQLiteDSN, cfg.UserKeyAuth.TrustedProxySecret)
return NewServer(cfg.Server.ListenAddr, handler, nil), nil
}

View File

@@ -336,14 +336,10 @@ func NewAPIHandler(adminToken string, actions ActionSet, dsn ...string) http.Han
if len(dsn) > 0 {
dsnVal = dsn[0]
}
return NewAPIHandlerWithAuth(AdminAuthConfig{Token: adminToken}, actions, dsnVal)
return NewAPIHandlerWithAuth(AdminAuthConfig{Token: adminToken}, actions, dsnVal, "")
}
func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, dsn ...string) http.Handler {
sqliteDSN := ""
if len(dsn) > 0 {
sqliteDSN = dsn[0]
}
func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, sqliteDSN string, portalSessionSecret string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", healthz)
mux.HandleFunc("GET /version", handleVersion)
@@ -366,6 +362,19 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, dsn ...
mux.HandleFunc("GET /api/portal/logical-groups/{groupID}/models", func(w http.ResponseWriter, r *http.Request) {
handleListPortalLogicalGroupModels(w, r, actions.ListPortalLogicalGroupModels)
})
// Portal user session endpoints
portalAuth := PortalAuthConfig{
SessionSecret: portalSessionSecret,
}
mux.HandleFunc("GET /api/portal/session", func(w http.ResponseWriter, r *http.Request) {
handlePortalSessionState(w, r, portalAuth)
})
mux.HandleFunc("POST /api/portal/session/login", func(w http.ResponseWriter, r *http.Request) {
handlePortalSessionLogin(w, r, portalAuth)
})
mux.HandleFunc("POST /api/portal/session/logout", func(w http.ResponseWriter, r *http.Request) {
handlePortalSessionLogout(w, r)
})
mux.Handle("POST /api/batch-import/runs", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateBatchImportRun(w, r, actions.CreateBatchImportRun)
})))
@@ -1296,6 +1305,14 @@ func writeJSON(w http.ResponseWriter, statusCode int, body any) {
_ = json.NewEncoder(w).Encode(body)
}
func nonEmptyString(value, fallback string) string {
value = strings.TrimSpace(value)
if value != "" {
return value
}
return fallback
}
func classifyError(err error) *httpError {
if err == nil {
return nil
@@ -1337,7 +1354,7 @@ func NewActionSet(sqliteDSN string) ActionSet {
return NewActionSetWithStickyRuntime(sqliteDSN, defaultStickyStoreRuntime())
}
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet {
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime, authCfg ...config.UserKeyAuthConfig) ActionSet {
routeLogWriter := newLazyRouteLogWriter(sqliteDSN)
resolveRoute := buildResolveRouteAction(sqliteDSN, stickyRuntime, routeLogWriter)
proxyRouteChatCompletions := buildProxyRouteChatCompletionsAction(sqliteDSN, resolveRoute, routeLogWriter)
@@ -1383,7 +1400,7 @@ func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRu
GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime),
ListProviderAccounts: buildListProviderAccountsAction(sqliteDSN),
GetProviderAccountBindingCandidates: buildGetProviderAccountBindingCandidatesAction(sqliteDSN),
UserKeyHandler: buildUserKeyHandler(sqliteDSN),
UserKeyHandler: buildUserKeyHandler(sqliteDSN, authCfg...),
UpdateProviderAccountBinding: buildUpdateProviderAccountBindingAction(sqliteDSN),
EnableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusActive),
DisableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDisabled),
@@ -2741,6 +2758,19 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "quota_exhausted", Message: "API key quota exhausted"})
return
}
if key.ExpiresAt != "" {
expiresAt, parseErr := time.Parse(time.RFC3339, key.ExpiresAt)
if parseErr != nil {
metrics.RecordUserKeyChatRequest("key_metadata_error")
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "key_metadata_error", Message: "invalid key expiry metadata"})
return
}
if !expiresAt.After(time.Now().UTC()) {
metrics.RecordUserKeyChatRequest("key_expired")
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "key_expired", Message: "API key has expired"})
return
}
}
// 4. Parse request body (OpenAI-compatible)
body, err := io.ReadAll(io.LimitReader(r.Body, maxJSONBodyBytes))
@@ -2768,6 +2798,20 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "model is required"})
return
}
if len(key.AllowedModels) > 0 {
modelAllowed := false
for _, allowedModel := range key.AllowedModels {
if strings.TrimSpace(allowedModel) == model {
modelAllowed = true
break
}
}
if !modelAllowed {
metrics.RecordUserKeyChatRequest("model_not_allowed")
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "model_not_allowed", Message: "requested model is not allowed for this API key"})
return
}
}
// 5. Map to proxy request
proxyReq := ProxyRouteChatCompletionsRequest{
@@ -2804,7 +2848,28 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
}
if upstreamResp == nil {
// Fallback: construct a minimal response from proxy info
upstreamResp = map[string]any{}
}
if !result.Forward.OK {
statusCode := result.Forward.UpstreamStatus
if statusCode <= 0 {
statusCode = http.StatusBadGateway
}
upstreamResp["upstream_http_code"] = statusCode
if _, hasError := upstreamResp["error"]; !hasError {
upstreamResp["error"] = map[string]any{
"code": nonEmptyString(result.Forward.ErrorClass, "upstream_error"),
"message": nonEmptyString(result.Forward.ErrorMessage, fmt.Sprintf("upstream request failed with status %d", statusCode)),
}
}
metrics.RecordUserKeyChatRequest(nonEmptyString(result.Forward.ErrorClass, "upstream_error"))
writeJSON(w, statusCode, upstreamResp)
return
}
// Fallback: construct a minimal success response from proxy info
if len(upstreamResp) == 0 {
upstreamResp = map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
"object": "chat.completion",
@@ -2820,10 +2885,8 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
}},
}
}
// Ensure upstream HTTP code is reflected
if !result.Forward.OK && result.Forward.UpstreamStatus > 0 {
upstreamResp["upstream_http_code"] = result.Forward.UpstreamStatus
if err := store.UserKeys().TouchLastUsed(r.Context(), key.KeyID); err != nil {
log.Printf("gateway: touch last_used_at for key %s failed: %v", key.KeyID, err)
}
// Wrap in OpenAI standard envelope if upstream didn't return one

View File

@@ -19,13 +19,16 @@ func generatePlaintextKey() (string, string) {
}
type UserKeyHandler struct {
createFn func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error)
listFn func(ctx context.Context, subjectID string) ([]UserKeyMeta, error)
getFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
resetFn func(ctx context.Context, keyID, subjectID string) (ResetUserKeyResponse, error)
pauseFn func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error)
resumeFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
deleteFn func(ctx context.Context, keyID, subjectID string) error
TrustedSubjectHeader string
TrustedProxySecretHeader string
TrustedProxySecret string
createFn func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error)
listFn func(ctx context.Context, subjectID string) ([]UserKeyMeta, error)
getFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
resetFn func(ctx context.Context, keyID, subjectID string) (ResetUserKeyResponse, error)
pauseFn func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error)
resumeFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
deleteFn func(ctx context.Context, keyID, subjectID string) error
}
type CreateUserKeyRequest struct {
@@ -60,22 +63,22 @@ type UserKeyMeta struct {
}
func (h *UserKeyHandler) extractSubjectID(r *http.Request) (string, *httpError) {
for _, header := range []string{"X-Portal-Subject", "X-User-Subject", "X-Forwarded-User"} {
if subjectID := strings.TrimSpace(r.Header.Get(header)); subjectID != "" {
return subjectID, nil
}
if h == nil {
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "user credentials required"}
}
if hdr := r.Header.Get("Authorization"); strings.HasPrefix(hdr, "Bearer ") {
token := strings.TrimSpace(strings.TrimPrefix(hdr, "Bearer "))
if token != "" {
n := 8
if len(token) < n {
n = len(token)
}
return "skeleton_user_" + token[:n], nil
}
subjectHeader := strings.TrimSpace(h.TrustedSubjectHeader)
secretHeader := strings.TrimSpace(h.TrustedProxySecretHeader)
secret := strings.TrimSpace(h.TrustedProxySecret)
if subjectHeader == "" || secretHeader == "" || secret == "" {
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted user identity proxy not configured"}
}
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "user credentials required"}
if got := strings.TrimSpace(r.Header.Get(secretHeader)); got != secret {
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted proxy authentication required"}
}
if subjectID := strings.TrimSpace(r.Header.Get(subjectHeader)); subjectID != "" {
return subjectID, nil
}
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted subject header required"}
}
func writeSvcNotImplError(w http.ResponseWriter) {
@@ -181,7 +184,16 @@ func handlePauseUserKey(w http.ResponseWriter, r *http.Request, h *UserKeyHandle
return
}
keyID := r.PathValue("key_id")
key, svcErr := h.pauseFn(r.Context(), keyID, subjectID, "")
var req struct {
Reason string `json:"reason"`
}
if r.Body != nil && r.ContentLength != 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_json", Message: err.Error()})
return
}
}
key, svcErr := h.pauseFn(r.Context(), keyID, subjectID, strings.TrimSpace(req.Reason))
if svcErr != nil {
writeHTTPError(w, classifyError(svcErr))
return

View File

@@ -4,12 +4,35 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
const (
testTrustedSubjectHeader = "X-CRM-Authenticated-Subject"
testTrustedProxySecretHeader = "X-CRM-Trusted-Proxy"
testTrustedProxySecret = "shared-secret"
)
func withTrustedProxyAuth(h *UserKeyHandler) *UserKeyHandler {
if h == nil {
return nil
}
clone := *h
clone.TrustedSubjectHeader = testTrustedSubjectHeader
clone.TrustedProxySecretHeader = testTrustedProxySecretHeader
clone.TrustedProxySecret = testTrustedProxySecret
return &clone
}
func applyTrustedProxyHeaders(req *http.Request, subjectID string) {
req.Header.Set(testTrustedSubjectHeader, subjectID)
req.Header.Set(testTrustedProxySecretHeader, testTrustedProxySecret)
}
func TestGeneratePlaintextKeyAndExtractSubjectID(t *testing.T) {
t.Parallel()
plaintext, fingerprint := generatePlaintextKey()
@@ -20,16 +43,49 @@ func TestGeneratePlaintextKeyAndExtractSubjectID(t *testing.T) {
t.Fatalf("fingerprint = %q, want sha256 prefix", fingerprint)
}
h := &UserKeyHandler{}
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("Authorization", "Bearer abcdefgh12345678")
subjectID, httpErr := h.extractSubjectID(req)
if httpErr != nil {
t.Fatalf("extractSubjectID() unexpected error: %+v", httpErr)
}
if subjectID != "skeleton_user_abcdefgh" {
t.Fatalf("subjectID = %q, want skeleton_user_abcdefgh", subjectID)
}
t.Run("rejects bearer fallback when trusted proxy auth is not configured", func(t *testing.T) {
h := &UserKeyHandler{}
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("Authorization", "Bearer abcdefgh12345678")
_, httpErr := h.extractSubjectID(req)
if httpErr == nil {
t.Fatal("expected unauthorized error when trusted proxy auth is not configured")
}
if httpErr.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want 401", httpErr.StatusCode)
}
})
t.Run("rejects portal subject header when trusted proxy auth is not configured", func(t *testing.T) {
h := &UserKeyHandler{}
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("X-Portal-Subject", "portal-user:1")
_, httpErr := h.extractSubjectID(req)
if httpErr == nil {
t.Fatal("expected unauthorized error when trusted proxy auth is not configured")
}
if httpErr.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want 401", httpErr.StatusCode)
}
})
t.Run("accepts trusted proxy subject when proxy secret matches", func(t *testing.T) {
h := &UserKeyHandler{
TrustedSubjectHeader: "X-CRM-Authenticated-Subject",
TrustedProxySecretHeader: "X-CRM-Trusted-Proxy",
TrustedProxySecret: "shared-secret",
}
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("X-CRM-Authenticated-Subject", "portal-user:1")
req.Header.Set("X-CRM-Trusted-Proxy", "shared-secret")
subjectID, httpErr := h.extractSubjectID(req)
if httpErr != nil {
t.Fatalf("extractSubjectID() unexpected error: %+v", httpErr)
}
if subjectID != "portal-user:1" {
t.Fatalf("subjectID = %q, want portal-user:1", subjectID)
}
})
}
func TestHandleUserKeyListNotImplemented(t *testing.T) {
@@ -46,16 +102,16 @@ func TestHandleUserKeyListNotImplemented(t *testing.T) {
func TestHandleUserKeyListSuccess(t *testing.T) {
t.Parallel()
h := &UserKeyHandler{
h := withTrustedProxyAuth(&UserKeyHandler{
listFn: func(ctx context.Context, subjectID string) ([]UserKeyMeta, error) {
if subjectID != "portal-user:1" {
t.Fatalf("subjectID = %q, want portal-user:1", subjectID)
}
return []UserKeyMeta{{KeyID: "key_1", AdminStatus: "active"}}, nil
},
}
})
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("X-Portal-Subject", "portal-user:1")
applyTrustedProxyHeaders(req, "portal-user:1")
rr := httptest.NewRecorder()
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
handleListUserKeys(w, r, h)
@@ -70,12 +126,12 @@ func TestHandleUserKeyListSuccess(t *testing.T) {
func TestHandleGetUserKeyMissingKeyID(t *testing.T) {
t.Parallel()
h := &UserKeyHandler{getFn: func(context.Context, string, string) (UserKeyMeta, error) {
h := withTrustedProxyAuth(&UserKeyHandler{getFn: func(context.Context, string, string) (UserKeyMeta, error) {
t.Fatal("getFn should not be called when key_id is missing")
return UserKeyMeta{}, nil
}}
}})
req := httptest.NewRequest(http.MethodGet, "/api/keys/", nil)
req.Header.Set("X-Portal-Subject", "portal-user:1")
applyTrustedProxyHeaders(req, "portal-user:1")
rr := httptest.NewRecorder()
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
handleGetUserKey(w, r, h)
@@ -131,7 +187,7 @@ func TestHandleUserKeyMutationHandlers(t *testing.T) {
path: "/api/keys/key_1/pause",
handlerFn: handlePauseUserKey,
userHandler: &UserKeyHandler{pauseFn: func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error) {
if keyID != "key_1" || subjectID != "portal-user:1" || reason != "" {
if keyID != "key_1" || subjectID != "portal-user:1" || reason != "user requested pause" {
t.Fatalf("pauseFn args = (%q,%q,%q)", keyID, subjectID, reason)
}
paused := meta
@@ -174,12 +230,18 @@ func TestHandleUserKeyMutationHandlers(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.method, tc.path, nil)
req := httptest.NewRequest(tc.method, tc.path, func() io.Reader {
if tc.name == "pause-success" {
return strings.NewReader(`{"reason":"user requested pause"}`)
}
return nil
}())
req.Header.Set("X-Portal-Subject", "portal-user:1")
req.SetPathValue("key_id", "key_1")
rr := httptest.NewRecorder()
applyTrustedProxyHeaders(req, "portal-user:1")
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
tc.handlerFn(w, r, tc.userHandler)
tc.handlerFn(w, r, withTrustedProxyAuth(tc.userHandler))
})
if rr.Code != tc.wantStatus {
t.Fatalf("status = %d, want %d body=%s", rr.Code, tc.wantStatus, rr.Body.String())
@@ -200,11 +262,11 @@ func serveWithMetrics(t *testing.T, req *http.Request, rr *httptest.ResponseReco
func TestHandleListUserKeysResponseShape(t *testing.T) {
t.Parallel()
h := &UserKeyHandler{listFn: func(context.Context, string) ([]UserKeyMeta, error) {
h := withTrustedProxyAuth(&UserKeyHandler{listFn: func(context.Context, string) ([]UserKeyMeta, error) {
return []UserKeyMeta{{KeyID: "key_json", AdminStatus: "active"}}, nil
}}
}})
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
req.Header.Set("X-Portal-Subject", "portal-user:json")
applyTrustedProxyHeaders(req, "portal-user:json")
rr := httptest.NewRecorder()
handleListUserKeys(rr, req, h)
var payload struct {
@@ -263,10 +325,10 @@ func TestHandleUserKeyMutationHandlersErrorPaths(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/keys/key_1", nil)
req.Header.Set("X-Portal-Subject", "portal-user:1")
applyTrustedProxyHeaders(req, "portal-user:1")
req.SetPathValue("key_id", "key_1")
rr := httptest.NewRecorder()
tc.handlerFn(rr, req, tc.userHandler)
tc.handlerFn(rr, req, withTrustedProxyAuth(tc.userHandler))
if rr.Code != tc.wantStatus {
t.Fatalf("status = %d, want %d body=%s", rr.Code, tc.wantStatus, rr.Body.String())
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/metrics"
"sub2api-cn-relay-manager/internal/store/sqlite"
@@ -104,13 +105,34 @@ func ensureSubjectHasAccess(ctx context.Context, client *sub2api.Client, subject
return apiKey, nil
}
func buildManagedIdentitySelector(subjectID, keyID string) string {
return strings.TrimSpace(subjectID) + "|key:" + strings.TrimSpace(keyID) + "|rot:" + generateKeyID()
}
func managedIdentitySelectorForRecord(rec *sqlite.UserKeyRecord) string {
if rec == nil {
return ""
}
if selector := strings.TrimSpace(rec.ManagedIdentitySelector); selector != "" {
return selector
}
return strings.TrimSpace(rec.OwnerSubjectID)
}
func recordUserKeyFailure(operation, result string, err error) error {
metrics.RecordUserKeyOperation(operation, result)
return err
}
func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
func buildUserKeyHandler(sqliteDSN string, authCfg ...config.UserKeyAuthConfig) *UserKeyHandler {
var cfg config.UserKeyAuthConfig
if len(authCfg) > 0 {
cfg = authCfg[0]
}
return &UserKeyHandler{
TrustedSubjectHeader: strings.TrimSpace(cfg.TrustedSubjectHeader),
TrustedProxySecretHeader: strings.TrimSpace(cfg.TrustedProxySecretHeader),
TrustedProxySecret: strings.TrimSpace(cfg.TrustedProxySecret),
createFn: func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error) {
if strings.TrimSpace(req.SubjectID) == "" {
metrics.RecordUserKeyOperation("create", "unauthorized")
@@ -136,6 +158,9 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
return CreateUserKeyResponse{}, &httpError{StatusCode: 429, Code: "rate_limited", Message: "create key rate limit exceeded"}
}
keyID := generateKeyID()
managedIdentitySelector := buildManagedIdentitySelector(req.SubjectID, keyID)
// Resolve logical group → host → group ID → ensure subscription access
_, route, hostRow, client, err := resolveLogicalGroupHost(ctx, store, req.LogicalGroupID)
if err != nil {
@@ -145,26 +170,26 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
if err != nil {
return CreateUserKeyResponse{}, recordUserKeyFailure("create", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for %q: %w", route.ShadowGroupID, err))
}
apiKey, err := ensureSubjectHasAccess(ctx, client, req.SubjectID, hostGroupID)
apiKey, err := ensureSubjectHasAccess(ctx, client, managedIdentitySelector, hostGroupID)
if err != nil {
return CreateUserKeyResponse{}, recordUserKeyFailure("create", "ensure_access_error", fmt.Errorf("ensure access for %q: %w", req.LogicalGroupID, err))
}
fingerprint := "sha256:" + sha256Hex(apiKey)
keyID := generateKeyID()
masked := "sk-****" + apiKey[len(apiKey)-4:]
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
if _, err := q.UserKeys.Create(ctx, sqlite.UserKeyRecord{
KeyID: keyID,
OwnerSubjectID: req.SubjectID,
KeyFingerprint: fingerprint,
MaskedPreview: masked,
DisplayName: strings.TrimSpace(req.DisplayName),
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
AllowedModels: req.AllowedModels,
AdminStatus: "active",
QuotaStatus: "ok",
KeyID: keyID,
OwnerSubjectID: req.SubjectID,
ManagedIdentitySelector: managedIdentitySelector,
KeyFingerprint: fingerprint,
MaskedPreview: masked,
DisplayName: strings.TrimSpace(req.DisplayName),
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
AllowedModels: req.AllowedModels,
AdminStatus: "active",
QuotaStatus: "ok",
}); err != nil {
return fmt.Errorf("create key: %w", err)
}
@@ -288,7 +313,8 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
if err != nil {
return ResetUserKeyResponse{}, recordUserKeyFailure("reset", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for %q: %w", route.ShadowGroupID, err))
}
newPlaintext, err := ensureSubjectHasAccess(ctx, client, rec.OwnerSubjectID, hostGroupID)
managedIdentitySelector := buildManagedIdentitySelector(rec.OwnerSubjectID, keyID)
newPlaintext, err := ensureSubjectHasAccess(ctx, client, managedIdentitySelector, hostGroupID)
if err != nil {
return ResetUserKeyResponse{}, recordUserKeyFailure("reset", "ensure_access_error", fmt.Errorf("ensure access on reset for %q: %w", rec.LogicalGroupID, err))
}
@@ -297,7 +323,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
masked := "sk-****" + newPlaintext[len(newPlaintext)-4:]
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
if err := q.UserKeys.UpdateSecret(ctx, keyID, hostFingerprint, masked, "active"); err != nil {
if err := q.UserKeys.UpdateSecret(ctx, keyID, managedIdentitySelector, hostFingerprint, masked, "active"); err != nil {
return fmt.Errorf("reset key: %w", err)
}
if _, err := q.UserKeyAuditEvents.Create(ctx, sqlite.UserKeyAuditEvent{
@@ -341,7 +367,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
if err != nil {
return UserKeyMeta{}, recordUserKeyFailure("pause", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for pause %q: %w", route.ShadowGroupID, err))
}
if err := client.PauseManagedSubscriptionAccess(ctx, rec.OwnerSubjectID, hostGroupID); err != nil {
if err := client.PauseManagedSubscriptionAccess(ctx, managedIdentitySelectorForRecord(rec), hostGroupID); err != nil {
return UserKeyMeta{}, recordUserKeyFailure("pause", "pause_access_error", fmt.Errorf("pause managed subscription access: %w", err))
}
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
@@ -384,7 +410,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
if err != nil {
return UserKeyMeta{}, recordUserKeyFailure("resume", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for resume %q: %w", route.ShadowGroupID, err))
}
if err := client.ResumeManagedSubscriptionAccess(ctx, rec.OwnerSubjectID, hostGroupID); err != nil {
if err := client.ResumeManagedSubscriptionAccess(ctx, managedIdentitySelectorForRecord(rec), hostGroupID); err != nil {
return UserKeyMeta{}, recordUserKeyFailure("resume", "resume_access_error", fmt.Errorf("resume managed subscription access: %w", err))
}
err = store.WithTx(ctx, func(q *sqlite.Queries) error {

View File

@@ -11,10 +11,24 @@ import (
"strings"
"testing"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/metrics"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func testUserKeyAuthConfig() config.UserKeyAuthConfig {
return config.UserKeyAuthConfig{
TrustedSubjectHeader: testTrustedSubjectHeader,
TrustedProxySecretHeader: testTrustedProxySecretHeader,
TrustedProxySecret: testTrustedProxySecret,
}
}
func applyTrustedProxyAuthHeaders(req *http.Request, subjectID string) {
req.Header.Set(testTrustedSubjectHeader, subjectID)
req.Header.Set(testTrustedProxySecretHeader, testTrustedProxySecret)
}
func makeCreateBody(groupID, displayName string, models []string) io.Reader {
b, _ := json.Marshal(map[string]any{
"logical_group_id": groupID,
@@ -60,11 +74,11 @@ func TestUserKeyAPIUsesPortalSubjectHeader(t *testing.T) {
})
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
})
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("gpt-shared", "portal key", []string{"gpt-5.4"}))
req.Header.Set("X-Portal-Subject", "smoke-user")
applyTrustedProxyAuthHeaders(req, "smoke-user")
resp := httptestRecorder(handler, req)
// We expect 500 because test host is unreachable (port 1), but the important
@@ -107,11 +121,11 @@ func TestUserKeyCreateRejectsMissingSubject(t *testing.T) {
func TestUserKeyCreateRejectsMissingGroup(t *testing.T) {
t.Parallel()
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t))),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t)), testUserKeyAuthConfig()),
})
body := bytes.NewReader([]byte(`{"display_name":"portal key"}`))
req := makeCreateRequest(t, http.MethodPost, "/api/keys", body)
req.Header.Set("X-Portal-Subject", "smoke-user")
applyTrustedProxyAuthHeaders(req, "smoke-user")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusBadRequest {
t.Fatalf("status code = %d, want 400", resp.code)
@@ -142,18 +156,18 @@ func TestUserKeyRateLimitNoDB(t *testing.T) {
})
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
})
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("gpt-shared", "rate-test", nil))
req.Header.Set("X-Portal-Subject", "rate-user")
applyTrustedProxyAuthHeaders(req, "rate-user")
resp := httptestRecorder(handler, req)
if resp.code == http.StatusUnauthorized || resp.code == http.StatusNotImplemented {
t.Fatalf("status code = %d, expected to pass auth layer", resp.code)
}
}
func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testing.T) {
func TestUserKeyCreateUsesPerRecordManagedKeyAndConsistentMetadata(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
@@ -163,6 +177,8 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
const hostGroupID = "999"
const subjectID = "portal-user:13"
var loginEmail string
var customKey string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
@@ -180,9 +196,9 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode login request: %v", err)
}
expected := expectedManagedIdentity(subjectID, hostGroupID)
if got := fmt.Sprint(req["email"]); got != expected.Email {
t.Fatalf("login email = %q, want subject-scoped %q", got, expected.Email)
loginEmail = fmt.Sprint(req["email"])
if !strings.Contains(loginEmail, "@sub2api.local") || strings.Contains(loginEmail, subjectID) {
t.Fatalf("login email = %q, want synthesized per-record managed identity", loginEmail)
}
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
@@ -190,9 +206,9 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode managed key request: %v", err)
}
expected := expectedManagedIdentity(subjectID, hostGroupID)
if got := fmt.Sprint(req["custom_key"]); got != expected.CustomKey {
t.Fatalf("custom_key = %q, want subject-scoped %q", got, expected.CustomKey)
customKey = fmt.Sprint(req["custom_key"])
if !strings.HasPrefix(customKey, "sk-relay-") {
t.Fatalf("custom_key = %q, want sk-relay-*", customKey)
}
w.Write([]byte(`{"data":{"id":501,"key":"placeholder-from-host","name":"managed-key"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
@@ -226,11 +242,11 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
})
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
})
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody(logicalGroupID, "portal key", []string{"gpt-5.4"}))
req.Header.Set("X-Portal-Subject", subjectID)
applyTrustedProxyAuthHeaders(req, subjectID)
resp := httptestRecorder(handler, req)
if resp.code != http.StatusCreated {
t.Fatalf("status code = %d, want 201, body=%s", resp.code, resp.Body().String())
@@ -241,11 +257,10 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
t.Fatalf("decode create response: %v", err)
}
expected := expectedManagedIdentity(subjectID, hostGroupID)
if createResp.PlaintextKey != expected.CustomKey {
t.Fatalf("plaintext_key = %q, want subject-scoped %q", createResp.PlaintextKey, expected.CustomKey)
if createResp.PlaintextKey != customKey {
t.Fatalf("plaintext_key = %q, want host custom_key %q", createResp.PlaintextKey, customKey)
}
wantMasked := "sk-****" + expected.CustomKey[len(expected.CustomKey)-4:]
wantMasked := "sk-****" + customKey[len(customKey)-4:]
if createResp.Key.MaskedPreview != wantMasked {
t.Fatalf("masked_preview = %q, want %q", createResp.Key.MaskedPreview, wantMasked)
}
@@ -254,12 +269,18 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
if err != nil {
t.Fatalf("UserKeys().GetByID() error = %v", err)
}
if record.KeyFingerprint != "sha256:"+sha256Hex(expected.CustomKey) {
if strings.TrimSpace(record.ManagedIdentitySelector) == "" || !strings.Contains(record.ManagedIdentitySelector, createResp.Key.KeyID) {
t.Fatalf("managed_identity_selector = %q, want non-empty selector tied to key id", record.ManagedIdentitySelector)
}
if record.KeyFingerprint != "sha256:"+sha256Hex(customKey) {
t.Fatalf("key_fingerprint = %q, want sha256 of returned plaintext key", record.KeyFingerprint)
}
if record.MaskedPreview != wantMasked {
t.Fatalf("stored masked_preview = %q, want %q", record.MaskedPreview, wantMasked)
}
if loginEmail == "" {
t.Fatal("login email was not observed")
}
}
type managedIdentityExpectation struct {
@@ -302,13 +323,145 @@ func expectedManagedPrefix(value string) string {
return prefix
}
func TestUserKeyCreateAndResetDoNotReuseSameManagedKeyWithinSubjectGroup(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
const logicalGroupID = "gpt-shared"
const hostGroupID = "999"
const subjectID = "portal-user:multi"
type managedUser struct {
ID int64
Email string
}
usersByEmail := map[string]managedUser{}
nextUserID := int64(100)
nextKeyID := int64(500)
createdCustomKeys := make([]string, 0, 4)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
search := strings.TrimSpace(r.URL.Query().Get("search"))
items := make([]map[string]any, 0, 1)
if user, ok := usersByEmail[search]; ok {
items = append(items, map[string]any{"id": user.ID, "email": user.Email})
}
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"items": items}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode create user request: %v", err)
}
email := strings.TrimSpace(fmt.Sprint(req["email"]))
nextUserID++
usersByEmail[email] = managedUser{ID: nextUserID, Email: email}
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": nextUserID, "email": email}})
case r.Method == http.MethodPut && strings.HasPrefix(r.URL.Path, "/api/v1/admin/users/"):
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 1}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/101/balance":
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 101}})
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/api/v1/admin/users/") && strings.HasSuffix(r.URL.Path, "/balance"):
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 1}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 401}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"access_token": "user-jwt"}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode create key request: %v", err)
}
customKey := strings.TrimSpace(fmt.Sprint(req["custom_key"]))
createdCustomKeys = append(createdCustomKeys, customKey)
nextKeyID++
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": nextKeyID, "key": customKey, "name": fmt.Sprint(req["name"])}})
case r.Method == http.MethodPut && strings.HasPrefix(r.URL.Path, "/api/v1/admin/api-keys/"):
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
_, _ = store.Hosts().Create(context.Background(), sqlite.Host{
HostID: "test-host",
BaseURL: server.URL,
HostVersion: "0.0.1",
CapabilityProbeJSON: "{}",
AuthType: "apikey",
AuthToken: "test-token",
})
_, _ = store.LogicalGroups().Create(context.Background(), sqlite.LogicalGroup{
LogicalGroupID: logicalGroupID,
DisplayName: "GPT Shared",
Status: "active",
})
_, _ = store.LogicalGroupRoutes().Create(context.Background(), sqlite.LogicalGroupRoute{
RouteID: "test-route",
LogicalGroupID: logicalGroupID,
Name: "Test Route",
Status: "active",
ShadowHostID: "test-host",
ShadowGroupID: hostGroupID,
})
handler := buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig())
first, err := handler.createFn(context.Background(), CreateUserKeyRequest{SubjectID: subjectID, LogicalGroupID: logicalGroupID, DisplayName: "first", AllowedModels: []string{"gpt-5.4"}})
if err != nil {
t.Fatalf("first createFn() error = %v", err)
}
second, err := handler.createFn(context.Background(), CreateUserKeyRequest{SubjectID: subjectID, LogicalGroupID: logicalGroupID, DisplayName: "second", AllowedModels: []string{"gpt-5.4"}})
if err != nil {
t.Fatalf("second createFn() error = %v", err)
}
if first.PlaintextKey == second.PlaintextKey {
t.Fatalf("createFn() reused plaintext key across records: first=%q second=%q", first.PlaintextKey, second.PlaintextKey)
}
reset, err := handler.resetFn(context.Background(), first.Key.KeyID, subjectID)
if err != nil {
t.Fatalf("resetFn() error = %v", err)
}
if reset.PlaintextKey == first.PlaintextKey {
t.Fatalf("resetFn() reused original plaintext key: before=%q after=%q", first.PlaintextKey, reset.PlaintextKey)
}
if reset.PlaintextKey == second.PlaintextKey {
t.Fatalf("resetFn() collided with sibling key: reset=%q sibling=%q", reset.PlaintextKey, second.PlaintextKey)
}
firstRecord, err := store.UserKeys().GetByID(context.Background(), first.Key.KeyID)
if err != nil {
t.Fatalf("GetByID(first) error = %v", err)
}
secondRecord, err := store.UserKeys().GetByID(context.Background(), second.Key.KeyID)
if err != nil {
t.Fatalf("GetByID(second) error = %v", err)
}
if firstRecord.KeyFingerprint == secondRecord.KeyFingerprint {
t.Fatalf("distinct key records share fingerprint: %q", firstRecord.KeyFingerprint)
}
if firstRecord.KeyFingerprint != "sha256:"+sha256Hex(reset.PlaintextKey) {
t.Fatalf("first record fingerprint = %q, want reset plaintext fingerprint", firstRecord.KeyFingerprint)
}
if len(createdCustomKeys) < 3 {
t.Fatalf("createdCustomKeys len = %d, want at least 3", len(createdCustomKeys))
}
if createdCustomKeys[0] == createdCustomKeys[1] || createdCustomKeys[0] == createdCustomKeys[2] {
t.Fatalf("host custom keys were reused unexpectedly: %#v", createdCustomKeys)
}
}
func TestUserKeyAPIMetricsMiddlewareAndCreateMetric(t *testing.T) {
t.Parallel()
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t))),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t)), testUserKeyAuthConfig()),
})
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("", "portal key", nil))
req.Header.Set("X-Portal-Subject", "smoke-user")
applyTrustedProxyAuthHeaders(req, "smoke-user")
_ = httptestRecorder(handler, req)
metricsReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)

311
internal/app/portal_auth.go Normal file
View File

@@ -0,0 +1,311 @@
package app
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"net/http"
"strconv"
"strings"
"time"
)
const (
portalSubjectCookieName = "crm_subject"
portalSessionCookieName = "crm_session"
defaultPortalSessionTTL = 30 * 24 * time.Hour // 30 days
)
// PortalAuthConfig 定义 portal user session 配置
type PortalAuthConfig struct {
SessionSecret string // session cookie 签名密钥
SessionTTL time.Duration // session 有效期
Now func() time.Time
}
// portalSessionInfo 存储 session 信息
type portalSessionInfo struct {
SubjectID string
Email string
ExpiresAt time.Time
}
// portalLoginRequest 登录请求
type portalLoginRequest struct {
Email string `json:"email"`
Password string `json:"password"` // 仅用于验证portal 采用"登录即注册"模式
}
// normalized 返回规范化配置
func (c PortalAuthConfig) normalized() PortalAuthConfig {
if c.SessionTTL <= 0 {
c.SessionTTL = defaultPortalSessionTTL
}
if c.Now == nil {
c.Now = time.Now
}
return c
}
// normalizedSubjectID 规范化 subject ID
func normalizedSubjectID(email string) string {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return ""
}
return "portal-email:" + email
}
// signSessionCookie 签名 session cookie 值
func signSessionCookie(secret, subjectID string, expiresAt time.Time) string {
if secret == "" || subjectID == "" {
return ""
}
payload := subjectID + "|" + strconv.FormatInt(expiresAt.Unix(), 10)
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(payload))
sig := hex.EncodeToString(mac.Sum(nil))
return base64.RawURLEncoding.EncodeToString([]byte(payload + "|" + sig))
}
// verifySessionCookie 验证并解析 session cookie
func verifySessionCookie(secret, raw string, now time.Time) (*portalSessionInfo, bool) {
if secret == "" || raw == "" {
return nil, false
}
b, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return nil, false
}
parts := strings.SplitN(string(b), "|", 3)
if len(parts) != 3 {
return nil, false
}
subjectID, tsStr, sigHex := parts[0], parts[1], parts[2]
// 验证签名
payload := subjectID + "|" + tsStr
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(payload))
expectedSig := hex.EncodeToString(mac.Sum(nil))
if subtle.ConstantTimeCompare([]byte(sigHex), []byte(expectedSig)) != 1 {
return nil, false
}
// 解析过期时间
unixSec, err := strconv.ParseInt(tsStr, 10, 64)
if err != nil {
return nil, false
}
expiresAt := time.Unix(unixSec, 0)
if now.After(expiresAt) {
return nil, false
}
return &portalSessionInfo{
SubjectID: subjectID,
ExpiresAt: expiresAt,
}, true
}
// generateSessionSecret 生成随机 session secret32字节
func generateSessionSecret() string {
b := make([]byte, 32)
rand.Read(b)
return hex.EncodeToString(b)
}
// extractSubjectFromCookie 从请求 cookie 中提取 subject
func extractSubjectFromCookie(r *http.Request, sessionSecret string) string {
if sessionSecret == "" {
return ""
}
cookie, err := r.Cookie(portalSessionCookieName)
if err != nil || cookie == nil || cookie.Value == "" {
return ""
}
info, ok := verifySessionCookie(sessionSecret, cookie.Value, time.Now())
if !ok {
return ""
}
return info.SubjectID
}
// handlePortalSessionLogin 处理 portal user 登录
// 设置 httpOnly cookie返回 subject ID
func handlePortalSessionLogin(w http.ResponseWriter, r *http.Request, cfg PortalAuthConfig) {
cfg = cfg.normalized()
if cfg.SessionSecret == "" {
writeHTTPError(w, &httpError{
StatusCode: http.StatusServiceUnavailable,
Code: "portal_auth_not_configured",
Message: "Portal session authentication is not configured",
})
return
}
var req portalLoginRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
email := strings.TrimSpace(req.Email)
if email == "" || !strings.Contains(email, "@") {
writeHTTPError(w, &httpError{
StatusCode: http.StatusBadRequest,
Code: "invalid_email",
Message: "Valid email is required",
})
return
}
subjectID := normalizedSubjectID(email)
expiresAt := cfg.Now().Add(cfg.SessionTTL)
// 生成签名 cookie
sessionValue := signSessionCookie(cfg.SessionSecret, subjectID, expiresAt)
if sessionValue == "" {
writeHTTPError(w, &httpError{
StatusCode: http.StatusInternalServerError,
Code: "session_sign_failed",
Message: "Failed to sign session",
})
return
}
// 设置 httpOnly cookieSameSite=LaxSecure 建议生产启用 HTTPS
cookie := &http.Cookie{
Name: portalSessionCookieName,
Value: sessionValue,
Path: "/",
Expires: expiresAt,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
}
http.SetCookie(w, cookie)
// 同时设置非 httpOnly cookie 供前端 JS 读取 subject用于显示
subjectCookie := &http.Cookie{
Name: portalSubjectCookieName,
Value: subjectID,
Path: "/",
Expires: expiresAt,
HttpOnly: false, // 允许前端读取
SameSite: http.SameSiteLaxMode,
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
}
http.SetCookie(w, subjectCookie)
writeJSON(w, http.StatusOK, map[string]any{
"authenticated": true,
"subject_id": subjectID,
"email": email,
"expires_at": expiresAt.Format(time.RFC3339),
})
}
// handlePortalSessionLogout 处理 portal user 登出
// 清除 session cookie
func handlePortalSessionLogout(w http.ResponseWriter, r *http.Request) {
// 清除 session cookie
sessionCookie := &http.Cookie{
Name: portalSessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
}
http.SetCookie(w, sessionCookie)
// 清除 subject cookie
subjectCookie := &http.Cookie{
Name: portalSubjectCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: false,
}
http.SetCookie(w, subjectCookie)
writeJSON(w, http.StatusOK, map[string]any{
"authenticated": false,
})
}
// handlePortalSessionState 处理 portal session 状态查询
func handlePortalSessionState(w http.ResponseWriter, r *http.Request, cfg PortalAuthConfig) {
cfg = cfg.normalized()
if cfg.SessionSecret == "" {
writeJSON(w, http.StatusOK, map[string]any{
"authenticated": false,
"login_enabled": false,
})
return
}
subjectID := extractSubjectFromCookie(r, cfg.SessionSecret)
if subjectID == "" {
writeJSON(w, http.StatusOK, map[string]any{
"authenticated": false,
"login_enabled": true,
})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"authenticated": true,
"login_enabled": true,
"subject_id": subjectID,
})
}
// requirePortalSubject 中间件:要求 portal session 认证
// 与 trusted proxy header 的验证流程配合
func requirePortalSubject(cfg PortalAuthConfig, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cfg = cfg.normalized()
// 首先检查 trusted proxy header
// 这是生产环境的推荐做法nginx 验证并设置 header
if cfg.SessionSecret == "" {
writeHTTPError(w, &httpError{
StatusCode: http.StatusUnauthorized,
Code: "unauthorized",
Message: "Portal authentication not configured",
})
return
}
subjectID := extractSubjectFromCookie(r, cfg.SessionSecret)
if subjectID == "" {
writeHTTPError(w, &httpError{
StatusCode: http.StatusUnauthorized,
Code: "unauthorized",
Message: "Valid session required",
})
return
}
// 将 subject 放入 context 供后续使用
ctx := context.WithValue(r.Context(), "portal_subject", subjectID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// getPortalSubjectFromContext 从 context 获取 subject
func getPortalSubjectFromContext(ctx context.Context) string {
if v, ok := ctx.Value("portal_subject").(string); ok {
return v
}
return ""
}

View File

@@ -0,0 +1,135 @@
package app
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestPortalSessionLoginSetsCookieAndReturnsSubject(t *testing.T) {
cfg := PortalAuthConfig{
SessionSecret: "test-secret-32-bytes-long-for-hmac",
Now: func() time.Time { return time.Unix(1_717_000_000, 0) },
}
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{"email":"user@example.com"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handlePortalSessionLogin(rec, req, cfg)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
// 检查响应体包含 subject
body := rec.Body.String()
if !strings.Contains(body, `"subject_id":"portal-email:user@example.com"`) {
t.Fatalf("response body missing subject_id: %s", body)
}
// 检查设置了 cookie
cookies := rec.Result().Cookies()
if len(cookies) < 2 {
t.Fatalf("expected at least 2 cookies (session + subject), got %d", len(cookies))
}
// 检查 session cookie 是 httpOnly
var foundSessionCookie bool
for _, c := range cookies {
if c.Name == portalSessionCookieName {
foundSessionCookie = true
if !c.HttpOnly {
t.Fatal("session cookie should be HttpOnly")
}
}
}
if !foundSessionCookie {
t.Fatalf("session cookie %s not found", portalSessionCookieName)
}
}
func TestPortalSessionLoginRejectsMissingEmail(t *testing.T) {
cfg := PortalAuthConfig{
SessionSecret: "test-secret",
}
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handlePortalSessionLogin(rec, req, cfg)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestPortalSessionLoginRequiresSecret(t *testing.T) {
cfg := PortalAuthConfig{
SessionSecret: "", // 未配置
}
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{"email":"user@example.com"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handlePortalSessionLogin(rec, req, cfg)
if rec.Code != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
}
}
func TestPortalSessionLogoutClearsCookies(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/logout", nil)
rec := httptest.NewRecorder()
handlePortalSessionLogout(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
// 检查清除了 cookie
cookies := rec.Result().Cookies()
var clearedSession bool
var clearedSubject bool
for _, c := range cookies {
if c.Name == portalSessionCookieName && c.MaxAge == -1 {
clearedSession = true
}
if c.Name == portalSubjectCookieName && c.MaxAge == -1 {
clearedSubject = true
}
}
if !clearedSession {
t.Fatal("session cookie should be cleared")
}
if !clearedSubject {
t.Fatal("subject cookie should be cleared")
}
}
func TestPortalSessionStateUnauthenticatedWhenNoCookie(t *testing.T) {
cfg := PortalAuthConfig{
SessionSecret: "test-secret",
Now: func() time.Time { return time.Unix(1_717_000_000, 0) },
}
req := httptest.NewRequest(http.MethodGet, "/api/portal/session", nil)
rec := httptest.NewRecorder()
handlePortalSessionState(rec, req, cfg)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, `"authenticated":false`) {
t.Fatalf("expected unauthenticated, got: %s", body)
}
}

View File

@@ -58,6 +58,190 @@ func TestPublicV1ChatCompletionsQuotaExhaustedRecordsMetric(t *testing.T) {
}
}
func TestPublicV1ChatCompletionsPropagatesUpstreamFailureStatusAndMetric(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
const plaintextKey = "sk-test-upstream-429"
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
KeyID: "key_upstream_429",
OwnerSubjectID: "portal-user",
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
MaskedPreview: "sk-****-429",
DisplayName: "upstream 429 key",
LogicalGroupID: "gpt-shared",
AllowedModels: []string{"gpt-5.4"},
AdminStatus: "active",
QuotaStatus: "ok",
}); err != nil {
t.Fatalf("UserKeys().Create() error = %v", err)
}
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
return ProxyRouteChatCompletionsResult{
Forward: RouteChatCompletionsForwardInfo{
OK: false,
UpstreamStatus: http.StatusTooManyRequests,
ErrorClass: "gateway_rate_limited",
Response: map[string]any{
"error": map[string]any{
"code": "upstream_rate_limited",
"message": "upstream rejected request",
},
},
},
}, nil
},
}, appTestDSN(t, store))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
req.Header.Set("Authorization", "Bearer "+plaintextKey)
req.Header.Set("Content-Type", "application/json")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusTooManyRequests {
t.Fatalf("status code = %d, want 429 body=%s", resp.code, resp.Body().String())
}
assertJSONContains(t, resp.Body().Bytes(), "error.code", "upstream_rate_limited")
metricsReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
metricsResp := httptest.NewRecorder()
metrics.Handler().ServeHTTP(metricsResp, metricsReq)
body := metricsResp.Body.String()
if !strings.Contains(body, `user_key_chat_requests_total{result="gateway_rate_limited"}`) {
t.Fatalf("metrics body missing gateway_rate_limited metric: %s", body)
}
}
func TestPublicV1ChatCompletionsRejectsDisallowedModel(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
const plaintextKey = "sk-test-disallowed-model"
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
KeyID: "key_disallowed_model",
OwnerSubjectID: "portal-user",
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
MaskedPreview: "sk-****odel",
DisplayName: "model restricted key",
LogicalGroupID: "gpt-shared",
AllowedModels: []string{"gpt-4.1"},
AdminStatus: "active",
QuotaStatus: "ok",
}); err != nil {
t.Fatalf("UserKeys().Create() error = %v", err)
}
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
t.Fatal("proxy should not be called for disallowed model")
return ProxyRouteChatCompletionsResult{}, nil
},
}, appTestDSN(t, store))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
req.Header.Set("Authorization", "Bearer "+plaintextKey)
req.Header.Set("Content-Type", "application/json")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusForbidden {
t.Fatalf("status code = %d, want 403 body=%s", resp.code, resp.Body().String())
}
assertJSONContains(t, resp.Body().Bytes(), "error.code", "model_not_allowed")
}
func TestPublicV1ChatCompletionsRejectsExpiredKey(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
const plaintextKey = "sk-test-expired-key"
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
KeyID: "key_expired",
OwnerSubjectID: "portal-user",
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
MaskedPreview: "sk-****ired",
DisplayName: "expired key",
LogicalGroupID: "gpt-shared",
AllowedModels: []string{"gpt-5.4"},
AdminStatus: "active",
QuotaStatus: "ok",
}); err != nil {
t.Fatalf("UserKeys().Create() error = %v", err)
}
if _, err := store.SQLDB().ExecContext(context.Background(), `UPDATE user_keys SET expires_at = ? WHERE key_id = ?`, "2020-01-01T00:00:00Z", "key_expired"); err != nil {
t.Fatalf("set expires_at error = %v", err)
}
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
t.Fatal("proxy should not be called for expired key")
return ProxyRouteChatCompletionsResult{}, nil
},
}, appTestDSN(t, store))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
req.Header.Set("Authorization", "Bearer "+plaintextKey)
req.Header.Set("Content-Type", "application/json")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusForbidden {
t.Fatalf("status code = %d, want 403 body=%s", resp.code, resp.Body().String())
}
assertJSONContains(t, resp.Body().Bytes(), "error.code", "key_expired")
}
func TestPublicV1ChatCompletionsTouchesLastUsedAtOnSuccess(t *testing.T) {
t.Parallel()
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
const plaintextKey = "sk-test-last-used"
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
KeyID: "key_last_used",
OwnerSubjectID: "portal-user",
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
MaskedPreview: "sk-****used",
DisplayName: "active key",
LogicalGroupID: "gpt-shared",
AllowedModels: []string{"gpt-5.4"},
AdminStatus: "active",
QuotaStatus: "ok",
}); err != nil {
t.Fatalf("UserKeys().Create() error = %v", err)
}
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
return ProxyRouteChatCompletionsResult{Forward: RouteChatCompletionsForwardInfo{OK: true, UpstreamStatus: http.StatusOK, Response: map[string]any{"id": "chatcmpl_ok", "object": "chat.completion", "model": "gpt-5.4", "choices": []map[string]any{{"index": 0, "message": map[string]any{"role": "assistant", "content": "pong"}, "finish_reason": "stop"}}}}}, nil
},
}, appTestDSN(t, store))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
req.Header.Set("Authorization", "Bearer "+plaintextKey)
req.Header.Set("Content-Type", "application/json")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusOK {
t.Fatalf("status code = %d, want 200 body=%s", resp.code, resp.Body().String())
}
record, err := store.UserKeys().GetByID(context.Background(), "key_last_used")
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if strings.TrimSpace(record.LastUsedAt) == "" {
t.Fatalf("LastUsedAt = %q, want non-empty after successful chat", record.LastUsedAt)
}
}
func TestMetricsMiddlewareUsesRoutePatternForKeyReset(t *testing.T) {
t.Parallel()

View File

@@ -16,11 +16,11 @@ func TestUserKeyCreateResolveHostErrorRecordsMetric(t *testing.T) {
defer closeAppTestStore(t, store)
handler := NewAPIHandler("t", ActionSet{
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
})
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("missing-group", "portal key", []string{"gpt-5.4"}))
req.Header.Set("X-Portal-Subject", "portal-user")
applyTrustedProxyAuthHeaders(req, "portal-user")
resp := httptestRecorder(handler, req)
if resp.code != http.StatusInternalServerError {
t.Fatalf("status code = %d, want 500 body=%s", resp.code, resp.Body().String())

View File

@@ -9,26 +9,30 @@ import (
)
const (
EnvListenAddr = "SUB2API_CRM_LISTEN_ADDR"
EnvSQLiteDSN = "SUB2API_CRM_SQLITE_DSN"
EnvAdminToken = "SUB2API_CRM_ADMIN_TOKEN"
EnvAdminUsername = "SUB2API_CRM_ADMIN_USERNAME"
EnvAdminPassword = "SUB2API_CRM_ADMIN_PASSWORD"
EnvAdminSessionTTL = "SUB2API_CRM_ADMIN_SESSION_TTL"
EnvRepoRoot = "SUB2API_CRM_REPO_ROOT"
EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED"
EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL"
EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND"
EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR"
EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD"
EnvRedisDB = "SUB2API_CRM_REDIS_DB"
EnvListenAddr = "SUB2API_CRM_LISTEN_ADDR"
EnvSQLiteDSN = "SUB2API_CRM_SQLITE_DSN"
EnvAdminToken = "SUB2API_CRM_ADMIN_TOKEN"
EnvAdminUsername = "SUB2API_CRM_ADMIN_USERNAME"
EnvAdminPassword = "SUB2API_CRM_ADMIN_PASSWORD"
EnvAdminSessionTTL = "SUB2API_CRM_ADMIN_SESSION_TTL"
EnvRepoRoot = "SUB2API_CRM_REPO_ROOT"
EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED"
EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL"
EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND"
EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR"
EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD"
EnvRedisDB = "SUB2API_CRM_REDIS_DB"
EnvTrustedSubjectHeader = "SUB2API_CRM_TRUSTED_SUBJECT_HEADER"
EnvTrustedProxySecretHeader = "SUB2API_CRM_TRUSTED_PROXY_SECRET_HEADER"
EnvTrustedProxySecret = "SUB2API_CRM_TRUSTED_PROXY_SECRET"
DefaultListenAddr = ":8080"
DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000"
DefaultAdminUsername = "admin"
DefaultAdminSessionTTL = 12 * time.Hour
DefaultReconcilePollInterval = 10 * time.Minute
DefaultRouteRuntimeBackend = "memory"
DefaultListenAddr = ":8080"
DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000"
DefaultAdminUsername = "admin"
DefaultAdminSessionTTL = 12 * time.Hour
DefaultReconcilePollInterval = 10 * time.Minute
DefaultRouteRuntimeBackend = "memory"
DefaultTrustedProxySecretHeader = "X-CRM-Trusted-Proxy"
)
type ServerConfig struct {
@@ -59,10 +63,17 @@ type RepositoryConfig struct {
RepoRoot string
}
type UserKeyAuthConfig struct {
TrustedSubjectHeader string
TrustedProxySecretHeader string
TrustedProxySecret string
}
type StartupConfig struct {
Server ServerConfig
Database DatabaseConfig
Repository RepositoryConfig
UserKeyAuth UserKeyAuthConfig
RouteRuntime RouteRuntimeConfig
Reconcile ReconcileConfig
}
@@ -96,6 +107,11 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig
Repository: RepositoryConfig{
RepoRoot: readOptionalEnv(lookup, EnvRepoRoot, ""),
},
UserKeyAuth: UserKeyAuthConfig{
TrustedSubjectHeader: readOptionalEnv(lookup, EnvTrustedSubjectHeader, ""),
TrustedProxySecretHeader: readOptionalEnv(lookup, EnvTrustedProxySecretHeader, DefaultTrustedProxySecretHeader),
TrustedProxySecret: readOptionalEnv(lookup, EnvTrustedProxySecret, ""),
},
RouteRuntime: RouteRuntimeConfig{
Backend: readOptionalEnv(lookup, EnvRouteRuntimeBackend, DefaultRouteRuntimeBackend),
Redis: RedisRuntimeConfig{

View File

@@ -77,6 +77,12 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
return " redis-pass ", true
case EnvRedisDB:
return "5", true
case EnvTrustedSubjectHeader:
return "X-CRM-Authenticated-Subject", true
case EnvTrustedProxySecretHeader:
return "X-CRM-Trusted-Proxy", true
case EnvTrustedProxySecret:
return "proxy-secret", true
default:
return "", false
}
@@ -112,6 +118,15 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
if cfg.RouteRuntime.Redis.DB != 5 {
t.Fatalf("RouteRuntime.Redis.DB = %d, want 5", cfg.RouteRuntime.Redis.DB)
}
if cfg.UserKeyAuth.TrustedSubjectHeader != "X-CRM-Authenticated-Subject" {
t.Fatalf("UserKeyAuth.TrustedSubjectHeader = %q, want X-CRM-Authenticated-Subject", cfg.UserKeyAuth.TrustedSubjectHeader)
}
if cfg.UserKeyAuth.TrustedProxySecretHeader != "X-CRM-Trusted-Proxy" {
t.Fatalf("UserKeyAuth.TrustedProxySecretHeader = %q, want X-CRM-Trusted-Proxy", cfg.UserKeyAuth.TrustedProxySecretHeader)
}
if cfg.UserKeyAuth.TrustedProxySecret != "proxy-secret" {
t.Fatalf("UserKeyAuth.TrustedProxySecret = %q, want proxy-secret", cfg.UserKeyAuth.TrustedProxySecret)
}
})
t.Run("default values", func(t *testing.T) {
lookup := func(k string) (string, bool) {
@@ -142,6 +157,15 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
if cfg.RouteRuntime.Redis.Addr != "" || cfg.RouteRuntime.Redis.Password != "" || cfg.RouteRuntime.Redis.DB != 0 {
t.Fatalf("RouteRuntime.Redis = %+v, want zero value", cfg.RouteRuntime.Redis)
}
if cfg.UserKeyAuth.TrustedSubjectHeader != "" {
t.Fatalf("UserKeyAuth.TrustedSubjectHeader = %q, want empty by default", cfg.UserKeyAuth.TrustedSubjectHeader)
}
if cfg.UserKeyAuth.TrustedProxySecretHeader != DefaultTrustedProxySecretHeader {
t.Fatalf("UserKeyAuth.TrustedProxySecretHeader = %q, want %q", cfg.UserKeyAuth.TrustedProxySecretHeader, DefaultTrustedProxySecretHeader)
}
if cfg.UserKeyAuth.TrustedProxySecret != "" {
t.Fatalf("UserKeyAuth.TrustedProxySecret = %q, want empty by default", cfg.UserKeyAuth.TrustedProxySecret)
}
})
t.Run("invalid reconcile interval", func(t *testing.T) {
lookup := func(k string) (string, bool) {

View File

@@ -0,0 +1,2 @@
ALTER TABLE user_keys ADD COLUMN managed_identity_selector TEXT NOT NULL DEFAULT '';
CREATE INDEX IF NOT EXISTS idx_user_keys_managed_identity_selector ON user_keys(managed_identity_selector);

View File

@@ -85,15 +85,15 @@ func TestUserKeysRepoUpdateSecret(t *testing.T) {
t.Fatalf("Create() error = %v", err)
}
if err := store.UserKeys().UpdateSecret(ctx, "key_rotate_001", "sha256:new", "sk-****new1", "active"); err != nil {
if err := store.UserKeys().UpdateSecret(ctx, "key_rotate_001", "subject|key:key_rotate_001|rot:key_nonce", "sha256:new", "sk-****new1", "active"); err != nil {
t.Fatalf("UpdateSecret() error = %v", err)
}
key, err := store.UserKeys().GetByID(ctx, "key_rotate_001")
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if key.KeyFingerprint != "sha256:new" || key.MaskedPreview != "sk-****new1" || key.AdminStatus != "active" {
t.Fatalf("updated key = %+v, want new fingerprint/mask/status", key)
if key.ManagedIdentitySelector != "subject|key:key_rotate_001|rot:key_nonce" || key.KeyFingerprint != "sha256:new" || key.MaskedPreview != "sk-****new1" || key.AdminStatus != "active" {
t.Fatalf("updated key = %+v, want new selector/fingerprint/mask/status", key)
}
if strings.TrimSpace(key.UpdatedAt) == "" {
t.Fatalf("UpdatedAt = %q, want non-empty", key.UpdatedAt)

View File

@@ -9,20 +9,21 @@ import (
)
type UserKeyRecord struct {
ID int64 `json:"-"`
KeyID string `json:"key_id"`
OwnerSubjectID string `json:"owner_subject_id"`
KeyFingerprint string `json:"key_fingerprint"`
MaskedPreview string `json:"masked_preview"`
DisplayName string `json:"display_name"`
LogicalGroupID string `json:"logical_group_id"`
AllowedModels []string `json:"allowed_models"`
AdminStatus string `json:"admin_status"`
QuotaStatus string `json:"quota_status"`
LastUsedAt string `json:"last_used_at,omitempty"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at,omitempty"`
UpdatedAt string `json:"updated_at"`
ID int64 `json:"-"`
KeyID string `json:"key_id"`
OwnerSubjectID string `json:"owner_subject_id"`
ManagedIdentitySelector string `json:"-"`
KeyFingerprint string `json:"key_fingerprint"`
MaskedPreview string `json:"masked_preview"`
DisplayName string `json:"display_name"`
LogicalGroupID string `json:"logical_group_id"`
AllowedModels []string `json:"allowed_models"`
AdminStatus string `json:"admin_status"`
QuotaStatus string `json:"quota_status"`
LastUsedAt string `json:"last_used_at,omitempty"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at,omitempty"`
UpdatedAt string `json:"updated_at"`
}
type UserKeysRepo struct {
@@ -40,11 +41,11 @@ func (r *UserKeysRepo) Create(ctx context.Context, key UserKeyRecord) (int64, er
}
result, err := r.db.ExecContext(ctx, `
INSERT INTO user_keys (
key_id, owner_subject_id, key_fingerprint, masked_preview,
key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
display_name, logical_group_id, allowed_models,
admin_status, quota_status
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
key.KeyID, key.OwnerSubjectID, key.KeyFingerprint, key.MaskedPreview,
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
key.KeyID, key.OwnerSubjectID, key.ManagedIdentitySelector, key.KeyFingerprint, key.MaskedPreview,
key.DisplayName, key.LogicalGroupID, string(modelsJSON),
key.AdminStatus, key.QuotaStatus,
)
@@ -60,7 +61,7 @@ func scanUserKeys(rows *sql.Rows) ([]UserKeyRecord, error) {
var k UserKeyRecord
var modelsJSON, lastUsedAt, expiresAt sql.NullString
err := rows.Scan(
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.KeyFingerprint, &k.MaskedPreview,
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.ManagedIdentitySelector, &k.KeyFingerprint, &k.MaskedPreview,
&k.DisplayName, &k.LogicalGroupID, &modelsJSON,
&k.AdminStatus, &k.QuotaStatus, &lastUsedAt, &k.CreatedAt, &expiresAt, &k.UpdatedAt,
)
@@ -81,7 +82,7 @@ func scanOneUserKey(row *sql.Row) (*UserKeyRecord, error) {
var k UserKeyRecord
var modelsJSON, lastUsedAt, expiresAt sql.NullString
err := row.Scan(
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.KeyFingerprint, &k.MaskedPreview,
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.ManagedIdentitySelector, &k.KeyFingerprint, &k.MaskedPreview,
&k.DisplayName, &k.LogicalGroupID, &modelsJSON,
&k.AdminStatus, &k.QuotaStatus, &lastUsedAt, &k.CreatedAt, &expiresAt, &k.UpdatedAt,
)
@@ -98,7 +99,7 @@ func scanOneUserKey(row *sql.Row) (*UserKeyRecord, error) {
func (r *UserKeysRepo) ListByOwner(ctx context.Context, subjectID string) ([]UserKeyRecord, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
display_name, logical_group_id, allowed_models,
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
FROM user_keys WHERE owner_subject_id = ? ORDER BY created_at DESC`, subjectID)
@@ -111,7 +112,7 @@ func (r *UserKeysRepo) ListByOwner(ctx context.Context, subjectID string) ([]Use
func (r *UserKeysRepo) ListByFingerprint(ctx context.Context, fingerprint string) ([]UserKeyRecord, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
display_name, logical_group_id, allowed_models,
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
FROM user_keys WHERE key_fingerprint = ? ORDER BY created_at DESC`, fingerprint)
@@ -124,7 +125,7 @@ func (r *UserKeysRepo) ListByFingerprint(ctx context.Context, fingerprint string
func (r *UserKeysRepo) GetByID(ctx context.Context, keyID string) (*UserKeyRecord, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
display_name, logical_group_id, allowed_models,
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
FROM user_keys WHERE key_id = ?`, keyID)
@@ -154,8 +155,9 @@ func (r *UserKeysRepo) UpdateStatus(ctx context.Context, keyID string, adminStat
return nil
}
func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, fingerprint, maskedPreview, adminStatus string) error {
func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, managedIdentitySelector, fingerprint, maskedPreview, adminStatus string) error {
keyID = strings.TrimSpace(keyID)
managedIdentitySelector = strings.TrimSpace(managedIdentitySelector)
fingerprint = strings.TrimSpace(fingerprint)
maskedPreview = strings.TrimSpace(maskedPreview)
adminStatus = strings.ToLower(strings.TrimSpace(adminStatus))
@@ -174,9 +176,9 @@ func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, fingerprint, mas
}
result, err := r.db.ExecContext(ctx,
`UPDATE user_keys
SET key_fingerprint = ?, masked_preview = ?, admin_status = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')
SET managed_identity_selector = ?, key_fingerprint = ?, masked_preview = ?, admin_status = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')
WHERE key_id = ?`,
fingerprint, maskedPreview, adminStatus, keyID,
managedIdentitySelector, fingerprint, maskedPreview, adminStatus, keyID,
)
if err != nil {
return fmt.Errorf("update user_key secret: %w", err)