test(project): achieve ≥70% package coverage across all internal packages
- store/sqlite: 75.4% (repos + db coverage) - host/sub2api: 80.8% (httptest mock server, pure function tests) - app: 74.2% (handler error paths, NewActionSet closures) - pack: 72.4% - provision: 75.2% - access: 77.3% - config: 94.7% (lookup mock tests) All tests pass: build, vet, race, coverage gates.
This commit is contained in:
@@ -15,25 +15,20 @@ type Server struct {
|
||||
listen ListenerFactory
|
||||
}
|
||||
|
||||
func NewServer(listenAddr string, listenerFactory ListenerFactory) *Server {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
func NewServer(listenAddr string, handler http.Handler, listenerFactory ListenerFactory) *Server {
|
||||
if handler == nil {
|
||||
handler = NewAPIHandler("", ActionSet{})
|
||||
}
|
||||
server := &Server{
|
||||
server: &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: mux,
|
||||
Handler: handler,
|
||||
},
|
||||
listen: net.Listen,
|
||||
}
|
||||
|
||||
if listenerFactory != nil {
|
||||
server.listen = listenerFactory
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
@@ -46,13 +41,11 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.Serve(ctx, listener)
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := s.server.Serve(listener)
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
@@ -65,11 +58,9 @@ func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.server.Shutdown(shutdownCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return <-errCh
|
||||
case err := <-errCh:
|
||||
return err
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestServeExposesHealthz(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:0", nil)
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
@@ -50,7 +59,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
|
||||
return listener, nil
|
||||
})
|
||||
|
||||
@@ -77,7 +86,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
|
||||
|
||||
func TestRunReturnsListenError(t *testing.T) {
|
||||
wantErr := errors.New("listen failed")
|
||||
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
|
||||
return nil, wantErr
|
||||
})
|
||||
|
||||
@@ -88,7 +97,7 @@ func TestRunReturnsListenError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeReturnsListenerError(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:0", nil)
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
@@ -104,6 +113,208 @@ func TestServeReturnsListenerError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRejectsMissingAdminToken(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/pack.zip"}, "")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusUnauthorized)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "unauthorized")
|
||||
}
|
||||
|
||||
func TestAPIInstallPackReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
return provision.PackInstallResult{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
|
||||
HostVersion: "0.1.126",
|
||||
Providers: []sqlite.Provider{{ProviderID: "deepseek", DisplayName: "DeepSeek"}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
|
||||
assertJSONContains(t, response.Body().Bytes(), "host_version", "0.1.126")
|
||||
}
|
||||
|
||||
func TestAPIPreviewProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
PreviewProvider: func(_ context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.PreviewReport{
|
||||
AcceptedKeys: []string{"k1", "k2"},
|
||||
Names: provision.ResourceNames{Group: "g", Channel: "c", Plan: "p"},
|
||||
Decisions: map[string]provision.PreviewDecision{
|
||||
"group": {Action: provision.PreviewActionCreate},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1", "k2"}, "mode": "partial"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "accepted_keys_count", float64(2))
|
||||
}
|
||||
|
||||
func TestAPIImportProviderReturnsConflictWithBatchStatus(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
return provision.RuntimeImportResult{
|
||||
BatchID: 12,
|
||||
Report: provision.ImportReport{
|
||||
BatchStatus: provision.BatchStatusFailed,
|
||||
ProviderStatus: provision.ProviderStatusFailed,
|
||||
AccessStatus: provision.AccessStatusBroken,
|
||||
Accounts: []provision.AccountImportResult{{}},
|
||||
},
|
||||
}, errors.New("strict import failed")
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1"}, "mode": "strict", "access_mode": "self_service", "access_api_key": "user-key"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusConflict)
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(12))
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_status", provision.BatchStatusFailed)
|
||||
}
|
||||
|
||||
func TestAPIBatchDetailReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
return provision.BatchDetailResult{
|
||||
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: "running", AccessStatus: "pending"},
|
||||
Items: []sqlite.ImportBatchItem{{ID: 1, KeyFingerprint: "sha256:abc", AccountStatus: "passed"}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/import-batches/7", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch.batch_status", "running")
|
||||
assertJSONContains(t, response.Body().Bytes(), "items_count", float64(1))
|
||||
}
|
||||
|
||||
func TestAPIProviderStatusReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
if req.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want openai-cn-pack", req.PackID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Host: sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126"},
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek", DisplayName: "DeepSeek", Platform: "openai"},
|
||||
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSelfServiceReady, Mode: provision.ImportModeStrict},
|
||||
ProviderStatus: "drifted",
|
||||
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
LatestReconcileStatus: "drifted",
|
||||
LatestReconcileSummary: map[string]any{"missing_count": 1},
|
||||
ManagedResources: []sqlite.ManagedResource{{}, {}},
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{}},
|
||||
ReconcileRuns: []sqlite.ReconcileRun{{}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/status?pack_id=openai-cn-pack", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_status", "drifted")
|
||||
assertJSONContains(t, response.Body().Bytes(), "managed_resources_count", float64(2))
|
||||
assertJSONContains(t, response.Body().Bytes(), "latest_reconcile_summary.missing_count", float64(1))
|
||||
}
|
||||
|
||||
func TestAPIProviderAccessStatusReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderAccessStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek"},
|
||||
Batch: sqlite.ImportBatch{ID: 7, AccessStatus: provision.AccessStatusSelfServiceReady},
|
||||
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/access/status", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "latest_access_status", provision.AccessStatusSelfServiceReady)
|
||||
assertJSONContains(t, response.Body().Bytes(), "closures_count", float64(1))
|
||||
if !strings.Contains(response.Body().String(), `"closure_type":"self_service"`) {
|
||||
t.Fatalf("access status payload missing closure type: %s", response.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIProviderResourcesReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderResources: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek"},
|
||||
Batch: sqlite.ImportBatch{ID: 7},
|
||||
ManagedResources: []sqlite.ManagedResource{{ID: 1, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}},
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
|
||||
ReconcileRuns: []sqlite.ReconcileRun{{ID: 3, Status: "active", SummaryJSON: `{"missing_count":0}`}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/resources", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
|
||||
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
|
||||
if !strings.Contains(response.Body().String(), `"resource_type":"group"`) {
|
||||
t.Fatalf("resources payload missing group resource: %s", response.Body().String())
|
||||
}
|
||||
if !strings.Contains(response.Body().String(), `"status":"self_service_ready"`) {
|
||||
t.Fatalf("resources payload missing access closure status: %s", response.Body().String())
|
||||
}
|
||||
if !strings.Contains(response.Body().String(), `"summary_json":"{\"missing_count\":0}"`) {
|
||||
t.Fatalf("resources payload missing reconcile summary: %s", response.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRollbackProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{AccountsDeleted: 2, PlansDeleted: 1, ChannelsDeleted: 1, GroupsDeleted: 1}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "deleted_accounts", float64(2))
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
|
||||
}
|
||||
|
||||
func TestAPIReconcileProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ReconcileProvider: func(_ context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
if req.AccessAPIKey != "user-key" {
|
||||
t.Fatalf("AccessAPIKey = %q, want user-key", req.AccessAPIKey)
|
||||
}
|
||||
return provision.ReconcileResult{BatchID: 7, Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken, Summary: map[string]any{"probe_failures": 1}}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/reconcile", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "access_api_key": "user-key"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "status", "drifted")
|
||||
assertJSONContains(t, response.Body().Bytes(), "missing_count", float64(1))
|
||||
assertJSONContains(t, response.Body().Bytes(), "summary.probe_failures", float64(1))
|
||||
}
|
||||
|
||||
func waitForHealthz(t *testing.T, url string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
@@ -126,3 +337,613 @@ func waitForHealthz(t *testing.T, url string) *http.Response {
|
||||
t.Fatalf("health endpoint %q was not reachable before deadline", url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func httptestRequest(t *testing.T, method, path string, body any, token string) *http.Request {
|
||||
t.Helper()
|
||||
payload, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
request, err := http.NewRequest(method, path, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func httptestRecorder(handler http.Handler, request *http.Request) *responseRecorder {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
handler.ServeHTTP(recorder, request)
|
||||
return recorder
|
||||
}
|
||||
|
||||
type responseRecorder struct {
|
||||
header http.Header
|
||||
body bytes.Buffer
|
||||
code int
|
||||
}
|
||||
|
||||
func (r *responseRecorder) Header() http.Header { return r.header }
|
||||
func (r *responseRecorder) Write(body []byte) (int, error) { return r.body.Write(body) }
|
||||
func (r *responseRecorder) WriteHeader(statusCode int) { r.code = statusCode }
|
||||
func (r *responseRecorder) Body() *bytes.Buffer { return &r.body }
|
||||
|
||||
func assertStatusCode(t *testing.T, recorder *responseRecorder, want int) {
|
||||
t.Helper()
|
||||
if recorder.code != want {
|
||||
t.Fatalf("status code = %d, want %d; body=%s", recorder.code, want, recorder.body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAddrReturnsConfiguredAddress(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:9999", nil, nil)
|
||||
if got := server.Addr(); got != "127.0.0.1:9999" {
|
||||
t.Fatalf("Addr() = %q, want %q", got, "127.0.0.1:9999")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantCode string
|
||||
wantUpstream int
|
||||
}{
|
||||
{name: "nil", err: nil},
|
||||
{name: "http error passthrough", err: &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "brew"}, wantStatusCode: http.StatusTeapot, wantCode: "teapot"},
|
||||
{name: "upstream error", err: &sub2api.HTTPError{Method: http.MethodGet, Path: "/x", StatusCode: http.StatusForbidden, Body: "nope"}, wantStatusCode: http.StatusBadGateway, wantCode: "host_request_failed", wantUpstream: http.StatusForbidden},
|
||||
{name: "pack conflict already installed", err: errors.New("pack already installed"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
|
||||
{name: "pack conflict checksum drift", err: errors.New("checksum drift detected"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
|
||||
{name: "provider not found", err: errors.New("provider \"deepseek\" not found in pack \"openai\""), wantStatusCode: http.StatusBadRequest, wantCode: "provider_not_found"},
|
||||
{name: "bad request pack path", err: errors.New("pack path is required"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
|
||||
{name: "bad request decode", err: errors.New("decode pack.json failed"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
|
||||
{name: "internal error", err: errors.New("boom"), wantStatusCode: http.StatusInternalServerError, wantCode: "internal_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyError(tt.err)
|
||||
if tt.err == nil {
|
||||
if got != nil {
|
||||
t.Fatalf("classifyError(nil) = %#v, want nil", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("classifyError() = nil, want error")
|
||||
}
|
||||
if got.StatusCode != tt.wantStatusCode {
|
||||
t.Fatalf("StatusCode = %d, want %d", got.StatusCode, tt.wantStatusCode)
|
||||
}
|
||||
if got.Code != tt.wantCode {
|
||||
t.Fatalf("Code = %q, want %q", got.Code, tt.wantCode)
|
||||
}
|
||||
if got.UpstreamStatus != tt.wantUpstream {
|
||||
t.Fatalf("UpstreamStatus = %d, want %d", got.UpstreamStatus, tt.wantUpstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteHTTPError(t *testing.T) {
|
||||
t.Run("default error when nil", func(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeHTTPError(recorder, nil)
|
||||
assertStatusCode(t, recorder, http.StatusInternalServerError)
|
||||
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "internal_error")
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.message", "internal server error")
|
||||
})
|
||||
|
||||
t.Run("writes provided error", func(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeHTTPError(recorder, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "invalid input", UpstreamStatus: http.StatusConflict})
|
||||
assertStatusCode(t, recorder, http.StatusBadRequest)
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.upstream_status", float64(http.StatusConflict))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeJSON(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","pack_path":"/tmp/pack.zip"}`))
|
||||
var got InstallPackRequest
|
||||
if err := decodeJSON(request, &got); err != nil {
|
||||
t.Fatalf("decodeJSON() error = %v, want nil", err)
|
||||
}
|
||||
if got.HostBaseURL != "https://example.com" || got.PackPath != "/tmp/pack.zip" {
|
||||
t.Fatalf("decoded request = %#v, want expected fields", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unknown fields", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","unknown":true}`))
|
||||
var got InstallPackRequest
|
||||
err := decodeJSON(request, &got)
|
||||
if err == nil {
|
||||
t.Fatal("decodeJSON() error = nil, want error")
|
||||
}
|
||||
if err.StatusCode != http.StatusBadRequest || err.Code != "bad_request" {
|
||||
t.Fatalf("decodeJSON() = %#v, want bad_request", err)
|
||||
}
|
||||
if !strings.Contains(err.Message, "unknown field") {
|
||||
t.Fatalf("Message = %q, want unknown field", err.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects trailing non-object payload", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com"}[]`))
|
||||
var got InstallPackRequest
|
||||
err := decodeJSON(request, &got)
|
||||
if err == nil {
|
||||
t.Fatal("decodeJSON() error = nil, want error")
|
||||
}
|
||||
if err.Message != "request body must contain a single JSON object" {
|
||||
t.Fatalf("Message = %q, want single object error", err.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeJSON(recorder, http.StatusCreated, map[string]any{"ok": true, "count": 2})
|
||||
assertStatusCode(t, recorder, http.StatusCreated)
|
||||
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "ok", true)
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "count", float64(2))
|
||||
}
|
||||
|
||||
func TestFindProvider(t *testing.T) {
|
||||
loaded := pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack"},
|
||||
Providers: []pack.ProviderManifest{
|
||||
{ProviderID: "deepseek", DisplayName: "DeepSeek"},
|
||||
{ProviderID: "openai", DisplayName: "OpenAI"},
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := findProvider(loaded, " deepseek ")
|
||||
if err != nil {
|
||||
t.Fatalf("findProvider() error = %v, want nil", err)
|
||||
}
|
||||
if provider.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", provider.ProviderID)
|
||||
}
|
||||
|
||||
_, err = findProvider(loaded, "missing")
|
||||
if err == nil {
|
||||
t.Fatal("findProvider() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `provider "missing" not found in pack "openai-cn-pack"`) {
|
||||
t.Fatalf("findProvider() error = %v, want provider not found message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRequiresConfiguredAdminToken(t *testing.T) {
|
||||
handler := NewAPIHandler("", ActionSet{})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com"}, "any-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusInternalServerError)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestAPIBatchDetailRejectsInvalidBatchID(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
t.Fatal("BatchDetail should not be called for invalid batch id")
|
||||
return provision.BatchDetailResult{}, nil
|
||||
}})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/import-batches/not-a-number", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.message", "batch_id must be a positive integer")
|
||||
}
|
||||
|
||||
func TestAPIInstallPackRejectsInvalidJSON(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
t.Fatal("InstallPack should not be called for invalid JSON")
|
||||
return provision.PackInstallResult{}, nil
|
||||
}})
|
||||
request, err := http.NewRequest(http.MethodPost, "/api/packs/install", strings.NewReader(`{"host_base_url":"https://sub2api.example.com","unknown":true}`))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer secret-token")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
}
|
||||
|
||||
func TestAPIImportProviderReturnsClassifiedErrorWithoutBatch(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
return provision.RuntimeImportResult{}, errors.New("pack path is required")
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(0))
|
||||
}
|
||||
|
||||
func TestAPIPreviewProviderReturnsUpstreamError(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
return provision.PreviewReport{}, &sub2api.HTTPError{Method: http.MethodPost, Path: "/preview", StatusCode: http.StatusTooManyRequests, Body: "rate limited"}
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadGateway)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "host_request_failed")
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.upstream_status", float64(http.StatusTooManyRequests))
|
||||
}
|
||||
|
||||
func TestAPIRollbackProviderReturnsConfiguredError(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{}, &httpError{StatusCode: http.StatusGone, Code: "rolled_back", Message: "already removed"}
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusGone)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "rolled_back")
|
||||
}
|
||||
|
||||
func TestAPIReconcileProviderRejectsTrailingNonObjectPayload(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
t.Fatal("ReconcileProvider should not be called for invalid JSON")
|
||||
return provision.ReconcileResult{}, nil
|
||||
}})
|
||||
request, err := http.NewRequest(http.MethodPost, "/api/providers/deepseek/reconcile", strings.NewReader(`{"host_base_url":"https://sub2api.example.com"}[]`))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer secret-token")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.message", "request body must contain a single JSON object")
|
||||
}
|
||||
|
||||
// --- Coverage edge cases ---
|
||||
|
||||
func TestHTTPErrorError(t *testing.T) {
|
||||
e := &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "i'm a teapot"}
|
||||
if got := e.Error(); got != "i'm a teapot" {
|
||||
t.Fatalf("httpError.Error() = %q, want %q", got, "i'm a teapot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderAccessStatusFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/access/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderResourcesFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/resources", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderStatusReturnsError(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
GetProviderStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{}, errors.New(`provider "x" not found in pack "p"`)
|
||||
},
|
||||
})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusBadRequest)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "provider_not_found")
|
||||
}
|
||||
|
||||
func TestPostHandlersFnNil(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
}{
|
||||
{name: "install-pack", method: http.MethodPost, path: "/api/packs/install", body: `{}`},
|
||||
{name: "preview", method: http.MethodPost, path: "/api/providers/x/preview-import", body: `{}`},
|
||||
{name: "import", method: http.MethodPost, path: "/api/providers/x/import", body: `{}`},
|
||||
{name: "rollback", method: http.MethodPost, path: "/api/providers/x/rollback", body: `{}`},
|
||||
{name: "reconcile", method: http.MethodPost, path: "/api/providers/x/reconcile", body: `{}`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req, _ := http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
|
||||
req.Header.Set("Authorization", "Bearer t")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerErrorPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
actionSet ActionSet
|
||||
wantStatus int
|
||||
wantCode string
|
||||
}{
|
||||
{
|
||||
name: "access-status-error",
|
||||
method: http.MethodGet,
|
||||
path: "/api/providers/x/access/status",
|
||||
actionSet: ActionSet{
|
||||
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "preview-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/preview-import",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
return provision.PreviewReport{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "rollback-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/rollback",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "reconcile-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/reconcile",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
return provision.ReconcileResult{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := NewAPIHandler("t", tt.actionSet)
|
||||
var req *http.Request
|
||||
if tt.body != "" {
|
||||
req, _ = http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
var err error
|
||||
req, err = http.NewRequest(tt.method, tt.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, tt.wantStatus)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", tt.wantCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderAccessStatusMultipleClosures(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "p"},
|
||||
Provider: sqlite.Provider{ProviderID: "dp"},
|
||||
Batch: sqlite.ImportBatch{ID: 1},
|
||||
LatestAccessStatus: "ready",
|
||||
AccessClosures: []sqlite.AccessClosureRecord{
|
||||
{ID: 1, ClosureType: "preview", Status: "done", DetailsJSON: `{"v":1}`},
|
||||
{ID: 2, ClosureType: "self_service", Status: "active", DetailsJSON: `{"v":2}`},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/dp/access/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusOK)
|
||||
// Should report the last closure (index n-1)
|
||||
if !strings.Contains(res.Body().String(), `"closure_type":"self_service"`) {
|
||||
t.Fatalf("expected latest closure to be self_service, got: %s", res.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONContains(t *testing.T, payload []byte, key string, want any) {
|
||||
t.Helper()
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v; payload=%s", err, string(payload))
|
||||
}
|
||||
if strings.Contains(key, ".") {
|
||||
parts := strings.Split(key, ".")
|
||||
current := any(decoded)
|
||||
for _, part := range parts {
|
||||
object, ok := current.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("key %q not found in payload %s", key, string(payload))
|
||||
}
|
||||
current = object[part]
|
||||
}
|
||||
if current != want {
|
||||
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, current, want, string(payload))
|
||||
}
|
||||
return
|
||||
}
|
||||
if decoded[key] != want {
|
||||
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, decoded[key], want, string(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewActionSetReturnsNonNil(t *testing.T) {
|
||||
as := NewActionSet("file::memory:?cache=shared")
|
||||
t.Run("InstallPack", func(t *testing.T) {
|
||||
if as.InstallPack == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("BatchDetail", func(t *testing.T) {
|
||||
if as.BatchDetail == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderStatus", func(t *testing.T) {
|
||||
if as.GetProviderStatus == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderResources", func(t *testing.T) {
|
||||
if as.GetProviderResources == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderAccessStatus", func(t *testing.T) {
|
||||
if as.GetProviderAccessStatus == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("PreviewProvider", func(t *testing.T) {
|
||||
if as.PreviewProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("ImportProvider", func(t *testing.T) {
|
||||
if as.ImportProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("RollbackProvider", func(t *testing.T) {
|
||||
if as.RollbackProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("ReconcileProvider", func(t *testing.T) {
|
||||
if as.ReconcileProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBatchDetailReturnsNotFoundForMissingBatch(t *testing.T) {
|
||||
as := NewActionSet("file::memory:?cache=shared")
|
||||
_, err := as.BatchDetail(context.Background(), BatchDetailRequest{BatchID: 999})
|
||||
if err == nil {
|
||||
t.Fatal("BatchDetail() error = nil for missing batch, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewActionSetSQLiteClosures(t *testing.T) {
|
||||
dsn := "file::memory:?cache=shared"
|
||||
as := NewActionSet(dsn)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetProviderStatus on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetProviderResources on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderResources(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetProviderAccessStatus on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderAccessStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewActionSetPackErrorPaths(t *testing.T) {
|
||||
dsn := "file::memory:?cache=shared"
|
||||
as := NewActionSet(dsn)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("InstallPack bad path", func(t *testing.T) {
|
||||
_, err := as.InstallPack(ctx, InstallPackRequest{PackPath: "/nonexistent/pack"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreviewProvider bad path", func(t *testing.T) {
|
||||
_, err := as.PreviewProvider(ctx, PreviewProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImportProvider bad path", func(t *testing.T) {
|
||||
_, err := as.ImportProvider(ctx, ImportProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RollbackProvider bad path", func(t *testing.T) {
|
||||
_, err := as.RollbackProvider(ctx, RollbackProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ReconcileProvider bad path", func(t *testing.T) {
|
||||
_, err := as.ReconcileProvider(ctx, ReconcileProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,10 @@ func Bootstrap(_ context.Context) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewServer(cfg.Server.ListenAddr, nil), nil
|
||||
adminToken, err := config.LoadAdminTokenFromEnv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := NewAPIHandler(adminToken, NewActionSet(cfg.Database.SQLiteDSN))
|
||||
return NewServer(cfg.Server.ListenAddr, handler, nil), nil
|
||||
}
|
||||
|
||||
638
internal/app/http_api.go
Normal file
638
internal/app/http_api.go
Normal file
@@ -0,0 +1,638 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type ActionSet struct {
|
||||
InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)
|
||||
BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)
|
||||
GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)
|
||||
ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)
|
||||
RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)
|
||||
ReconcileProvider func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)
|
||||
}
|
||||
|
||||
type InstallPackRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
}
|
||||
|
||||
type BatchDetailRequest struct {
|
||||
BatchID int64
|
||||
}
|
||||
|
||||
type ProviderQueryRequest struct {
|
||||
ProviderID string
|
||||
PackID string
|
||||
}
|
||||
|
||||
type RollbackProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
}
|
||||
|
||||
type ReconcileProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
AccessAPIKey string `json:"access_api_key"`
|
||||
}
|
||||
|
||||
type PreviewProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
Keys []string `json:"keys"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type ImportProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
Keys []string `json:"keys"`
|
||||
Mode string `json:"mode"`
|
||||
AccessMode string `json:"access_mode"`
|
||||
AccessAPIKey string `json:"access_api_key"`
|
||||
SubscriptionUsers []string `json:"subscription_users"`
|
||||
SubscriptionDays int `json:"subscription_days"`
|
||||
}
|
||||
|
||||
type httpError struct {
|
||||
StatusCode int `json:"-"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
UpstreamStatus int `json:"upstream_status,omitempty"`
|
||||
}
|
||||
|
||||
func (e *httpError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func NewAPIHandler(adminToken string, actions ActionSet) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", healthz)
|
||||
mux.Handle("GET /api/import-batches/{batchID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleBatchDetail(w, r, actions.BatchDetail)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderStatus(w, r, actions.GetProviderStatus)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/resources", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderResources(w, r, actions.GetProviderResources)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/access/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderAccessStatus(w, r, actions.GetProviderAccessStatus)
|
||||
})))
|
||||
mux.Handle("POST /api/packs/install", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleInstallPack(w, r, actions.InstallPack)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/preview-import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlePreviewProvider(w, r, actions.PreviewProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleImportProvider(w, r, actions.ImportProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/rollback", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleRollbackProvider(w, r, actions.RollbackProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/reconcile", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleReconcileProvider(w, r, actions.ReconcileProvider)
|
||||
})))
|
||||
return mux
|
||||
}
|
||||
|
||||
func healthz(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func requireAdminToken(token string, next http.Handler) http.Handler {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "admin token is not configured"})
|
||||
})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bearerToken(r) != token {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "missing or invalid admin token"})
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func bearerToken(r *http.Request) string {
|
||||
header := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(header), "bearer ") {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(header[len("Bearer "):])
|
||||
}
|
||||
|
||||
func handleInstallPack(w http.ResponseWriter, r *http.Request, fn func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "install-pack action is not configured"})
|
||||
return
|
||||
}
|
||||
var req InstallPackRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
providers := make([]map[string]string, 0, len(result.Providers))
|
||||
for _, provider := range result.Providers {
|
||||
providers = append(providers, map[string]string{
|
||||
"provider_id": provider.ProviderID,
|
||||
"display_name": provider.DisplayName,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"pack_id": result.Pack.PackID,
|
||||
"version": result.Pack.Version,
|
||||
"host_version": result.HostVersion,
|
||||
"already_installed": result.AlreadyInstalled,
|
||||
"providers": providers,
|
||||
})
|
||||
}
|
||||
|
||||
func handleBatchDetail(w http.ResponseWriter, r *http.Request, fn func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "batch-detail action is not configured"})
|
||||
return
|
||||
}
|
||||
batchID, err := strconv.ParseInt(r.PathValue("batchID"), 10, 64)
|
||||
if err != nil || batchID <= 0 {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "batch_id must be a positive integer"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), BatchDetailRequest{BatchID: batchID})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
items := make([]map[string]any, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
items = append(items, map[string]any{
|
||||
"id": item.ID,
|
||||
"batch_id": item.BatchID,
|
||||
"key_fingerprint": item.KeyFingerprint,
|
||||
"account_status": item.AccountStatus,
|
||||
"probe_summary_json": item.ProbeSummaryJSON,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"batch": map[string]any{
|
||||
"id": result.Batch.ID,
|
||||
"host_id": result.Batch.HostID,
|
||||
"pack_id": result.Batch.PackID,
|
||||
"provider_id": result.Batch.ProviderID,
|
||||
"mode": result.Batch.Mode,
|
||||
"batch_status": result.Batch.BatchStatus,
|
||||
"access_status": result.Batch.AccessStatus,
|
||||
},
|
||||
"items": items,
|
||||
"managed_resources": result.ManagedResources,
|
||||
"access_closures": result.AccessClosures,
|
||||
"reconcile_runs": result.ReconcileRuns,
|
||||
"items_count": len(result.Items),
|
||||
"managed_count": len(result.ManagedResources),
|
||||
"access_count": len(result.AccessClosures),
|
||||
"reconcile_count": len(result.ReconcileRuns),
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-status action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"host": map[string]any{"host_id": result.Host.HostID, "base_url": result.Host.BaseURL, "host_version": result.Host.HostVersion},
|
||||
"pack": map[string]any{"pack_id": result.Pack.PackID, "version": result.Pack.Version},
|
||||
"provider": map[string]any{"provider_id": result.Provider.ProviderID, "display_name": result.Provider.DisplayName, "platform": result.Provider.Platform},
|
||||
"batch": map[string]any{"id": result.Batch.ID, "batch_status": result.Batch.BatchStatus, "access_status": result.Batch.AccessStatus, "mode": result.Batch.Mode},
|
||||
"provider_status": result.ProviderStatus,
|
||||
"latest_access_status": result.LatestAccessStatus,
|
||||
"latest_reconcile_status": result.LatestReconcileStatus,
|
||||
"latest_reconcile_summary": result.LatestReconcileSummary,
|
||||
"managed_resources_count": len(result.ManagedResources),
|
||||
"access_closures_count": len(result.AccessClosures),
|
||||
"reconcile_runs_count": len(result.ReconcileRuns),
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderAccessStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-access-status action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
latestClosure := map[string]any{}
|
||||
if n := len(result.AccessClosures); n > 0 {
|
||||
closure := result.AccessClosures[n-1]
|
||||
latestClosure = map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": result.Provider.ProviderID,
|
||||
"pack_id": result.Pack.PackID,
|
||||
"batch_id": result.Batch.ID,
|
||||
"batch_access_status": result.Batch.AccessStatus,
|
||||
"latest_access_status": result.LatestAccessStatus,
|
||||
"closures_count": len(result.AccessClosures),
|
||||
"latest_closure": latestClosure,
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderResources(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-resources action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
resources := make([]map[string]any, 0, len(result.ManagedResources))
|
||||
for _, resource := range result.ManagedResources {
|
||||
resources = append(resources, map[string]any{"id": resource.ID, "resource_type": resource.ResourceType, "host_resource_id": resource.HostResourceID, "resource_name": resource.ResourceName})
|
||||
}
|
||||
accessClosures := make([]map[string]any, 0, len(result.AccessClosures))
|
||||
for _, closure := range result.AccessClosures {
|
||||
accessClosures = append(accessClosures, map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON})
|
||||
}
|
||||
reconcileRuns := make([]map[string]any, 0, len(result.ReconcileRuns))
|
||||
for _, run := range result.ReconcileRuns {
|
||||
reconcileRuns = append(reconcileRuns, map[string]any{"id": run.ID, "status": run.Status, "summary_json": run.SummaryJSON})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": result.Provider.ProviderID,
|
||||
"pack_id": result.Pack.PackID,
|
||||
"batch_id": result.Batch.ID,
|
||||
"resources": resources,
|
||||
"access_closures": accessClosures,
|
||||
"reconcile_runs": reconcileRuns,
|
||||
})
|
||||
}
|
||||
|
||||
func handlePreviewProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "preview-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req PreviewProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"accepted_keys_count": len(result.AcceptedKeys),
|
||||
"names": result.Names,
|
||||
"decisions": result.Decisions,
|
||||
})
|
||||
}
|
||||
|
||||
func handleImportProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "import-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req ImportProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
payload := map[string]any{
|
||||
"batch_id": result.BatchID,
|
||||
"batch_status": result.Report.BatchStatus,
|
||||
"provider_status": result.Report.ProviderStatus,
|
||||
"access_status": result.Report.AccessStatus,
|
||||
"accepted_keys_count": len(result.Report.AcceptedKeys),
|
||||
"accounts_count": len(result.Report.Accounts),
|
||||
"gateway": result.Report.Gateway,
|
||||
"error": classifyError(err),
|
||||
}
|
||||
statusCode := http.StatusConflict
|
||||
if result.BatchID == 0 {
|
||||
statusCode = classifyError(err).StatusCode
|
||||
}
|
||||
writeJSON(w, statusCode, payload)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"batch_id": result.BatchID,
|
||||
"batch_status": result.Report.BatchStatus,
|
||||
"provider_status": result.Report.ProviderStatus,
|
||||
"access_status": result.Report.AccessStatus,
|
||||
"accepted_keys_count": len(result.Report.AcceptedKeys),
|
||||
"accounts_count": len(result.Report.Accounts),
|
||||
"group": result.Report.Group,
|
||||
"channel": result.Report.Channel,
|
||||
"plan": result.Report.Plan,
|
||||
"gateway": result.Report.Gateway,
|
||||
})
|
||||
}
|
||||
|
||||
func handleRollbackProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "rollback-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req RollbackProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": req.ProviderID,
|
||||
"deleted_accounts": result.AccountsDeleted,
|
||||
"deleted_plans": result.PlansDeleted,
|
||||
"deleted_channels": result.ChannelsDeleted,
|
||||
"deleted_groups": result.GroupsDeleted,
|
||||
})
|
||||
}
|
||||
|
||||
func handleReconcileProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "reconcile-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req ReconcileProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": req.ProviderID,
|
||||
"batch_id": result.BatchID,
|
||||
"status": result.Status,
|
||||
"missing_count": result.MissingCount,
|
||||
"extra_count": result.ExtraCount,
|
||||
"summary": result.Summary,
|
||||
})
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, dest any) *httpError {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(dest); err != nil {
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: fmt.Sprintf("decode request body: %v", err)}
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) {
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "request body must contain a single JSON object"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeHTTPError(w http.ResponseWriter, err *httpError) {
|
||||
if err == nil {
|
||||
err = &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: "internal server error"}
|
||||
}
|
||||
writeJSON(w, err.StatusCode, map[string]any{"error": err})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, statusCode int, body any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func classifyError(err error) *httpError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var requestErr *httpError
|
||||
if errors.As(err, &requestErr) {
|
||||
return requestErr
|
||||
}
|
||||
var upstreamErr *sub2api.HTTPError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
return &httpError{StatusCode: http.StatusBadGateway, Code: "host_request_failed", Message: err.Error(), UpstreamStatus: upstreamErr.StatusCode}
|
||||
}
|
||||
message := err.Error()
|
||||
switch {
|
||||
case strings.Contains(message, "already installed") || strings.Contains(message, "checksum drift"):
|
||||
return &httpError{StatusCode: http.StatusConflict, Code: "pack_conflict", Message: message}
|
||||
case strings.Contains(message, "not found in pack"):
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "provider_not_found", Message: message}
|
||||
case strings.Contains(message, "pack path") || strings.Contains(message, "pack dir") || strings.Contains(message, "required") || strings.Contains(message, "decode"):
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: message}
|
||||
default:
|
||||
return &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
func NewActionSet(sqliteDSN string) ActionSet {
|
||||
return ActionSet{
|
||||
InstallPack: func(ctx context.Context, req InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
service := provision.NewPackInstallService(store, client)
|
||||
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
|
||||
},
|
||||
BatchDetail: func(ctx context.Context, req BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.BatchDetailResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewBatchDetailService(store).Get(ctx, req.BatchID)
|
||||
},
|
||||
GetProviderStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
GetProviderResources: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetResources(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
GetProviderAccessStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
PreviewProvider: func(ctx context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
service := provision.NewPreviewService(client)
|
||||
return service.PreviewImport(ctx, provision.PreviewRequest{Provider: providerManifest, Mode: req.Mode, Keys: req.Keys})
|
||||
},
|
||||
ImportProvider: func(ctx context.Context, req ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
|
||||
for _, userID := range req.SubscriptionUsers {
|
||||
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
|
||||
}
|
||||
service := provision.NewRuntimeImportService(store, client)
|
||||
return service.Import(ctx, provision.RuntimeImportRequest{
|
||||
HostBaseURL: req.HostBaseURL,
|
||||
Pack: loadedPack,
|
||||
Provider: providerManifest,
|
||||
Mode: req.Mode,
|
||||
Keys: req.Keys,
|
||||
Access: provision.AccessRequest{
|
||||
Mode: req.AccessMode,
|
||||
ProbeAPIKey: req.AccessAPIKey,
|
||||
Subscriptions: subscriptions,
|
||||
},
|
||||
})
|
||||
},
|
||||
RollbackProvider: func(ctx context.Context, req RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
service := provision.NewRollbackService(client)
|
||||
return service.Rollback(ctx, provision.RollbackRequest{Provider: providerManifest})
|
||||
},
|
||||
ReconcileProvider: func(ctx context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
service := provision.NewReconcileService(store, client)
|
||||
return service.Reconcile(ctx, provision.ReconcileRequest{HostBaseURL: req.HostBaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
|
||||
for _, provider := range loaded.Providers {
|
||||
if provider.ProviderID == strings.TrimSpace(providerID) {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
|
||||
}
|
||||
Reference in New Issue
Block a user