feat(vNext.4): implement trusted-subject security chain for portal user key self-service
- Add portal_auth.go: Portal user session auth with HMAC-signed cookies
- Add /api/portal/session/{login,logout,state} endpoints
- Update nginx config template: cookie-to-header trusted proxy pattern
- Update frontend: sync CRM session on login/logout
- Add TRUSTED_SUBJECT_DEPLOY_GUIDE.md with remote43 deployment steps
- Update EXECUTION_BOARD.md: mark trusted-subject blocking issue as resolved
This implements the secure chain:
Browser → Portal → nginx (cookie→header) → CRM (verify proxy secret)
Required remote43 actions:
1. Generate 64-char hex secret
2. Update .env.crm with TRUSTED_* config
3. Update nginx with cookie map and header injection
4. Restart services
Fixes EXECUTION_BOARD.md 2026-06-08 blocking issue
This commit is contained in:
@@ -141,7 +141,7 @@ func TestAPIAdminSessionLoginSetsCookieAndAuthorizesSubsequentRequest(t *testing
|
||||
ListPacks: func(context.Context) ([]PackInfo, error) {
|
||||
return []PackInfo{{PackID: "openai-cn-pack", Version: "1.1.6"}}, nil
|
||||
},
|
||||
}, "")
|
||||
}, "", "")
|
||||
|
||||
loginRequest := httptestRequest(t, http.MethodPost, "/api/admin/session/login", map[string]any{
|
||||
"username": "admin",
|
||||
@@ -177,7 +177,7 @@ func TestAPIAdminSessionRejectsInvalidPassword(t *testing.T) {
|
||||
Token: "secret-token",
|
||||
Username: "admin",
|
||||
Password: "pass-123",
|
||||
}, ActionSet{}, "")
|
||||
}, ActionSet{}, "", "")
|
||||
request := httptestRequest(t, http.MethodPost, "/api/admin/session/login", map[string]any{
|
||||
"username": "admin",
|
||||
"password": "wrong",
|
||||
@@ -192,7 +192,7 @@ func TestAPIAdminSessionLogoutClearsCookie(t *testing.T) {
|
||||
Token: "secret-token",
|
||||
Username: "admin",
|
||||
Password: "pass-123",
|
||||
}, ActionSet{}, "")
|
||||
}, ActionSet{}, "", "")
|
||||
request := httptestRequest(t, http.MethodPost, "/api/admin/session/logout", nil, "")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusNoContent)
|
||||
@@ -219,7 +219,7 @@ func TestAPIAdminSessionMeReportsAuthenticationState(t *testing.T) {
|
||||
Now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}, ActionSet{}, "")
|
||||
}, ActionSet{}, "", "")
|
||||
|
||||
request := httptestRequest(t, http.MethodGet, "/api/admin/session", nil, "")
|
||||
response := httptestRecorder(handler, request)
|
||||
|
||||
@@ -30,7 +30,7 @@ func Bootstrap(ctx context.Context) (*Server, error) {
|
||||
Username: adminSession.Username,
|
||||
Password: adminSession.Password,
|
||||
SessionTTL: adminSession.SessionTTL,
|
||||
}, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime), cfg.Database.SQLiteDSN)
|
||||
}, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime, cfg.UserKeyAuth), cfg.Database.SQLiteDSN, cfg.UserKeyAuth.TrustedProxySecret)
|
||||
return NewServer(cfg.Server.ListenAddr, handler, nil), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -336,14 +336,10 @@ func NewAPIHandler(adminToken string, actions ActionSet, dsn ...string) http.Han
|
||||
if len(dsn) > 0 {
|
||||
dsnVal = dsn[0]
|
||||
}
|
||||
return NewAPIHandlerWithAuth(AdminAuthConfig{Token: adminToken}, actions, dsnVal)
|
||||
return NewAPIHandlerWithAuth(AdminAuthConfig{Token: adminToken}, actions, dsnVal, "")
|
||||
}
|
||||
|
||||
func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, dsn ...string) http.Handler {
|
||||
sqliteDSN := ""
|
||||
if len(dsn) > 0 {
|
||||
sqliteDSN = dsn[0]
|
||||
}
|
||||
func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, sqliteDSN string, portalSessionSecret string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", healthz)
|
||||
mux.HandleFunc("GET /version", handleVersion)
|
||||
@@ -366,6 +362,19 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet, dsn ...
|
||||
mux.HandleFunc("GET /api/portal/logical-groups/{groupID}/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
handleListPortalLogicalGroupModels(w, r, actions.ListPortalLogicalGroupModels)
|
||||
})
|
||||
// Portal user session endpoints
|
||||
portalAuth := PortalAuthConfig{
|
||||
SessionSecret: portalSessionSecret,
|
||||
}
|
||||
mux.HandleFunc("GET /api/portal/session", func(w http.ResponseWriter, r *http.Request) {
|
||||
handlePortalSessionState(w, r, portalAuth)
|
||||
})
|
||||
mux.HandleFunc("POST /api/portal/session/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
handlePortalSessionLogin(w, r, portalAuth)
|
||||
})
|
||||
mux.HandleFunc("POST /api/portal/session/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
handlePortalSessionLogout(w, r)
|
||||
})
|
||||
mux.Handle("POST /api/batch-import/runs", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleCreateBatchImportRun(w, r, actions.CreateBatchImportRun)
|
||||
})))
|
||||
@@ -1296,6 +1305,14 @@ func writeJSON(w http.ResponseWriter, statusCode int, body any) {
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func nonEmptyString(value, fallback string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func classifyError(err error) *httpError {
|
||||
if err == nil {
|
||||
return nil
|
||||
@@ -1337,7 +1354,7 @@ func NewActionSet(sqliteDSN string) ActionSet {
|
||||
return NewActionSetWithStickyRuntime(sqliteDSN, defaultStickyStoreRuntime())
|
||||
}
|
||||
|
||||
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet {
|
||||
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime, authCfg ...config.UserKeyAuthConfig) ActionSet {
|
||||
routeLogWriter := newLazyRouteLogWriter(sqliteDSN)
|
||||
resolveRoute := buildResolveRouteAction(sqliteDSN, stickyRuntime, routeLogWriter)
|
||||
proxyRouteChatCompletions := buildProxyRouteChatCompletionsAction(sqliteDSN, resolveRoute, routeLogWriter)
|
||||
@@ -1383,7 +1400,7 @@ func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRu
|
||||
GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime),
|
||||
ListProviderAccounts: buildListProviderAccountsAction(sqliteDSN),
|
||||
GetProviderAccountBindingCandidates: buildGetProviderAccountBindingCandidatesAction(sqliteDSN),
|
||||
UserKeyHandler: buildUserKeyHandler(sqliteDSN),
|
||||
UserKeyHandler: buildUserKeyHandler(sqliteDSN, authCfg...),
|
||||
UpdateProviderAccountBinding: buildUpdateProviderAccountBindingAction(sqliteDSN),
|
||||
EnableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusActive),
|
||||
DisableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDisabled),
|
||||
@@ -2741,6 +2758,19 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "quota_exhausted", Message: "API key quota exhausted"})
|
||||
return
|
||||
}
|
||||
if key.ExpiresAt != "" {
|
||||
expiresAt, parseErr := time.Parse(time.RFC3339, key.ExpiresAt)
|
||||
if parseErr != nil {
|
||||
metrics.RecordUserKeyChatRequest("key_metadata_error")
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "key_metadata_error", Message: "invalid key expiry metadata"})
|
||||
return
|
||||
}
|
||||
if !expiresAt.After(time.Now().UTC()) {
|
||||
metrics.RecordUserKeyChatRequest("key_expired")
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "key_expired", Message: "API key has expired"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Parse request body (OpenAI-compatible)
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, maxJSONBodyBytes))
|
||||
@@ -2768,6 +2798,20 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "model is required"})
|
||||
return
|
||||
}
|
||||
if len(key.AllowedModels) > 0 {
|
||||
modelAllowed := false
|
||||
for _, allowedModel := range key.AllowedModels {
|
||||
if strings.TrimSpace(allowedModel) == model {
|
||||
modelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !modelAllowed {
|
||||
metrics.RecordUserKeyChatRequest("model_not_allowed")
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusForbidden, Code: "model_not_allowed", Message: "requested model is not allowed for this API key"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Map to proxy request
|
||||
proxyReq := ProxyRouteChatCompletionsRequest{
|
||||
@@ -2804,7 +2848,28 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
|
||||
}
|
||||
|
||||
if upstreamResp == nil {
|
||||
// Fallback: construct a minimal response from proxy info
|
||||
upstreamResp = map[string]any{}
|
||||
}
|
||||
|
||||
if !result.Forward.OK {
|
||||
statusCode := result.Forward.UpstreamStatus
|
||||
if statusCode <= 0 {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
upstreamResp["upstream_http_code"] = statusCode
|
||||
if _, hasError := upstreamResp["error"]; !hasError {
|
||||
upstreamResp["error"] = map[string]any{
|
||||
"code": nonEmptyString(result.Forward.ErrorClass, "upstream_error"),
|
||||
"message": nonEmptyString(result.Forward.ErrorMessage, fmt.Sprintf("upstream request failed with status %d", statusCode)),
|
||||
}
|
||||
}
|
||||
metrics.RecordUserKeyChatRequest(nonEmptyString(result.Forward.ErrorClass, "upstream_error"))
|
||||
writeJSON(w, statusCode, upstreamResp)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback: construct a minimal success response from proxy info
|
||||
if len(upstreamResp) == 0 {
|
||||
upstreamResp = map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
|
||||
"object": "chat.completion",
|
||||
@@ -2820,10 +2885,8 @@ func handlePublicV1ChatCompletions(w http.ResponseWriter, r *http.Request, dsn s
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure upstream HTTP code is reflected
|
||||
if !result.Forward.OK && result.Forward.UpstreamStatus > 0 {
|
||||
upstreamResp["upstream_http_code"] = result.Forward.UpstreamStatus
|
||||
if err := store.UserKeys().TouchLastUsed(r.Context(), key.KeyID); err != nil {
|
||||
log.Printf("gateway: touch last_used_at for key %s failed: %v", key.KeyID, err)
|
||||
}
|
||||
|
||||
// Wrap in OpenAI standard envelope if upstream didn't return one
|
||||
|
||||
@@ -19,13 +19,16 @@ func generatePlaintextKey() (string, string) {
|
||||
}
|
||||
|
||||
type UserKeyHandler struct {
|
||||
createFn func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error)
|
||||
listFn func(ctx context.Context, subjectID string) ([]UserKeyMeta, error)
|
||||
getFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
|
||||
resetFn func(ctx context.Context, keyID, subjectID string) (ResetUserKeyResponse, error)
|
||||
pauseFn func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error)
|
||||
resumeFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
|
||||
deleteFn func(ctx context.Context, keyID, subjectID string) error
|
||||
TrustedSubjectHeader string
|
||||
TrustedProxySecretHeader string
|
||||
TrustedProxySecret string
|
||||
createFn func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error)
|
||||
listFn func(ctx context.Context, subjectID string) ([]UserKeyMeta, error)
|
||||
getFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
|
||||
resetFn func(ctx context.Context, keyID, subjectID string) (ResetUserKeyResponse, error)
|
||||
pauseFn func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error)
|
||||
resumeFn func(ctx context.Context, keyID, subjectID string) (UserKeyMeta, error)
|
||||
deleteFn func(ctx context.Context, keyID, subjectID string) error
|
||||
}
|
||||
|
||||
type CreateUserKeyRequest struct {
|
||||
@@ -60,22 +63,22 @@ type UserKeyMeta struct {
|
||||
}
|
||||
|
||||
func (h *UserKeyHandler) extractSubjectID(r *http.Request) (string, *httpError) {
|
||||
for _, header := range []string{"X-Portal-Subject", "X-User-Subject", "X-Forwarded-User"} {
|
||||
if subjectID := strings.TrimSpace(r.Header.Get(header)); subjectID != "" {
|
||||
return subjectID, nil
|
||||
}
|
||||
if h == nil {
|
||||
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "user credentials required"}
|
||||
}
|
||||
if hdr := r.Header.Get("Authorization"); strings.HasPrefix(hdr, "Bearer ") {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(hdr, "Bearer "))
|
||||
if token != "" {
|
||||
n := 8
|
||||
if len(token) < n {
|
||||
n = len(token)
|
||||
}
|
||||
return "skeleton_user_" + token[:n], nil
|
||||
}
|
||||
subjectHeader := strings.TrimSpace(h.TrustedSubjectHeader)
|
||||
secretHeader := strings.TrimSpace(h.TrustedProxySecretHeader)
|
||||
secret := strings.TrimSpace(h.TrustedProxySecret)
|
||||
if subjectHeader == "" || secretHeader == "" || secret == "" {
|
||||
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted user identity proxy not configured"}
|
||||
}
|
||||
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "user credentials required"}
|
||||
if got := strings.TrimSpace(r.Header.Get(secretHeader)); got != secret {
|
||||
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted proxy authentication required"}
|
||||
}
|
||||
if subjectID := strings.TrimSpace(r.Header.Get(subjectHeader)); subjectID != "" {
|
||||
return subjectID, nil
|
||||
}
|
||||
return "", &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "trusted subject header required"}
|
||||
}
|
||||
|
||||
func writeSvcNotImplError(w http.ResponseWriter) {
|
||||
@@ -181,7 +184,16 @@ func handlePauseUserKey(w http.ResponseWriter, r *http.Request, h *UserKeyHandle
|
||||
return
|
||||
}
|
||||
keyID := r.PathValue("key_id")
|
||||
key, svcErr := h.pauseFn(r.Context(), keyID, subjectID, "")
|
||||
var req struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
if r.Body != nil && r.ContentLength != 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_json", Message: err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
key, svcErr := h.pauseFn(r.Context(), keyID, subjectID, strings.TrimSpace(req.Reason))
|
||||
if svcErr != nil {
|
||||
writeHTTPError(w, classifyError(svcErr))
|
||||
return
|
||||
|
||||
@@ -4,12 +4,35 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
testTrustedSubjectHeader = "X-CRM-Authenticated-Subject"
|
||||
testTrustedProxySecretHeader = "X-CRM-Trusted-Proxy"
|
||||
testTrustedProxySecret = "shared-secret"
|
||||
)
|
||||
|
||||
func withTrustedProxyAuth(h *UserKeyHandler) *UserKeyHandler {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *h
|
||||
clone.TrustedSubjectHeader = testTrustedSubjectHeader
|
||||
clone.TrustedProxySecretHeader = testTrustedProxySecretHeader
|
||||
clone.TrustedProxySecret = testTrustedProxySecret
|
||||
return &clone
|
||||
}
|
||||
|
||||
func applyTrustedProxyHeaders(req *http.Request, subjectID string) {
|
||||
req.Header.Set(testTrustedSubjectHeader, subjectID)
|
||||
req.Header.Set(testTrustedProxySecretHeader, testTrustedProxySecret)
|
||||
}
|
||||
|
||||
func TestGeneratePlaintextKeyAndExtractSubjectID(t *testing.T) {
|
||||
t.Parallel()
|
||||
plaintext, fingerprint := generatePlaintextKey()
|
||||
@@ -20,16 +43,49 @@ func TestGeneratePlaintextKeyAndExtractSubjectID(t *testing.T) {
|
||||
t.Fatalf("fingerprint = %q, want sha256 prefix", fingerprint)
|
||||
}
|
||||
|
||||
h := &UserKeyHandler{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("Authorization", "Bearer abcdefgh12345678")
|
||||
subjectID, httpErr := h.extractSubjectID(req)
|
||||
if httpErr != nil {
|
||||
t.Fatalf("extractSubjectID() unexpected error: %+v", httpErr)
|
||||
}
|
||||
if subjectID != "skeleton_user_abcdefgh" {
|
||||
t.Fatalf("subjectID = %q, want skeleton_user_abcdefgh", subjectID)
|
||||
}
|
||||
t.Run("rejects bearer fallback when trusted proxy auth is not configured", func(t *testing.T) {
|
||||
h := &UserKeyHandler{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("Authorization", "Bearer abcdefgh12345678")
|
||||
_, httpErr := h.extractSubjectID(req)
|
||||
if httpErr == nil {
|
||||
t.Fatal("expected unauthorized error when trusted proxy auth is not configured")
|
||||
}
|
||||
if httpErr.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want 401", httpErr.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects portal subject header when trusted proxy auth is not configured", func(t *testing.T) {
|
||||
h := &UserKeyHandler{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:1")
|
||||
_, httpErr := h.extractSubjectID(req)
|
||||
if httpErr == nil {
|
||||
t.Fatal("expected unauthorized error when trusted proxy auth is not configured")
|
||||
}
|
||||
if httpErr.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want 401", httpErr.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accepts trusted proxy subject when proxy secret matches", func(t *testing.T) {
|
||||
h := &UserKeyHandler{
|
||||
TrustedSubjectHeader: "X-CRM-Authenticated-Subject",
|
||||
TrustedProxySecretHeader: "X-CRM-Trusted-Proxy",
|
||||
TrustedProxySecret: "shared-secret",
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("X-CRM-Authenticated-Subject", "portal-user:1")
|
||||
req.Header.Set("X-CRM-Trusted-Proxy", "shared-secret")
|
||||
subjectID, httpErr := h.extractSubjectID(req)
|
||||
if httpErr != nil {
|
||||
t.Fatalf("extractSubjectID() unexpected error: %+v", httpErr)
|
||||
}
|
||||
if subjectID != "portal-user:1" {
|
||||
t.Fatalf("subjectID = %q, want portal-user:1", subjectID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleUserKeyListNotImplemented(t *testing.T) {
|
||||
@@ -46,16 +102,16 @@ func TestHandleUserKeyListNotImplemented(t *testing.T) {
|
||||
|
||||
func TestHandleUserKeyListSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := &UserKeyHandler{
|
||||
h := withTrustedProxyAuth(&UserKeyHandler{
|
||||
listFn: func(ctx context.Context, subjectID string) ([]UserKeyMeta, error) {
|
||||
if subjectID != "portal-user:1" {
|
||||
t.Fatalf("subjectID = %q, want portal-user:1", subjectID)
|
||||
}
|
||||
return []UserKeyMeta{{KeyID: "key_1", AdminStatus: "active"}}, nil
|
||||
},
|
||||
}
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:1")
|
||||
applyTrustedProxyHeaders(req, "portal-user:1")
|
||||
rr := httptest.NewRecorder()
|
||||
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
|
||||
handleListUserKeys(w, r, h)
|
||||
@@ -70,12 +126,12 @@ func TestHandleUserKeyListSuccess(t *testing.T) {
|
||||
|
||||
func TestHandleGetUserKeyMissingKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := &UserKeyHandler{getFn: func(context.Context, string, string) (UserKeyMeta, error) {
|
||||
h := withTrustedProxyAuth(&UserKeyHandler{getFn: func(context.Context, string, string) (UserKeyMeta, error) {
|
||||
t.Fatal("getFn should not be called when key_id is missing")
|
||||
return UserKeyMeta{}, nil
|
||||
}}
|
||||
}})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys/", nil)
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:1")
|
||||
applyTrustedProxyHeaders(req, "portal-user:1")
|
||||
rr := httptest.NewRecorder()
|
||||
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
|
||||
handleGetUserKey(w, r, h)
|
||||
@@ -131,7 +187,7 @@ func TestHandleUserKeyMutationHandlers(t *testing.T) {
|
||||
path: "/api/keys/key_1/pause",
|
||||
handlerFn: handlePauseUserKey,
|
||||
userHandler: &UserKeyHandler{pauseFn: func(ctx context.Context, keyID, subjectID, reason string) (UserKeyMeta, error) {
|
||||
if keyID != "key_1" || subjectID != "portal-user:1" || reason != "" {
|
||||
if keyID != "key_1" || subjectID != "portal-user:1" || reason != "user requested pause" {
|
||||
t.Fatalf("pauseFn args = (%q,%q,%q)", keyID, subjectID, reason)
|
||||
}
|
||||
paused := meta
|
||||
@@ -174,12 +230,18 @@ func TestHandleUserKeyMutationHandlers(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
req := httptest.NewRequest(tc.method, tc.path, func() io.Reader {
|
||||
if tc.name == "pause-success" {
|
||||
return strings.NewReader(`{"reason":"user requested pause"}`)
|
||||
}
|
||||
return nil
|
||||
}())
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:1")
|
||||
req.SetPathValue("key_id", "key_1")
|
||||
rr := httptest.NewRecorder()
|
||||
applyTrustedProxyHeaders(req, "portal-user:1")
|
||||
serveWithMetrics(t, req, rr, func(w http.ResponseWriter, r *http.Request) {
|
||||
tc.handlerFn(w, r, tc.userHandler)
|
||||
tc.handlerFn(w, r, withTrustedProxyAuth(tc.userHandler))
|
||||
})
|
||||
if rr.Code != tc.wantStatus {
|
||||
t.Fatalf("status = %d, want %d body=%s", rr.Code, tc.wantStatus, rr.Body.String())
|
||||
@@ -200,11 +262,11 @@ func serveWithMetrics(t *testing.T, req *http.Request, rr *httptest.ResponseReco
|
||||
|
||||
func TestHandleListUserKeysResponseShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := &UserKeyHandler{listFn: func(context.Context, string) ([]UserKeyMeta, error) {
|
||||
h := withTrustedProxyAuth(&UserKeyHandler{listFn: func(context.Context, string) ([]UserKeyMeta, error) {
|
||||
return []UserKeyMeta{{KeyID: "key_json", AdminStatus: "active"}}, nil
|
||||
}}
|
||||
}})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/keys", nil)
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:json")
|
||||
applyTrustedProxyHeaders(req, "portal-user:json")
|
||||
rr := httptest.NewRecorder()
|
||||
handleListUserKeys(rr, req, h)
|
||||
var payload struct {
|
||||
@@ -263,10 +325,10 @@ func TestHandleUserKeyMutationHandlersErrorPaths(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/keys/key_1", nil)
|
||||
req.Header.Set("X-Portal-Subject", "portal-user:1")
|
||||
applyTrustedProxyHeaders(req, "portal-user:1")
|
||||
req.SetPathValue("key_id", "key_1")
|
||||
rr := httptest.NewRecorder()
|
||||
tc.handlerFn(rr, req, tc.userHandler)
|
||||
tc.handlerFn(rr, req, withTrustedProxyAuth(tc.userHandler))
|
||||
if rr.Code != tc.wantStatus {
|
||||
t.Fatalf("status = %d, want %d body=%s", rr.Code, tc.wantStatus, rr.Body.String())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/config"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/metrics"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
@@ -104,13 +105,34 @@ func ensureSubjectHasAccess(ctx context.Context, client *sub2api.Client, subject
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
func buildManagedIdentitySelector(subjectID, keyID string) string {
|
||||
return strings.TrimSpace(subjectID) + "|key:" + strings.TrimSpace(keyID) + "|rot:" + generateKeyID()
|
||||
}
|
||||
|
||||
func managedIdentitySelectorForRecord(rec *sqlite.UserKeyRecord) string {
|
||||
if rec == nil {
|
||||
return ""
|
||||
}
|
||||
if selector := strings.TrimSpace(rec.ManagedIdentitySelector); selector != "" {
|
||||
return selector
|
||||
}
|
||||
return strings.TrimSpace(rec.OwnerSubjectID)
|
||||
}
|
||||
|
||||
func recordUserKeyFailure(operation, result string, err error) error {
|
||||
metrics.RecordUserKeyOperation(operation, result)
|
||||
return err
|
||||
}
|
||||
|
||||
func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
func buildUserKeyHandler(sqliteDSN string, authCfg ...config.UserKeyAuthConfig) *UserKeyHandler {
|
||||
var cfg config.UserKeyAuthConfig
|
||||
if len(authCfg) > 0 {
|
||||
cfg = authCfg[0]
|
||||
}
|
||||
return &UserKeyHandler{
|
||||
TrustedSubjectHeader: strings.TrimSpace(cfg.TrustedSubjectHeader),
|
||||
TrustedProxySecretHeader: strings.TrimSpace(cfg.TrustedProxySecretHeader),
|
||||
TrustedProxySecret: strings.TrimSpace(cfg.TrustedProxySecret),
|
||||
createFn: func(ctx context.Context, req CreateUserKeyRequest) (CreateUserKeyResponse, error) {
|
||||
if strings.TrimSpace(req.SubjectID) == "" {
|
||||
metrics.RecordUserKeyOperation("create", "unauthorized")
|
||||
@@ -136,6 +158,9 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
return CreateUserKeyResponse{}, &httpError{StatusCode: 429, Code: "rate_limited", Message: "create key rate limit exceeded"}
|
||||
}
|
||||
|
||||
keyID := generateKeyID()
|
||||
managedIdentitySelector := buildManagedIdentitySelector(req.SubjectID, keyID)
|
||||
|
||||
// Resolve logical group → host → group ID → ensure subscription access
|
||||
_, route, hostRow, client, err := resolveLogicalGroupHost(ctx, store, req.LogicalGroupID)
|
||||
if err != nil {
|
||||
@@ -145,26 +170,26 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
if err != nil {
|
||||
return CreateUserKeyResponse{}, recordUserKeyFailure("create", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for %q: %w", route.ShadowGroupID, err))
|
||||
}
|
||||
apiKey, err := ensureSubjectHasAccess(ctx, client, req.SubjectID, hostGroupID)
|
||||
apiKey, err := ensureSubjectHasAccess(ctx, client, managedIdentitySelector, hostGroupID)
|
||||
if err != nil {
|
||||
return CreateUserKeyResponse{}, recordUserKeyFailure("create", "ensure_access_error", fmt.Errorf("ensure access for %q: %w", req.LogicalGroupID, err))
|
||||
}
|
||||
|
||||
fingerprint := "sha256:" + sha256Hex(apiKey)
|
||||
keyID := generateKeyID()
|
||||
masked := "sk-****" + apiKey[len(apiKey)-4:]
|
||||
|
||||
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
|
||||
if _, err := q.UserKeys.Create(ctx, sqlite.UserKeyRecord{
|
||||
KeyID: keyID,
|
||||
OwnerSubjectID: req.SubjectID,
|
||||
KeyFingerprint: fingerprint,
|
||||
MaskedPreview: masked,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
|
||||
AllowedModels: req.AllowedModels,
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
KeyID: keyID,
|
||||
OwnerSubjectID: req.SubjectID,
|
||||
ManagedIdentitySelector: managedIdentitySelector,
|
||||
KeyFingerprint: fingerprint,
|
||||
MaskedPreview: masked,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
|
||||
AllowedModels: req.AllowedModels,
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("create key: %w", err)
|
||||
}
|
||||
@@ -288,7 +313,8 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
if err != nil {
|
||||
return ResetUserKeyResponse{}, recordUserKeyFailure("reset", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for %q: %w", route.ShadowGroupID, err))
|
||||
}
|
||||
newPlaintext, err := ensureSubjectHasAccess(ctx, client, rec.OwnerSubjectID, hostGroupID)
|
||||
managedIdentitySelector := buildManagedIdentitySelector(rec.OwnerSubjectID, keyID)
|
||||
newPlaintext, err := ensureSubjectHasAccess(ctx, client, managedIdentitySelector, hostGroupID)
|
||||
if err != nil {
|
||||
return ResetUserKeyResponse{}, recordUserKeyFailure("reset", "ensure_access_error", fmt.Errorf("ensure access on reset for %q: %w", rec.LogicalGroupID, err))
|
||||
}
|
||||
@@ -297,7 +323,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
masked := "sk-****" + newPlaintext[len(newPlaintext)-4:]
|
||||
|
||||
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
|
||||
if err := q.UserKeys.UpdateSecret(ctx, keyID, hostFingerprint, masked, "active"); err != nil {
|
||||
if err := q.UserKeys.UpdateSecret(ctx, keyID, managedIdentitySelector, hostFingerprint, masked, "active"); err != nil {
|
||||
return fmt.Errorf("reset key: %w", err)
|
||||
}
|
||||
if _, err := q.UserKeyAuditEvents.Create(ctx, sqlite.UserKeyAuditEvent{
|
||||
@@ -341,7 +367,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
if err != nil {
|
||||
return UserKeyMeta{}, recordUserKeyFailure("pause", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for pause %q: %w", route.ShadowGroupID, err))
|
||||
}
|
||||
if err := client.PauseManagedSubscriptionAccess(ctx, rec.OwnerSubjectID, hostGroupID); err != nil {
|
||||
if err := client.PauseManagedSubscriptionAccess(ctx, managedIdentitySelectorForRecord(rec), hostGroupID); err != nil {
|
||||
return UserKeyMeta{}, recordUserKeyFailure("pause", "pause_access_error", fmt.Errorf("pause managed subscription access: %w", err))
|
||||
}
|
||||
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
|
||||
@@ -384,7 +410,7 @@ func buildUserKeyHandler(sqliteDSN string) *UserKeyHandler {
|
||||
if err != nil {
|
||||
return UserKeyMeta{}, recordUserKeyFailure("resume", "resolve_shadow_group_error", fmt.Errorf("resolve shadow group id for resume %q: %w", route.ShadowGroupID, err))
|
||||
}
|
||||
if err := client.ResumeManagedSubscriptionAccess(ctx, rec.OwnerSubjectID, hostGroupID); err != nil {
|
||||
if err := client.ResumeManagedSubscriptionAccess(ctx, managedIdentitySelectorForRecord(rec), hostGroupID); err != nil {
|
||||
return UserKeyMeta{}, recordUserKeyFailure("resume", "resume_access_error", fmt.Errorf("resume managed subscription access: %w", err))
|
||||
}
|
||||
err = store.WithTx(ctx, func(q *sqlite.Queries) error {
|
||||
|
||||
@@ -11,10 +11,24 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/config"
|
||||
"sub2api-cn-relay-manager/internal/metrics"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func testUserKeyAuthConfig() config.UserKeyAuthConfig {
|
||||
return config.UserKeyAuthConfig{
|
||||
TrustedSubjectHeader: testTrustedSubjectHeader,
|
||||
TrustedProxySecretHeader: testTrustedProxySecretHeader,
|
||||
TrustedProxySecret: testTrustedProxySecret,
|
||||
}
|
||||
}
|
||||
|
||||
func applyTrustedProxyAuthHeaders(req *http.Request, subjectID string) {
|
||||
req.Header.Set(testTrustedSubjectHeader, subjectID)
|
||||
req.Header.Set(testTrustedProxySecretHeader, testTrustedProxySecret)
|
||||
}
|
||||
|
||||
func makeCreateBody(groupID, displayName string, models []string) io.Reader {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"logical_group_id": groupID,
|
||||
@@ -60,11 +74,11 @@ func TestUserKeyAPIUsesPortalSubjectHeader(t *testing.T) {
|
||||
})
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
|
||||
})
|
||||
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("gpt-shared", "portal key", []string{"gpt-5.4"}))
|
||||
req.Header.Set("X-Portal-Subject", "smoke-user")
|
||||
applyTrustedProxyAuthHeaders(req, "smoke-user")
|
||||
resp := httptestRecorder(handler, req)
|
||||
|
||||
// We expect 500 because test host is unreachable (port 1), but the important
|
||||
@@ -107,11 +121,11 @@ func TestUserKeyCreateRejectsMissingSubject(t *testing.T) {
|
||||
func TestUserKeyCreateRejectsMissingGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t))),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t)), testUserKeyAuthConfig()),
|
||||
})
|
||||
body := bytes.NewReader([]byte(`{"display_name":"portal key"}`))
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", body)
|
||||
req.Header.Set("X-Portal-Subject", "smoke-user")
|
||||
applyTrustedProxyAuthHeaders(req, "smoke-user")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want 400", resp.code)
|
||||
@@ -142,18 +156,18 @@ func TestUserKeyRateLimitNoDB(t *testing.T) {
|
||||
})
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
|
||||
})
|
||||
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("gpt-shared", "rate-test", nil))
|
||||
req.Header.Set("X-Portal-Subject", "rate-user")
|
||||
applyTrustedProxyAuthHeaders(req, "rate-user")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code == http.StatusUnauthorized || resp.code == http.StatusNotImplemented {
|
||||
t.Fatalf("status code = %d, expected to pass auth layer", resp.code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testing.T) {
|
||||
func TestUserKeyCreateUsesPerRecordManagedKeyAndConsistentMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
@@ -163,6 +177,8 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
const hostGroupID = "999"
|
||||
const subjectID = "portal-user:13"
|
||||
|
||||
var loginEmail string
|
||||
var customKey string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
|
||||
@@ -180,9 +196,9 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode login request: %v", err)
|
||||
}
|
||||
expected := expectedManagedIdentity(subjectID, hostGroupID)
|
||||
if got := fmt.Sprint(req["email"]); got != expected.Email {
|
||||
t.Fatalf("login email = %q, want subject-scoped %q", got, expected.Email)
|
||||
loginEmail = fmt.Sprint(req["email"])
|
||||
if !strings.Contains(loginEmail, "@sub2api.local") || strings.Contains(loginEmail, subjectID) {
|
||||
t.Fatalf("login email = %q, want synthesized per-record managed identity", loginEmail)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
|
||||
@@ -190,9 +206,9 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode managed key request: %v", err)
|
||||
}
|
||||
expected := expectedManagedIdentity(subjectID, hostGroupID)
|
||||
if got := fmt.Sprint(req["custom_key"]); got != expected.CustomKey {
|
||||
t.Fatalf("custom_key = %q, want subject-scoped %q", got, expected.CustomKey)
|
||||
customKey = fmt.Sprint(req["custom_key"])
|
||||
if !strings.HasPrefix(customKey, "sk-relay-") {
|
||||
t.Fatalf("custom_key = %q, want sk-relay-*", customKey)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":501,"key":"placeholder-from-host","name":"managed-key"}}`))
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
|
||||
@@ -226,11 +242,11 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
})
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
|
||||
})
|
||||
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody(logicalGroupID, "portal key", []string{"gpt-5.4"}))
|
||||
req.Header.Set("X-Portal-Subject", subjectID)
|
||||
applyTrustedProxyAuthHeaders(req, subjectID)
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusCreated {
|
||||
t.Fatalf("status code = %d, want 201, body=%s", resp.code, resp.Body().String())
|
||||
@@ -241,11 +257,10 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
t.Fatalf("decode create response: %v", err)
|
||||
}
|
||||
|
||||
expected := expectedManagedIdentity(subjectID, hostGroupID)
|
||||
if createResp.PlaintextKey != expected.CustomKey {
|
||||
t.Fatalf("plaintext_key = %q, want subject-scoped %q", createResp.PlaintextKey, expected.CustomKey)
|
||||
if createResp.PlaintextKey != customKey {
|
||||
t.Fatalf("plaintext_key = %q, want host custom_key %q", createResp.PlaintextKey, customKey)
|
||||
}
|
||||
wantMasked := "sk-****" + expected.CustomKey[len(expected.CustomKey)-4:]
|
||||
wantMasked := "sk-****" + customKey[len(customKey)-4:]
|
||||
if createResp.Key.MaskedPreview != wantMasked {
|
||||
t.Fatalf("masked_preview = %q, want %q", createResp.Key.MaskedPreview, wantMasked)
|
||||
}
|
||||
@@ -254,12 +269,18 @@ func TestUserKeyCreateUsesSubjectScopedManagedKeyAndConsistentMetadata(t *testin
|
||||
if err != nil {
|
||||
t.Fatalf("UserKeys().GetByID() error = %v", err)
|
||||
}
|
||||
if record.KeyFingerprint != "sha256:"+sha256Hex(expected.CustomKey) {
|
||||
if strings.TrimSpace(record.ManagedIdentitySelector) == "" || !strings.Contains(record.ManagedIdentitySelector, createResp.Key.KeyID) {
|
||||
t.Fatalf("managed_identity_selector = %q, want non-empty selector tied to key id", record.ManagedIdentitySelector)
|
||||
}
|
||||
if record.KeyFingerprint != "sha256:"+sha256Hex(customKey) {
|
||||
t.Fatalf("key_fingerprint = %q, want sha256 of returned plaintext key", record.KeyFingerprint)
|
||||
}
|
||||
if record.MaskedPreview != wantMasked {
|
||||
t.Fatalf("stored masked_preview = %q, want %q", record.MaskedPreview, wantMasked)
|
||||
}
|
||||
if loginEmail == "" {
|
||||
t.Fatal("login email was not observed")
|
||||
}
|
||||
}
|
||||
|
||||
type managedIdentityExpectation struct {
|
||||
@@ -302,13 +323,145 @@ func expectedManagedPrefix(value string) string {
|
||||
return prefix
|
||||
}
|
||||
|
||||
func TestUserKeyCreateAndResetDoNotReuseSameManagedKeyWithinSubjectGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
const logicalGroupID = "gpt-shared"
|
||||
const hostGroupID = "999"
|
||||
const subjectID = "portal-user:multi"
|
||||
|
||||
type managedUser struct {
|
||||
ID int64
|
||||
Email string
|
||||
}
|
||||
usersByEmail := map[string]managedUser{}
|
||||
nextUserID := int64(100)
|
||||
nextKeyID := int64(500)
|
||||
createdCustomKeys := make([]string, 0, 4)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
|
||||
search := strings.TrimSpace(r.URL.Query().Get("search"))
|
||||
items := make([]map[string]any, 0, 1)
|
||||
if user, ok := usersByEmail[search]; ok {
|
||||
items = append(items, map[string]any{"id": user.ID, "email": user.Email})
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"items": items}})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode create user request: %v", err)
|
||||
}
|
||||
email := strings.TrimSpace(fmt.Sprint(req["email"]))
|
||||
nextUserID++
|
||||
usersByEmail[email] = managedUser{ID: nextUserID, Email: email}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": nextUserID, "email": email}})
|
||||
case r.Method == http.MethodPut && strings.HasPrefix(r.URL.Path, "/api/v1/admin/users/"):
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 1}})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/101/balance":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 101}})
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/api/v1/admin/users/") && strings.HasSuffix(r.URL.Path, "/balance"):
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 1}})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": 401}})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"access_token": "user-jwt"}})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode create key request: %v", err)
|
||||
}
|
||||
customKey := strings.TrimSpace(fmt.Sprint(req["custom_key"]))
|
||||
createdCustomKeys = append(createdCustomKeys, customKey)
|
||||
nextKeyID++
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"id": nextKeyID, "key": customKey, "name": fmt.Sprint(req["name"])}})
|
||||
case r.Method == http.MethodPut && strings.HasPrefix(r.URL.Path, "/api/v1/admin/api-keys/"):
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, _ = store.Hosts().Create(context.Background(), sqlite.Host{
|
||||
HostID: "test-host",
|
||||
BaseURL: server.URL,
|
||||
HostVersion: "0.0.1",
|
||||
CapabilityProbeJSON: "{}",
|
||||
AuthType: "apikey",
|
||||
AuthToken: "test-token",
|
||||
})
|
||||
_, _ = store.LogicalGroups().Create(context.Background(), sqlite.LogicalGroup{
|
||||
LogicalGroupID: logicalGroupID,
|
||||
DisplayName: "GPT Shared",
|
||||
Status: "active",
|
||||
})
|
||||
_, _ = store.LogicalGroupRoutes().Create(context.Background(), sqlite.LogicalGroupRoute{
|
||||
RouteID: "test-route",
|
||||
LogicalGroupID: logicalGroupID,
|
||||
Name: "Test Route",
|
||||
Status: "active",
|
||||
ShadowHostID: "test-host",
|
||||
ShadowGroupID: hostGroupID,
|
||||
})
|
||||
|
||||
handler := buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig())
|
||||
first, err := handler.createFn(context.Background(), CreateUserKeyRequest{SubjectID: subjectID, LogicalGroupID: logicalGroupID, DisplayName: "first", AllowedModels: []string{"gpt-5.4"}})
|
||||
if err != nil {
|
||||
t.Fatalf("first createFn() error = %v", err)
|
||||
}
|
||||
second, err := handler.createFn(context.Background(), CreateUserKeyRequest{SubjectID: subjectID, LogicalGroupID: logicalGroupID, DisplayName: "second", AllowedModels: []string{"gpt-5.4"}})
|
||||
if err != nil {
|
||||
t.Fatalf("second createFn() error = %v", err)
|
||||
}
|
||||
if first.PlaintextKey == second.PlaintextKey {
|
||||
t.Fatalf("createFn() reused plaintext key across records: first=%q second=%q", first.PlaintextKey, second.PlaintextKey)
|
||||
}
|
||||
|
||||
reset, err := handler.resetFn(context.Background(), first.Key.KeyID, subjectID)
|
||||
if err != nil {
|
||||
t.Fatalf("resetFn() error = %v", err)
|
||||
}
|
||||
if reset.PlaintextKey == first.PlaintextKey {
|
||||
t.Fatalf("resetFn() reused original plaintext key: before=%q after=%q", first.PlaintextKey, reset.PlaintextKey)
|
||||
}
|
||||
if reset.PlaintextKey == second.PlaintextKey {
|
||||
t.Fatalf("resetFn() collided with sibling key: reset=%q sibling=%q", reset.PlaintextKey, second.PlaintextKey)
|
||||
}
|
||||
|
||||
firstRecord, err := store.UserKeys().GetByID(context.Background(), first.Key.KeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID(first) error = %v", err)
|
||||
}
|
||||
secondRecord, err := store.UserKeys().GetByID(context.Background(), second.Key.KeyID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID(second) error = %v", err)
|
||||
}
|
||||
if firstRecord.KeyFingerprint == secondRecord.KeyFingerprint {
|
||||
t.Fatalf("distinct key records share fingerprint: %q", firstRecord.KeyFingerprint)
|
||||
}
|
||||
if firstRecord.KeyFingerprint != "sha256:"+sha256Hex(reset.PlaintextKey) {
|
||||
t.Fatalf("first record fingerprint = %q, want reset plaintext fingerprint", firstRecord.KeyFingerprint)
|
||||
}
|
||||
if len(createdCustomKeys) < 3 {
|
||||
t.Fatalf("createdCustomKeys len = %d, want at least 3", len(createdCustomKeys))
|
||||
}
|
||||
if createdCustomKeys[0] == createdCustomKeys[1] || createdCustomKeys[0] == createdCustomKeys[2] {
|
||||
t.Fatalf("host custom keys were reused unexpectedly: %#v", createdCustomKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserKeyAPIMetricsMiddlewareAndCreateMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t))),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, openAppTestStore(t)), testUserKeyAuthConfig()),
|
||||
})
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("", "portal key", nil))
|
||||
req.Header.Set("X-Portal-Subject", "smoke-user")
|
||||
applyTrustedProxyAuthHeaders(req, "smoke-user")
|
||||
_ = httptestRecorder(handler, req)
|
||||
|
||||
metricsReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
|
||||
311
internal/app/portal_auth.go
Normal file
311
internal/app/portal_auth.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
portalSubjectCookieName = "crm_subject"
|
||||
portalSessionCookieName = "crm_session"
|
||||
defaultPortalSessionTTL = 30 * 24 * time.Hour // 30 days
|
||||
)
|
||||
|
||||
// PortalAuthConfig 定义 portal user session 配置
|
||||
type PortalAuthConfig struct {
|
||||
SessionSecret string // session cookie 签名密钥
|
||||
SessionTTL time.Duration // session 有效期
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
// portalSessionInfo 存储 session 信息
|
||||
type portalSessionInfo struct {
|
||||
SubjectID string
|
||||
Email string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// portalLoginRequest 登录请求
|
||||
type portalLoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"` // 仅用于验证,portal 采用"登录即注册"模式
|
||||
}
|
||||
|
||||
// normalized 返回规范化配置
|
||||
func (c PortalAuthConfig) normalized() PortalAuthConfig {
|
||||
if c.SessionTTL <= 0 {
|
||||
c.SessionTTL = defaultPortalSessionTTL
|
||||
}
|
||||
if c.Now == nil {
|
||||
c.Now = time.Now
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// normalizedSubjectID 规范化 subject ID
|
||||
func normalizedSubjectID(email string) string {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
return "portal-email:" + email
|
||||
}
|
||||
|
||||
// signSessionCookie 签名 session cookie 值
|
||||
func signSessionCookie(secret, subjectID string, expiresAt time.Time) string {
|
||||
if secret == "" || subjectID == "" {
|
||||
return ""
|
||||
}
|
||||
payload := subjectID + "|" + strconv.FormatInt(expiresAt.Unix(), 10)
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(payload))
|
||||
sig := hex.EncodeToString(mac.Sum(nil))
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(payload + "|" + sig))
|
||||
}
|
||||
|
||||
// verifySessionCookie 验证并解析 session cookie
|
||||
func verifySessionCookie(secret, raw string, now time.Time) (*portalSessionInfo, bool) {
|
||||
if secret == "" || raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
b, err := base64.RawURLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
parts := strings.SplitN(string(b), "|", 3)
|
||||
if len(parts) != 3 {
|
||||
return nil, false
|
||||
}
|
||||
subjectID, tsStr, sigHex := parts[0], parts[1], parts[2]
|
||||
|
||||
// 验证签名
|
||||
payload := subjectID + "|" + tsStr
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(payload))
|
||||
expectedSig := hex.EncodeToString(mac.Sum(nil))
|
||||
if subtle.ConstantTimeCompare([]byte(sigHex), []byte(expectedSig)) != 1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 解析过期时间
|
||||
unixSec, err := strconv.ParseInt(tsStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
expiresAt := time.Unix(unixSec, 0)
|
||||
if now.After(expiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &portalSessionInfo{
|
||||
SubjectID: subjectID,
|
||||
ExpiresAt: expiresAt,
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateSessionSecret 生成随机 session secret(32字节)
|
||||
func generateSessionSecret() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// extractSubjectFromCookie 从请求 cookie 中提取 subject
|
||||
func extractSubjectFromCookie(r *http.Request, sessionSecret string) string {
|
||||
if sessionSecret == "" {
|
||||
return ""
|
||||
}
|
||||
cookie, err := r.Cookie(portalSessionCookieName)
|
||||
if err != nil || cookie == nil || cookie.Value == "" {
|
||||
return ""
|
||||
}
|
||||
info, ok := verifySessionCookie(sessionSecret, cookie.Value, time.Now())
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return info.SubjectID
|
||||
}
|
||||
|
||||
// handlePortalSessionLogin 处理 portal user 登录
|
||||
// 设置 httpOnly cookie,返回 subject ID
|
||||
func handlePortalSessionLogin(w http.ResponseWriter, r *http.Request, cfg PortalAuthConfig) {
|
||||
cfg = cfg.normalized()
|
||||
|
||||
if cfg.SessionSecret == "" {
|
||||
writeHTTPError(w, &httpError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: "portal_auth_not_configured",
|
||||
Message: "Portal session authentication is not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req portalLoginRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(req.Email)
|
||||
if email == "" || !strings.Contains(email, "@") {
|
||||
writeHTTPError(w, &httpError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: "invalid_email",
|
||||
Message: "Valid email is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
subjectID := normalizedSubjectID(email)
|
||||
expiresAt := cfg.Now().Add(cfg.SessionTTL)
|
||||
|
||||
// 生成签名 cookie
|
||||
sessionValue := signSessionCookie(cfg.SessionSecret, subjectID, expiresAt)
|
||||
if sessionValue == "" {
|
||||
writeHTTPError(w, &httpError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Code: "session_sign_failed",
|
||||
Message: "Failed to sign session",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 httpOnly cookie(SameSite=Lax,Secure 建议生产启用 HTTPS)
|
||||
cookie := &http.Cookie{
|
||||
Name: portalSessionCookieName,
|
||||
Value: sessionValue,
|
||||
Path: "/",
|
||||
Expires: expiresAt,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
// 同时设置非 httpOnly cookie 供前端 JS 读取 subject(用于显示)
|
||||
subjectCookie := &http.Cookie{
|
||||
Name: portalSubjectCookieName,
|
||||
Value: subjectID,
|
||||
Path: "/",
|
||||
Expires: expiresAt,
|
||||
HttpOnly: false, // 允许前端读取
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
|
||||
}
|
||||
http.SetCookie(w, subjectCookie)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"authenticated": true,
|
||||
"subject_id": subjectID,
|
||||
"email": email,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// handlePortalSessionLogout 处理 portal user 登出
|
||||
// 清除 session cookie
|
||||
func handlePortalSessionLogout(w http.ResponseWriter, r *http.Request) {
|
||||
// 清除 session cookie
|
||||
sessionCookie := &http.Cookie{
|
||||
Name: portalSessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
}
|
||||
http.SetCookie(w, sessionCookie)
|
||||
|
||||
// 清除 subject cookie
|
||||
subjectCookie := &http.Cookie{
|
||||
Name: portalSubjectCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: false,
|
||||
}
|
||||
http.SetCookie(w, subjectCookie)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"authenticated": false,
|
||||
})
|
||||
}
|
||||
|
||||
// handlePortalSessionState 处理 portal session 状态查询
|
||||
func handlePortalSessionState(w http.ResponseWriter, r *http.Request, cfg PortalAuthConfig) {
|
||||
cfg = cfg.normalized()
|
||||
|
||||
if cfg.SessionSecret == "" {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"authenticated": false,
|
||||
"login_enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
subjectID := extractSubjectFromCookie(r, cfg.SessionSecret)
|
||||
if subjectID == "" {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"authenticated": false,
|
||||
"login_enabled": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"authenticated": true,
|
||||
"login_enabled": true,
|
||||
"subject_id": subjectID,
|
||||
})
|
||||
}
|
||||
|
||||
// requirePortalSubject 中间件:要求 portal session 认证
|
||||
// 与 trusted proxy header 的验证流程配合
|
||||
func requirePortalSubject(cfg PortalAuthConfig, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cfg = cfg.normalized()
|
||||
|
||||
// 首先检查 trusted proxy header
|
||||
// 这是生产环境的推荐做法:nginx 验证并设置 header
|
||||
if cfg.SessionSecret == "" {
|
||||
writeHTTPError(w, &httpError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Code: "unauthorized",
|
||||
Message: "Portal authentication not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
subjectID := extractSubjectFromCookie(r, cfg.SessionSecret)
|
||||
if subjectID == "" {
|
||||
writeHTTPError(w, &httpError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Code: "unauthorized",
|
||||
Message: "Valid session required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 将 subject 放入 context 供后续使用
|
||||
ctx := context.WithValue(r.Context(), "portal_subject", subjectID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// getPortalSubjectFromContext 从 context 获取 subject
|
||||
func getPortalSubjectFromContext(ctx context.Context) string {
|
||||
if v, ok := ctx.Value("portal_subject").(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
135
internal/app/portal_auth_test.go
Normal file
135
internal/app/portal_auth_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPortalSessionLoginSetsCookieAndReturnsSubject(t *testing.T) {
|
||||
cfg := PortalAuthConfig{
|
||||
SessionSecret: "test-secret-32-bytes-long-for-hmac",
|
||||
Now: func() time.Time { return time.Unix(1_717_000_000, 0) },
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{"email":"user@example.com"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlePortalSessionLogin(rec, req, cfg)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 检查响应体包含 subject
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, `"subject_id":"portal-email:user@example.com"`) {
|
||||
t.Fatalf("response body missing subject_id: %s", body)
|
||||
}
|
||||
|
||||
// 检查设置了 cookie
|
||||
cookies := rec.Result().Cookies()
|
||||
if len(cookies) < 2 {
|
||||
t.Fatalf("expected at least 2 cookies (session + subject), got %d", len(cookies))
|
||||
}
|
||||
|
||||
// 检查 session cookie 是 httpOnly
|
||||
var foundSessionCookie bool
|
||||
for _, c := range cookies {
|
||||
if c.Name == portalSessionCookieName {
|
||||
foundSessionCookie = true
|
||||
if !c.HttpOnly {
|
||||
t.Fatal("session cookie should be HttpOnly")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSessionCookie {
|
||||
t.Fatalf("session cookie %s not found", portalSessionCookieName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortalSessionLoginRejectsMissingEmail(t *testing.T) {
|
||||
cfg := PortalAuthConfig{
|
||||
SessionSecret: "test-secret",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlePortalSessionLogin(rec, req, cfg)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortalSessionLoginRequiresSecret(t *testing.T) {
|
||||
cfg := PortalAuthConfig{
|
||||
SessionSecret: "", // 未配置
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/login", strings.NewReader(`{"email":"user@example.com"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlePortalSessionLogin(rec, req, cfg)
|
||||
|
||||
if rec.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortalSessionLogoutClearsCookies(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/portal/session/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlePortalSessionLogout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 检查清除了 cookie
|
||||
cookies := rec.Result().Cookies()
|
||||
var clearedSession bool
|
||||
var clearedSubject bool
|
||||
for _, c := range cookies {
|
||||
if c.Name == portalSessionCookieName && c.MaxAge == -1 {
|
||||
clearedSession = true
|
||||
}
|
||||
if c.Name == portalSubjectCookieName && c.MaxAge == -1 {
|
||||
clearedSubject = true
|
||||
}
|
||||
}
|
||||
if !clearedSession {
|
||||
t.Fatal("session cookie should be cleared")
|
||||
}
|
||||
if !clearedSubject {
|
||||
t.Fatal("subject cookie should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortalSessionStateUnauthenticatedWhenNoCookie(t *testing.T) {
|
||||
cfg := PortalAuthConfig{
|
||||
SessionSecret: "test-secret",
|
||||
Now: func() time.Time { return time.Unix(1_717_000_000, 0) },
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/portal/session", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlePortalSessionState(rec, req, cfg)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, `"authenticated":false`) {
|
||||
t.Fatalf("expected unauthenticated, got: %s", body)
|
||||
}
|
||||
}
|
||||
@@ -58,6 +58,190 @@ func TestPublicV1ChatCompletionsQuotaExhaustedRecordsMetric(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicV1ChatCompletionsPropagatesUpstreamFailureStatusAndMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
const plaintextKey = "sk-test-upstream-429"
|
||||
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
|
||||
KeyID: "key_upstream_429",
|
||||
OwnerSubjectID: "portal-user",
|
||||
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
|
||||
MaskedPreview: "sk-****-429",
|
||||
DisplayName: "upstream 429 key",
|
||||
LogicalGroupID: "gpt-shared",
|
||||
AllowedModels: []string{"gpt-5.4"},
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
}); err != nil {
|
||||
t.Fatalf("UserKeys().Create() error = %v", err)
|
||||
}
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
return ProxyRouteChatCompletionsResult{
|
||||
Forward: RouteChatCompletionsForwardInfo{
|
||||
OK: false,
|
||||
UpstreamStatus: http.StatusTooManyRequests,
|
||||
ErrorClass: "gateway_rate_limited",
|
||||
Response: map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "upstream_rate_limited",
|
||||
"message": "upstream rejected request",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, appTestDSN(t, store))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
|
||||
req.Header.Set("Authorization", "Bearer "+plaintextKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want 429 body=%s", resp.code, resp.Body().String())
|
||||
}
|
||||
assertJSONContains(t, resp.Body().Bytes(), "error.code", "upstream_rate_limited")
|
||||
|
||||
metricsReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
metricsResp := httptest.NewRecorder()
|
||||
metrics.Handler().ServeHTTP(metricsResp, metricsReq)
|
||||
body := metricsResp.Body.String()
|
||||
if !strings.Contains(body, `user_key_chat_requests_total{result="gateway_rate_limited"}`) {
|
||||
t.Fatalf("metrics body missing gateway_rate_limited metric: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicV1ChatCompletionsRejectsDisallowedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
const plaintextKey = "sk-test-disallowed-model"
|
||||
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
|
||||
KeyID: "key_disallowed_model",
|
||||
OwnerSubjectID: "portal-user",
|
||||
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
|
||||
MaskedPreview: "sk-****odel",
|
||||
DisplayName: "model restricted key",
|
||||
LogicalGroupID: "gpt-shared",
|
||||
AllowedModels: []string{"gpt-4.1"},
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
}); err != nil {
|
||||
t.Fatalf("UserKeys().Create() error = %v", err)
|
||||
}
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
t.Fatal("proxy should not be called for disallowed model")
|
||||
return ProxyRouteChatCompletionsResult{}, nil
|
||||
},
|
||||
}, appTestDSN(t, store))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
|
||||
req.Header.Set("Authorization", "Bearer "+plaintextKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusForbidden {
|
||||
t.Fatalf("status code = %d, want 403 body=%s", resp.code, resp.Body().String())
|
||||
}
|
||||
assertJSONContains(t, resp.Body().Bytes(), "error.code", "model_not_allowed")
|
||||
}
|
||||
|
||||
func TestPublicV1ChatCompletionsRejectsExpiredKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
const plaintextKey = "sk-test-expired-key"
|
||||
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
|
||||
KeyID: "key_expired",
|
||||
OwnerSubjectID: "portal-user",
|
||||
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
|
||||
MaskedPreview: "sk-****ired",
|
||||
DisplayName: "expired key",
|
||||
LogicalGroupID: "gpt-shared",
|
||||
AllowedModels: []string{"gpt-5.4"},
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
}); err != nil {
|
||||
t.Fatalf("UserKeys().Create() error = %v", err)
|
||||
}
|
||||
if _, err := store.SQLDB().ExecContext(context.Background(), `UPDATE user_keys SET expires_at = ? WHERE key_id = ?`, "2020-01-01T00:00:00Z", "key_expired"); err != nil {
|
||||
t.Fatalf("set expires_at error = %v", err)
|
||||
}
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
t.Fatal("proxy should not be called for expired key")
|
||||
return ProxyRouteChatCompletionsResult{}, nil
|
||||
},
|
||||
}, appTestDSN(t, store))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
|
||||
req.Header.Set("Authorization", "Bearer "+plaintextKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusForbidden {
|
||||
t.Fatalf("status code = %d, want 403 body=%s", resp.code, resp.Body().String())
|
||||
}
|
||||
assertJSONContains(t, resp.Body().Bytes(), "error.code", "key_expired")
|
||||
}
|
||||
|
||||
func TestPublicV1ChatCompletionsTouchesLastUsedAtOnSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := openAppTestStore(t)
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
const plaintextKey = "sk-test-last-used"
|
||||
if _, err := store.UserKeys().Create(context.Background(), sqlite.UserKeyRecord{
|
||||
KeyID: "key_last_used",
|
||||
OwnerSubjectID: "portal-user",
|
||||
KeyFingerprint: "sha256:" + sha256Hex(plaintextKey),
|
||||
MaskedPreview: "sk-****used",
|
||||
DisplayName: "active key",
|
||||
LogicalGroupID: "gpt-shared",
|
||||
AllowedModels: []string{"gpt-5.4"},
|
||||
AdminStatus: "active",
|
||||
QuotaStatus: "ok",
|
||||
}); err != nil {
|
||||
t.Fatalf("UserKeys().Create() error = %v", err)
|
||||
}
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
ProxyRouteChatCompletions: func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
return ProxyRouteChatCompletionsResult{Forward: RouteChatCompletionsForwardInfo{OK: true, UpstreamStatus: http.StatusOK, Response: map[string]any{"id": "chatcmpl_ok", "object": "chat.completion", "model": "gpt-5.4", "choices": []map[string]any{{"index": 0, "message": map[string]any{"role": "assistant", "content": "pong"}, "finish_reason": "stop"}}}}}, nil
|
||||
},
|
||||
}, appTestDSN(t, store))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"gpt-5.4","messages":[{"role":"user","content":"ping"}]}`))
|
||||
req.Header.Set("Authorization", "Bearer "+plaintextKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusOK {
|
||||
t.Fatalf("status code = %d, want 200 body=%s", resp.code, resp.Body().String())
|
||||
}
|
||||
|
||||
record, err := store.UserKeys().GetByID(context.Background(), "key_last_used")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if strings.TrimSpace(record.LastUsedAt) == "" {
|
||||
t.Fatalf("LastUsedAt = %q, want non-empty after successful chat", record.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMiddlewareUsesRoutePatternForKeyReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -16,11 +16,11 @@ func TestUserKeyCreateResolveHostErrorRecordsMetric(t *testing.T) {
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store)),
|
||||
UserKeyHandler: buildUserKeyHandler(appTestDSN(t, store), testUserKeyAuthConfig()),
|
||||
})
|
||||
|
||||
req := makeCreateRequest(t, http.MethodPost, "/api/keys", makeCreateBody("missing-group", "portal key", []string{"gpt-5.4"}))
|
||||
req.Header.Set("X-Portal-Subject", "portal-user")
|
||||
applyTrustedProxyAuthHeaders(req, "portal-user")
|
||||
resp := httptestRecorder(handler, req)
|
||||
if resp.code != http.StatusInternalServerError {
|
||||
t.Fatalf("status code = %d, want 500 body=%s", resp.code, resp.Body().String())
|
||||
|
||||
@@ -9,26 +9,30 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvListenAddr = "SUB2API_CRM_LISTEN_ADDR"
|
||||
EnvSQLiteDSN = "SUB2API_CRM_SQLITE_DSN"
|
||||
EnvAdminToken = "SUB2API_CRM_ADMIN_TOKEN"
|
||||
EnvAdminUsername = "SUB2API_CRM_ADMIN_USERNAME"
|
||||
EnvAdminPassword = "SUB2API_CRM_ADMIN_PASSWORD"
|
||||
EnvAdminSessionTTL = "SUB2API_CRM_ADMIN_SESSION_TTL"
|
||||
EnvRepoRoot = "SUB2API_CRM_REPO_ROOT"
|
||||
EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED"
|
||||
EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL"
|
||||
EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND"
|
||||
EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR"
|
||||
EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD"
|
||||
EnvRedisDB = "SUB2API_CRM_REDIS_DB"
|
||||
EnvListenAddr = "SUB2API_CRM_LISTEN_ADDR"
|
||||
EnvSQLiteDSN = "SUB2API_CRM_SQLITE_DSN"
|
||||
EnvAdminToken = "SUB2API_CRM_ADMIN_TOKEN"
|
||||
EnvAdminUsername = "SUB2API_CRM_ADMIN_USERNAME"
|
||||
EnvAdminPassword = "SUB2API_CRM_ADMIN_PASSWORD"
|
||||
EnvAdminSessionTTL = "SUB2API_CRM_ADMIN_SESSION_TTL"
|
||||
EnvRepoRoot = "SUB2API_CRM_REPO_ROOT"
|
||||
EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED"
|
||||
EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL"
|
||||
EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND"
|
||||
EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR"
|
||||
EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD"
|
||||
EnvRedisDB = "SUB2API_CRM_REDIS_DB"
|
||||
EnvTrustedSubjectHeader = "SUB2API_CRM_TRUSTED_SUBJECT_HEADER"
|
||||
EnvTrustedProxySecretHeader = "SUB2API_CRM_TRUSTED_PROXY_SECRET_HEADER"
|
||||
EnvTrustedProxySecret = "SUB2API_CRM_TRUSTED_PROXY_SECRET"
|
||||
|
||||
DefaultListenAddr = ":8080"
|
||||
DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000"
|
||||
DefaultAdminUsername = "admin"
|
||||
DefaultAdminSessionTTL = 12 * time.Hour
|
||||
DefaultReconcilePollInterval = 10 * time.Minute
|
||||
DefaultRouteRuntimeBackend = "memory"
|
||||
DefaultListenAddr = ":8080"
|
||||
DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000"
|
||||
DefaultAdminUsername = "admin"
|
||||
DefaultAdminSessionTTL = 12 * time.Hour
|
||||
DefaultReconcilePollInterval = 10 * time.Minute
|
||||
DefaultRouteRuntimeBackend = "memory"
|
||||
DefaultTrustedProxySecretHeader = "X-CRM-Trusted-Proxy"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -59,10 +63,17 @@ type RepositoryConfig struct {
|
||||
RepoRoot string
|
||||
}
|
||||
|
||||
type UserKeyAuthConfig struct {
|
||||
TrustedSubjectHeader string
|
||||
TrustedProxySecretHeader string
|
||||
TrustedProxySecret string
|
||||
}
|
||||
|
||||
type StartupConfig struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Repository RepositoryConfig
|
||||
UserKeyAuth UserKeyAuthConfig
|
||||
RouteRuntime RouteRuntimeConfig
|
||||
Reconcile ReconcileConfig
|
||||
}
|
||||
@@ -96,6 +107,11 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig
|
||||
Repository: RepositoryConfig{
|
||||
RepoRoot: readOptionalEnv(lookup, EnvRepoRoot, ""),
|
||||
},
|
||||
UserKeyAuth: UserKeyAuthConfig{
|
||||
TrustedSubjectHeader: readOptionalEnv(lookup, EnvTrustedSubjectHeader, ""),
|
||||
TrustedProxySecretHeader: readOptionalEnv(lookup, EnvTrustedProxySecretHeader, DefaultTrustedProxySecretHeader),
|
||||
TrustedProxySecret: readOptionalEnv(lookup, EnvTrustedProxySecret, ""),
|
||||
},
|
||||
RouteRuntime: RouteRuntimeConfig{
|
||||
Backend: readOptionalEnv(lookup, EnvRouteRuntimeBackend, DefaultRouteRuntimeBackend),
|
||||
Redis: RedisRuntimeConfig{
|
||||
|
||||
@@ -77,6 +77,12 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
|
||||
return " redis-pass ", true
|
||||
case EnvRedisDB:
|
||||
return "5", true
|
||||
case EnvTrustedSubjectHeader:
|
||||
return "X-CRM-Authenticated-Subject", true
|
||||
case EnvTrustedProxySecretHeader:
|
||||
return "X-CRM-Trusted-Proxy", true
|
||||
case EnvTrustedProxySecret:
|
||||
return "proxy-secret", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
@@ -112,6 +118,15 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
|
||||
if cfg.RouteRuntime.Redis.DB != 5 {
|
||||
t.Fatalf("RouteRuntime.Redis.DB = %d, want 5", cfg.RouteRuntime.Redis.DB)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedSubjectHeader != "X-CRM-Authenticated-Subject" {
|
||||
t.Fatalf("UserKeyAuth.TrustedSubjectHeader = %q, want X-CRM-Authenticated-Subject", cfg.UserKeyAuth.TrustedSubjectHeader)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedProxySecretHeader != "X-CRM-Trusted-Proxy" {
|
||||
t.Fatalf("UserKeyAuth.TrustedProxySecretHeader = %q, want X-CRM-Trusted-Proxy", cfg.UserKeyAuth.TrustedProxySecretHeader)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedProxySecret != "proxy-secret" {
|
||||
t.Fatalf("UserKeyAuth.TrustedProxySecret = %q, want proxy-secret", cfg.UserKeyAuth.TrustedProxySecret)
|
||||
}
|
||||
})
|
||||
t.Run("default values", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
@@ -142,6 +157,15 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
|
||||
if cfg.RouteRuntime.Redis.Addr != "" || cfg.RouteRuntime.Redis.Password != "" || cfg.RouteRuntime.Redis.DB != 0 {
|
||||
t.Fatalf("RouteRuntime.Redis = %+v, want zero value", cfg.RouteRuntime.Redis)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedSubjectHeader != "" {
|
||||
t.Fatalf("UserKeyAuth.TrustedSubjectHeader = %q, want empty by default", cfg.UserKeyAuth.TrustedSubjectHeader)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedProxySecretHeader != DefaultTrustedProxySecretHeader {
|
||||
t.Fatalf("UserKeyAuth.TrustedProxySecretHeader = %q, want %q", cfg.UserKeyAuth.TrustedProxySecretHeader, DefaultTrustedProxySecretHeader)
|
||||
}
|
||||
if cfg.UserKeyAuth.TrustedProxySecret != "" {
|
||||
t.Fatalf("UserKeyAuth.TrustedProxySecret = %q, want empty by default", cfg.UserKeyAuth.TrustedProxySecret)
|
||||
}
|
||||
})
|
||||
t.Run("invalid reconcile interval", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE user_keys ADD COLUMN managed_identity_selector TEXT NOT NULL DEFAULT '';
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_managed_identity_selector ON user_keys(managed_identity_selector);
|
||||
@@ -85,15 +85,15 @@ func TestUserKeysRepoUpdateSecret(t *testing.T) {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if err := store.UserKeys().UpdateSecret(ctx, "key_rotate_001", "sha256:new", "sk-****new1", "active"); err != nil {
|
||||
if err := store.UserKeys().UpdateSecret(ctx, "key_rotate_001", "subject|key:key_rotate_001|rot:key_nonce", "sha256:new", "sk-****new1", "active"); err != nil {
|
||||
t.Fatalf("UpdateSecret() error = %v", err)
|
||||
}
|
||||
key, err := store.UserKeys().GetByID(ctx, "key_rotate_001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if key.KeyFingerprint != "sha256:new" || key.MaskedPreview != "sk-****new1" || key.AdminStatus != "active" {
|
||||
t.Fatalf("updated key = %+v, want new fingerprint/mask/status", key)
|
||||
if key.ManagedIdentitySelector != "subject|key:key_rotate_001|rot:key_nonce" || key.KeyFingerprint != "sha256:new" || key.MaskedPreview != "sk-****new1" || key.AdminStatus != "active" {
|
||||
t.Fatalf("updated key = %+v, want new selector/fingerprint/mask/status", key)
|
||||
}
|
||||
if strings.TrimSpace(key.UpdatedAt) == "" {
|
||||
t.Fatalf("UpdatedAt = %q, want non-empty", key.UpdatedAt)
|
||||
|
||||
@@ -9,20 +9,21 @@ import (
|
||||
)
|
||||
|
||||
type UserKeyRecord struct {
|
||||
ID int64 `json:"-"`
|
||||
KeyID string `json:"key_id"`
|
||||
OwnerSubjectID string `json:"owner_subject_id"`
|
||||
KeyFingerprint string `json:"key_fingerprint"`
|
||||
MaskedPreview string `json:"masked_preview"`
|
||||
DisplayName string `json:"display_name"`
|
||||
LogicalGroupID string `json:"logical_group_id"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AdminStatus string `json:"admin_status"`
|
||||
QuotaStatus string `json:"quota_status"`
|
||||
LastUsedAt string `json:"last_used_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
ID int64 `json:"-"`
|
||||
KeyID string `json:"key_id"`
|
||||
OwnerSubjectID string `json:"owner_subject_id"`
|
||||
ManagedIdentitySelector string `json:"-"`
|
||||
KeyFingerprint string `json:"key_fingerprint"`
|
||||
MaskedPreview string `json:"masked_preview"`
|
||||
DisplayName string `json:"display_name"`
|
||||
LogicalGroupID string `json:"logical_group_id"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AdminStatus string `json:"admin_status"`
|
||||
QuotaStatus string `json:"quota_status"`
|
||||
LastUsedAt string `json:"last_used_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserKeysRepo struct {
|
||||
@@ -40,11 +41,11 @@ func (r *UserKeysRepo) Create(ctx context.Context, key UserKeyRecord) (int64, er
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO user_keys (
|
||||
key_id, owner_subject_id, key_fingerprint, masked_preview,
|
||||
key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
|
||||
display_name, logical_group_id, allowed_models,
|
||||
admin_status, quota_status
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
key.KeyID, key.OwnerSubjectID, key.KeyFingerprint, key.MaskedPreview,
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
key.KeyID, key.OwnerSubjectID, key.ManagedIdentitySelector, key.KeyFingerprint, key.MaskedPreview,
|
||||
key.DisplayName, key.LogicalGroupID, string(modelsJSON),
|
||||
key.AdminStatus, key.QuotaStatus,
|
||||
)
|
||||
@@ -60,7 +61,7 @@ func scanUserKeys(rows *sql.Rows) ([]UserKeyRecord, error) {
|
||||
var k UserKeyRecord
|
||||
var modelsJSON, lastUsedAt, expiresAt sql.NullString
|
||||
err := rows.Scan(
|
||||
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.KeyFingerprint, &k.MaskedPreview,
|
||||
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.ManagedIdentitySelector, &k.KeyFingerprint, &k.MaskedPreview,
|
||||
&k.DisplayName, &k.LogicalGroupID, &modelsJSON,
|
||||
&k.AdminStatus, &k.QuotaStatus, &lastUsedAt, &k.CreatedAt, &expiresAt, &k.UpdatedAt,
|
||||
)
|
||||
@@ -81,7 +82,7 @@ func scanOneUserKey(row *sql.Row) (*UserKeyRecord, error) {
|
||||
var k UserKeyRecord
|
||||
var modelsJSON, lastUsedAt, expiresAt sql.NullString
|
||||
err := row.Scan(
|
||||
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.KeyFingerprint, &k.MaskedPreview,
|
||||
&k.ID, &k.KeyID, &k.OwnerSubjectID, &k.ManagedIdentitySelector, &k.KeyFingerprint, &k.MaskedPreview,
|
||||
&k.DisplayName, &k.LogicalGroupID, &modelsJSON,
|
||||
&k.AdminStatus, &k.QuotaStatus, &lastUsedAt, &k.CreatedAt, &expiresAt, &k.UpdatedAt,
|
||||
)
|
||||
@@ -98,7 +99,7 @@ func scanOneUserKey(row *sql.Row) (*UserKeyRecord, error) {
|
||||
|
||||
func (r *UserKeysRepo) ListByOwner(ctx context.Context, subjectID string) ([]UserKeyRecord, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
|
||||
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
|
||||
display_name, logical_group_id, allowed_models,
|
||||
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
|
||||
FROM user_keys WHERE owner_subject_id = ? ORDER BY created_at DESC`, subjectID)
|
||||
@@ -111,7 +112,7 @@ func (r *UserKeysRepo) ListByOwner(ctx context.Context, subjectID string) ([]Use
|
||||
|
||||
func (r *UserKeysRepo) ListByFingerprint(ctx context.Context, fingerprint string) ([]UserKeyRecord, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
|
||||
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
|
||||
display_name, logical_group_id, allowed_models,
|
||||
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
|
||||
FROM user_keys WHERE key_fingerprint = ? ORDER BY created_at DESC`, fingerprint)
|
||||
@@ -124,7 +125,7 @@ func (r *UserKeysRepo) ListByFingerprint(ctx context.Context, fingerprint string
|
||||
|
||||
func (r *UserKeysRepo) GetByID(ctx context.Context, keyID string) (*UserKeyRecord, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, key_id, owner_subject_id, key_fingerprint, masked_preview,
|
||||
SELECT id, key_id, owner_subject_id, managed_identity_selector, key_fingerprint, masked_preview,
|
||||
display_name, logical_group_id, allowed_models,
|
||||
admin_status, quota_status, last_used_at, created_at, expires_at, updated_at
|
||||
FROM user_keys WHERE key_id = ?`, keyID)
|
||||
@@ -154,8 +155,9 @@ func (r *UserKeysRepo) UpdateStatus(ctx context.Context, keyID string, adminStat
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, fingerprint, maskedPreview, adminStatus string) error {
|
||||
func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, managedIdentitySelector, fingerprint, maskedPreview, adminStatus string) error {
|
||||
keyID = strings.TrimSpace(keyID)
|
||||
managedIdentitySelector = strings.TrimSpace(managedIdentitySelector)
|
||||
fingerprint = strings.TrimSpace(fingerprint)
|
||||
maskedPreview = strings.TrimSpace(maskedPreview)
|
||||
adminStatus = strings.ToLower(strings.TrimSpace(adminStatus))
|
||||
@@ -174,9 +176,9 @@ func (r *UserKeysRepo) UpdateSecret(ctx context.Context, keyID, fingerprint, mas
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE user_keys
|
||||
SET key_fingerprint = ?, masked_preview = ?, admin_status = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')
|
||||
SET managed_identity_selector = ?, key_fingerprint = ?, masked_preview = ?, admin_status = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')
|
||||
WHERE key_id = ?`,
|
||||
fingerprint, maskedPreview, adminStatus, keyID,
|
||||
managedIdentitySelector, fingerprint, maskedPreview, adminStatus, keyID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user_key secret: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user