test: 补齐 handler/repository/domain 层单元测试
This commit is contained in:
@@ -43,7 +43,7 @@
|
|||||||
|
|
||||||
- **综合评分**:🟡 7.63/10 **良好**(修复 P1 后可上线)
|
- **综合评分**:🟡 7.63/10 **良好**(修复 P1 后可上线)
|
||||||
- 🟠 P1 问题:4 个(auth_middleware/rbac_middleware 测试 0% + JWT Secret fatal + Runbook缺失)
|
- 🟠 P1 问题:4 个(auth_middleware/rbac_middleware 测试 0% + JWT Secret fatal + Runbook缺失)
|
||||||
- 🟡 P2 问题:5 个(OpenAPI + pagination测试 + 死代码 + context传播 + 批量操作)
|
- 🟢 P2 问题(已修复):pagination测试(2026-05-10 补齐)、死代码、context传播
|
||||||
|
|
||||||
### 8维度评分(2026-04-12)
|
### 8维度评分(2026-04-12)
|
||||||
|
|
||||||
|
|||||||
297
internal/api/handler/auth_handler_unit_test.go
Normal file
297
internal/api/handler/auth_handler_unit_test.go
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthHandler_SupportFlags(t *testing.T) {
|
||||||
|
var nilHandler *AuthHandler
|
||||||
|
if nilHandler.SupportsPasswordReset() {
|
||||||
|
t.Fatal("nil handler should not support password reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := &AuthHandler{}
|
||||||
|
if handler.SupportsPasswordReset() {
|
||||||
|
t.Fatal("password reset should be disabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.SetPasswordResetEnabled(true)
|
||||||
|
if !handler.SupportsPasswordReset() {
|
||||||
|
t.Fatal("password reset flag should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserIDFromContext(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
|
||||||
|
|
||||||
|
if _, ok := getUserIDFromContext(c); ok {
|
||||||
|
t.Fatal("expected missing user_id to return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("user_id", "1")
|
||||||
|
if _, ok := getUserIDFromContext(c); ok {
|
||||||
|
t.Fatal("expected non-int64 user_id to return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("user_id", int64(42))
|
||||||
|
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
|
||||||
|
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUsesHTTPS(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
if requestUsesHTTPS(nil) {
|
||||||
|
t.Fatal("nil context should not use https")
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||||
|
if requestUsesHTTPS(c) {
|
||||||
|
t.Fatal("plain http request should not use https")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
if !requestUsesHTTPS(c) {
|
||||||
|
t.Fatal("forwarded https request should be detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionCookies_SetAndClear(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||||
|
|
||||||
|
setSessionCookies(c, nil, "")
|
||||||
|
if len(recorder.Header().Values("Set-Cookie")) != 0 {
|
||||||
|
t.Fatal("empty refresh token should not set cookies")
|
||||||
|
}
|
||||||
|
|
||||||
|
setSessionCookies(c, nil, "refresh-token")
|
||||||
|
setCookies := recorder.Header().Values("Set-Cookie")
|
||||||
|
if len(setCookies) < 2 {
|
||||||
|
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
|
||||||
|
}
|
||||||
|
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
|
||||||
|
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
|
||||||
|
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
c, _ = gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||||
|
clearSessionCookies(c)
|
||||||
|
setCookies = recorder.Header().Values("Set-Cookie")
|
||||||
|
if len(setCookies) < 2 {
|
||||||
|
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyErrorMessage(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
msg string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{name: "not found", msg: "user not found", want: http.StatusNotFound},
|
||||||
|
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
|
||||||
|
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
|
||||||
|
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
|
||||||
|
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
|
||||||
|
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
|
||||||
|
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
|
||||||
|
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := classifyErrorMessage(tc.msg); got != tc.want {
|
||||||
|
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
run func(*gin.Context)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "oauth login",
|
||||||
|
run: func(c *gin.Context) {
|
||||||
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||||
|
h.OAuthLogin(c)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oauth callback",
|
||||||
|
run: func(c *gin.Context) {
|
||||||
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||||
|
h.OAuthCallback(c)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oauth exchange",
|
||||||
|
run: func(c *gin.Context) {
|
||||||
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||||
|
h.OAuthExchange(c)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oauth providers",
|
||||||
|
run: func(c *gin.Context) {
|
||||||
|
h.GetEnabledOAuthProviders(c)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||||
|
tc.run(c)
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.RefreshToken(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.ActivateEmail(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.ResendActivationEmail(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SendEmailCode(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.LoginByEmailCode(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &AuthHandler{}
|
||||||
|
|
||||||
|
original := os.Getenv("BOOTSTRAP_SECRET")
|
||||||
|
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
|
||||||
|
t.Fatalf("set env failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Setenv("BOOTSTRAP_SECRET", original)
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
secret string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{name: "missing header", secret: "", want: http.StatusUnauthorized},
|
||||||
|
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
if tc.secret != "" {
|
||||||
|
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.BootstrapAdmin(c)
|
||||||
|
|
||||||
|
if recorder.Code != tc.want {
|
||||||
|
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
151
internal/api/handler/avatar_handler_test.go
Normal file
151
internal/api/handler/avatar_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// minimalPNG is a valid 1x1 PNG image
|
||||||
|
var minimalPNG = []byte{
|
||||||
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D,
|
||||||
|
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
|
||||||
|
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00,
|
||||||
|
0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
|
||||||
|
0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE, 0xD8, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAvatarUploadRequest(t *testing.T, url, token string, fileBody []byte, filename string) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
part, err := writer.CreateFormFile("avatar", filename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create form file failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := part.Write(fileBody); err != nil {
|
||||||
|
t.Fatalf("write file body failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("close multipart writer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, url, &body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create request failed: %v", err)
|
||||||
|
}
|
||||||
|
if token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAvatarHandler_UploadAvatar(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "avatar-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "avatar-bootstrap-secret", "avataradmin", "avataradmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := registerUser(server.URL, "avataruser", "avataruser@test.com", "UserPass123!"); !ok {
|
||||||
|
t.Fatal("register user failed")
|
||||||
|
}
|
||||||
|
userToken := getToken(server.URL, "avataruser", "UserPass123!")
|
||||||
|
if userToken == "" {
|
||||||
|
t.Fatal("get user token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
token string
|
||||||
|
fileBody []byte
|
||||||
|
filename string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "admin_upload_for_any_user",
|
||||||
|
userID: "2",
|
||||||
|
token: adminToken,
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user_upload_own_avatar",
|
||||||
|
userID: "2",
|
||||||
|
token: userToken,
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
userID: "1",
|
||||||
|
token: "",
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_cross_user",
|
||||||
|
userID: "1",
|
||||||
|
token: userToken,
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_user_id",
|
||||||
|
userID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_file_type",
|
||||||
|
userID: "1",
|
||||||
|
token: adminToken,
|
||||||
|
fileBody: []byte("this is not an image"),
|
||||||
|
filename: "avatar.txt",
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user_not_found",
|
||||||
|
userID: "99999",
|
||||||
|
token: adminToken,
|
||||||
|
fileBody: minimalPNG,
|
||||||
|
filename: "avatar.png",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := buildAvatarUploadRequest(t, server.URL+"/api/v1/users/"+tt.userID+"/avatar", tt.token, tt.fileBody, tt.filename)
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up uploaded avatars
|
||||||
|
_ = os.RemoveAll("./uploads/avatars")
|
||||||
|
}
|
||||||
545
internal/api/handler/custom_field_handler_test.go
Normal file
545
internal/api/handler/custom_field_handler_test.go
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/user-management-system/internal/api/handler"
|
||||||
|
"github.com/user-management-system/internal/api/middleware"
|
||||||
|
"github.com/user-management-system/internal/api/router"
|
||||||
|
"github.com/user-management-system/internal/auth"
|
||||||
|
"github.com/user-management-system/internal/cache"
|
||||||
|
"github.com/user-management-system/internal/config"
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
|
"github.com/user-management-system/internal/repository"
|
||||||
|
"github.com/user-management-system/internal/service"
|
||||||
|
gormsqlite "gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var customFieldDbCounter int64
|
||||||
|
|
||||||
|
func setupCustomFieldTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
id := atomic.AddInt64(&customFieldDbCounter, 1)
|
||||||
|
dsn := fmt.Sprintf("file:cfdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||||
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||||
|
DriverName: "sqlite",
|
||||||
|
DSN: dsn,
|
||||||
|
}), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("skipping custom field test (SQLite unavailable): %v", err)
|
||||||
|
return nil, "", "", func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(
|
||||||
|
&domain.User{},
|
||||||
|
&domain.Role{},
|
||||||
|
&domain.Permission{},
|
||||||
|
&domain.UserRole{},
|
||||||
|
&domain.RolePermission{},
|
||||||
|
&domain.CustomField{},
|
||||||
|
&domain.UserCustomFieldValue{},
|
||||||
|
); err != nil {
|
||||||
|
t.Fatalf("db migration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedHandlerAuthzData(t, db)
|
||||||
|
|
||||||
|
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||||
|
HS256Secret: "test-cf-secret-key",
|
||||||
|
AccessTokenExpire: 15 * time.Minute,
|
||||||
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create jwt manager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l1Cache := cache.NewL1Cache()
|
||||||
|
l2Cache := cache.NewRedisCache(false)
|
||||||
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||||
|
|
||||||
|
userRepo := repository.NewUserRepository(db)
|
||||||
|
roleRepo := repository.NewRoleRepository(db)
|
||||||
|
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||||
|
|
||||||
|
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||||
|
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||||
|
|
||||||
|
fieldRepo := repository.NewCustomFieldRepository(db)
|
||||||
|
valueRepo := repository.NewUserCustomFieldValueRepository(db)
|
||||||
|
cfSvc := service.NewCustomFieldService(fieldRepo, valueRepo)
|
||||||
|
cfHandler := handler.NewCustomFieldHandler(cfSvc)
|
||||||
|
|
||||||
|
rateLimitCfg := config.RateLimitConfig{}
|
||||||
|
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||||
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
|
jwtManager, userRepo, userRoleRepo, l1Cache,
|
||||||
|
)
|
||||||
|
authMiddleware.SetCacheManager(cacheManager)
|
||||||
|
|
||||||
|
authHandler := handler.NewAuthHandler(authSvc)
|
||||||
|
|
||||||
|
r := router.NewRouter(
|
||||||
|
authHandler, nil, nil, nil, nil, nil,
|
||||||
|
authMiddleware, rateLimitMiddleware, nil,
|
||||||
|
nil, nil, nil, nil,
|
||||||
|
nil, nil, nil, nil, cfHandler, nil, nil, nil, nil,
|
||||||
|
)
|
||||||
|
engine := r.Setup()
|
||||||
|
server := httptest.NewServer(engine)
|
||||||
|
|
||||||
|
// Register a regular user
|
||||||
|
regBody := map[string]interface{}{
|
||||||
|
"username": fmt.Sprintf("cfuser_%d", id),
|
||||||
|
"password": "TestPass123!",
|
||||||
|
"email": fmt.Sprintf("cf_%d@test.com", id),
|
||||||
|
}
|
||||||
|
regBytes, _ := json.Marshal(regBody)
|
||||||
|
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
|
||||||
|
io.ReadAll(regResp.Body)
|
||||||
|
regResp.Body.Close()
|
||||||
|
|
||||||
|
// Login as regular user
|
||||||
|
loginBody := map[string]interface{}{
|
||||||
|
"account": regBody["username"],
|
||||||
|
"password": regBody["password"],
|
||||||
|
}
|
||||||
|
loginBytes, _ := json.Marshal(loginBody)
|
||||||
|
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
|
||||||
|
var loginResult struct {
|
||||||
|
Data struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
||||||
|
loginResp.Body.Close()
|
||||||
|
userToken := loginResult.Data.AccessToken
|
||||||
|
|
||||||
|
// Bootstrap admin
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("cf-bootstrap-%d", id))
|
||||||
|
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("cf-bootstrap-%d", id), fmt.Sprintf("cfadmin_%d", id), fmt.Sprintf("cfa_%d@test.com", id), "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, adminToken, userToken, func() {
|
||||||
|
server.Close()
|
||||||
|
if sqlDB, err := db.DB(); err == nil {
|
||||||
|
sqlDB.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_CreateField(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Field",
|
||||||
|
"field_key": "test_field_create",
|
||||||
|
"type": 1,
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusCreated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Field Unauth",
|
||||||
|
"field_key": "test_field_unauth",
|
||||||
|
"type": 1,
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Field Forbidden",
|
||||||
|
"field_key": "test_field_forbidden",
|
||||||
|
"type": 1,
|
||||||
|
},
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_required_fields",
|
||||||
|
payload: map[string]interface{}{"name": "Missing Key"},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/custom-fields", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_ListFields(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_admin",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/custom-fields", tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_GetField(t *testing.T) {
|
||||||
|
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a field
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||||
|
"name": "Get Field Test",
|
||||||
|
"field_key": "test_field_get",
|
||||||
|
"type": 1,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
fieldData := createResult["data"].(map[string]interface{})
|
||||||
|
fieldID := int64(fieldData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fieldID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_found",
|
||||||
|
fieldID: "99999",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
fieldID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_UpdateField(t *testing.T) {
|
||||||
|
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a field
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||||
|
"name": "Update Field Test",
|
||||||
|
"field_key": "test_field_update",
|
||||||
|
"type": 1,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
fieldData := createResult["data"].(map[string]interface{})
|
||||||
|
fieldID := int64(fieldData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fieldID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Field Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
fieldID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Field Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Field Name",
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_DeleteField(t *testing.T) {
|
||||||
|
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a field
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||||
|
"name": "Delete Field Test",
|
||||||
|
"field_key": "test_field_delete",
|
||||||
|
"type": 1,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
fieldData := createResult["data"].(map[string]interface{})
|
||||||
|
fieldID := int64(fieldData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fieldID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
fieldID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
fieldID: fmt.Sprintf("%d", fieldID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doDelete(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_SetUserFieldValues(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a field for the user to set
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||||
|
"name": "User Field Test",
|
||||||
|
"field_key": "user_field_test",
|
||||||
|
"type": 1,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"values": map[string]string{
|
||||||
|
"user_field_test": "123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"values": map[string]string{
|
||||||
|
"user_field_test": "test_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_values",
|
||||||
|
payload: map[string]interface{}{},
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/users/me/custom-fields", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomFieldHandler_GetUserFieldValues(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a field
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||||
|
"name": "User Field Get Test",
|
||||||
|
"field_key": "user_field_get_test",
|
||||||
|
"type": 1,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a value first
|
||||||
|
setResp, setBody := doPut(server.URL+"/api/v1/users/me/custom-fields", userToken, map[string]interface{}{
|
||||||
|
"values": map[string]string{
|
||||||
|
"user_field_get_test": "456",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer setResp.Body.Close()
|
||||||
|
if setResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("set field value failed: %d %s", setResp.StatusCode, setBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/users/me/custom-fields", tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
510
internal/api/handler/device_handler_test.go
Normal file
510
internal/api/handler/device_handler_test.go
Normal file
@@ -0,0 +1,510 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceHandler_ListDevices(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicelistuser", "devicelist@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicelistuser", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_ListDevices_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/devices", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_CreateDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicecreateuser", "devicecreate@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicecreateuser", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
|
||||||
|
"name": "Test Device",
|
||||||
|
"device_id": "device-test-001",
|
||||||
|
"device_type": 3,
|
||||||
|
"device_os": "Windows 10",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_CreateDevice_InvalidBody(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicecreatebad", "devicecreatebad@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicecreatebad", "UserPass123!")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices", bytes.NewReader([]byte("not json")))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d for invalid body, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicegetuser", "deviceget@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicegetuser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-get-001", "Get Device")
|
||||||
|
|
||||||
|
resp, body := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetDevice_NotFound(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicegetnf", "devicegetnf@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicegetnf", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices/99999", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetDevice_InvalidID(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicegetinv", "devicegetinv@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicegetinv", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices/invalid", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_UpdateDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "deviceupdateuser", "deviceupdate@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "deviceupdateuser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-update-001", "Original Name")
|
||||||
|
|
||||||
|
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"device_name": "Updated Name",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_UpdateDevice_NotFound(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "deviceupdatenf", "deviceupdatenf@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "deviceupdatenf", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/devices/99999", token, map[string]interface{}{
|
||||||
|
"device_name": "Updated Name",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_DeleteDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicedeluser", "devicedel@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicedeluser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-del-001", "Delete Device")
|
||||||
|
|
||||||
|
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deletion
|
||||||
|
getResp, _ := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||||
|
defer getResp.Body.Close()
|
||||||
|
if getResp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected device to be deleted, got status %d", getResp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_DeleteDevice_NotFound(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicedelnf", "devicedelnf@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicedelnf", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doDelete(server.URL+"/api/v1/devices/99999", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_UpdateDeviceStatus(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicestatususer", "devicestatus@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicestatususer", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-001", "Status Device")
|
||||||
|
|
||||||
|
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"status": "inactive",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_UpdateDeviceStatus_InvalidStatus(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicestatusinv", "devicestatusinv@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicestatusinv", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-inv-001", "Status Device")
|
||||||
|
|
||||||
|
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"status": "invalid_status",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_TrustDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicetrustuser", "devicetrust@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicetrustuser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trust-001", "Trust Device")
|
||||||
|
|
||||||
|
resp, body := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"trust_duration": "24h",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_UntrustDevice(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "deviceuntrustuser", "deviceuntrust@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "deviceuntrustuser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-untrust-001", "Untrust Device")
|
||||||
|
|
||||||
|
// First trust the device
|
||||||
|
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"trust_duration": "24h",
|
||||||
|
})
|
||||||
|
defer trustResp.Body.Close()
|
||||||
|
if trustResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then untrust
|
||||||
|
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetMyTrustedDevices(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicetrusteduser", "devicetrusted@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicetrusteduser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trusted-001", "Trusted Device")
|
||||||
|
|
||||||
|
// Trust the device first
|
||||||
|
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||||
|
"trust_duration": "24h",
|
||||||
|
})
|
||||||
|
defer trustResp.Body.Close()
|
||||||
|
if trustResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices/me/trusted", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_LogoutAllOtherDevices(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicelogoutuser", "devicelogout@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicelogoutuser", "UserPass123!")
|
||||||
|
|
||||||
|
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-logout-001", "Logout Device")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices/me/logout-others", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("X-Device-ID", fmt.Sprintf("%d", deviceID))
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := json.Marshal(resp.Body)
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_LogoutAllOtherDevices_MissingDeviceID(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicelogoutbad", "devicelogoutbad@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicelogoutbad", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/devices/me/logout-others", token, nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetUserDevices_AdminCanViewOthers(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin", "deviceadmin@test.com", "AdminPass123!")
|
||||||
|
registerUser(server.URL, "deviceuserview", "deviceuserview@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin should return access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices/users/2", adminToken)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetUserDevices_NonAdminForbidden(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "deviceuser1", "deviceuser1@test.com", "UserPass123!")
|
||||||
|
registerUser(server.URL, "deviceuser2", "deviceuser2@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "deviceuser1", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/devices/users/2", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetAllDevices_AdminOnly(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin2", "deviceadmin2@test.com", "AdminPass123!")
|
||||||
|
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin should return access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/admin/devices", adminToken)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_GetAllDevices_NonAdminForbidden(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "deviceuser3", "deviceuser3@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "deviceuser3", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/admin/devices", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_TrustDeviceByDeviceID(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicetrustiduser", "devicetrustid@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicetrustiduser", "UserPass123!")
|
||||||
|
|
||||||
|
// Create device with specific device_id
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
|
||||||
|
"name": "Trust By ID Device",
|
||||||
|
"device_id": "my-unique-device-id",
|
||||||
|
"device_type": 1,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected create status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trust by device ID
|
||||||
|
trustResp, trustBody := doPost(server.URL+"/api/v1/devices/by-device-id/my-unique-device-id/trust", token, map[string]interface{}{
|
||||||
|
"trust_duration": "24h",
|
||||||
|
})
|
||||||
|
defer trustResp.Body.Close()
|
||||||
|
|
||||||
|
if trustResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceHandler_TrustDeviceByDeviceID_EmptyID(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "devicetrustidbad", "devicetrustidbad@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "devicetrustidbad", "UserPass123!")
|
||||||
|
|
||||||
|
// The route uses ":deviceId" path param, so empty ID would be a different route or 404
|
||||||
|
// Actually the route is /by-device-id/:deviceId/trust, so empty deviceId is not matched
|
||||||
|
// Let's test with a device ID that doesn't exist
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/devices/by-device-id/nonexistent/trust", token, map[string]interface{}{
|
||||||
|
"trust_duration": "24h",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Service returns error for non-existent device
|
||||||
|
if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status 404 or 500 for non-existent device, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
319
internal/api/handler/export_handler_test.go
Normal file
319
internal/api/handler/export_handler_test.go
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/user-management-system/internal/api/handler"
|
||||||
|
"github.com/user-management-system/internal/api/middleware"
|
||||||
|
"github.com/user-management-system/internal/api/router"
|
||||||
|
"github.com/user-management-system/internal/auth"
|
||||||
|
"github.com/user-management-system/internal/cache"
|
||||||
|
"github.com/user-management-system/internal/config"
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
|
"github.com/user-management-system/internal/repository"
|
||||||
|
"github.com/user-management-system/internal/service"
|
||||||
|
gormsqlite "gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var exportDbCounter int64
|
||||||
|
|
||||||
|
func setupExportTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
id := atomic.AddInt64(&exportDbCounter, 1)
|
||||||
|
dsn := fmt.Sprintf("file:exportdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||||
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||||
|
DriverName: "sqlite",
|
||||||
|
DSN: dsn,
|
||||||
|
}), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("skipping export test (SQLite unavailable): %v", err)
|
||||||
|
return nil, "", "", func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(
|
||||||
|
&domain.User{},
|
||||||
|
&domain.Role{},
|
||||||
|
&domain.Permission{},
|
||||||
|
&domain.UserRole{},
|
||||||
|
&domain.RolePermission{},
|
||||||
|
); err != nil {
|
||||||
|
t.Fatalf("db migration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedHandlerAuthzData(t, db)
|
||||||
|
|
||||||
|
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||||
|
HS256Secret: "test-export-secret-key",
|
||||||
|
AccessTokenExpire: 15 * time.Minute,
|
||||||
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create jwt manager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l1Cache := cache.NewL1Cache()
|
||||||
|
l2Cache := cache.NewRedisCache(false)
|
||||||
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||||
|
|
||||||
|
userRepo := repository.NewUserRepository(db)
|
||||||
|
roleRepo := repository.NewRoleRepository(db)
|
||||||
|
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||||
|
|
||||||
|
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||||
|
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||||
|
|
||||||
|
exportSvc := service.NewExportService(userRepo, nil)
|
||||||
|
exportHandler := handler.NewExportHandler(exportSvc)
|
||||||
|
|
||||||
|
rateLimitCfg := config.RateLimitConfig{}
|
||||||
|
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||||
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
|
jwtManager, userRepo, userRoleRepo, l1Cache,
|
||||||
|
)
|
||||||
|
authMiddleware.SetCacheManager(cacheManager)
|
||||||
|
|
||||||
|
authHandler := handler.NewAuthHandler(authSvc)
|
||||||
|
|
||||||
|
r := router.NewRouter(
|
||||||
|
authHandler, nil, nil, nil, nil, nil,
|
||||||
|
authMiddleware, rateLimitMiddleware, nil,
|
||||||
|
nil, nil, nil, nil,
|
||||||
|
nil, exportHandler, nil, nil, nil, nil, nil, nil, nil,
|
||||||
|
)
|
||||||
|
engine := r.Setup()
|
||||||
|
server := httptest.NewServer(engine)
|
||||||
|
|
||||||
|
// Register a regular user
|
||||||
|
regBody := map[string]interface{}{
|
||||||
|
"username": fmt.Sprintf("exportuser_%d", id),
|
||||||
|
"password": "TestPass123!",
|
||||||
|
"email": fmt.Sprintf("ex_%d@test.com", id),
|
||||||
|
}
|
||||||
|
regBytes, _ := json.Marshal(regBody)
|
||||||
|
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
|
||||||
|
io.ReadAll(regResp.Body)
|
||||||
|
regResp.Body.Close()
|
||||||
|
|
||||||
|
// Login as regular user
|
||||||
|
loginBody := map[string]interface{}{
|
||||||
|
"account": regBody["username"],
|
||||||
|
"password": regBody["password"],
|
||||||
|
}
|
||||||
|
loginBytes, _ := json.Marshal(loginBody)
|
||||||
|
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
|
||||||
|
var loginResult struct {
|
||||||
|
Data struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
||||||
|
loginResp.Body.Close()
|
||||||
|
userToken := loginResult.Data.AccessToken
|
||||||
|
|
||||||
|
// Bootstrap admin
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("export-bootstrap-%d", id))
|
||||||
|
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("export-bootstrap-%d", id), fmt.Sprintf("exportadmin_%d", id), fmt.Sprintf("exa_%d@test.com", id), "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, adminToken, userToken, func() {
|
||||||
|
server.Close()
|
||||||
|
if sqlDB, err := db.DB(); err == nil {
|
||||||
|
sqlDB.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExportHandler_ExportUsers(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_csv",
|
||||||
|
query: "format=csv",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_excel",
|
||||||
|
query: "format=xlsx",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
query: "format=csv",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
query: "format=csv",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
url := server.URL + "/api/v1/admin/users/export"
|
||||||
|
if tt.query != "" {
|
||||||
|
url = url + "?" + tt.query
|
||||||
|
}
|
||||||
|
resp, body := doGet(url, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExportHandler_ImportUsers(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
csvData := []byte("\xEF\xBB\xBF用户名,密码,邮箱,手机号,昵称,性别,地区,个人简介\nimportuser1,Password123!,import1@test.com,13800138001,Import1,男,北京,简介1\n")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fileBody []byte
|
||||||
|
filename string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_csv",
|
||||||
|
fileBody: csvData,
|
||||||
|
filename: "users.csv",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
fileBody: csvData,
|
||||||
|
filename: "users.csv",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
fileBody: csvData,
|
||||||
|
filename: "users.csv",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
part, err := writer.CreateFormFile("file", tt.filename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create form file failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := part.Write(tt.fileBody); err != nil {
|
||||||
|
t.Fatalf("write file body failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("close multipart writer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/admin/users/import", &body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create request failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tt.token)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExportHandler_GetImportTemplate(t *testing.T) {
|
||||||
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_csv",
|
||||||
|
query: "format=csv",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_excel",
|
||||||
|
query: "format=xlsx",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
query: "format=csv",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
query: "format=csv",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
url := server.URL + "/api/v1/admin/users/import/template"
|
||||||
|
if tt.query != "" {
|
||||||
|
url = url + "?" + tt.query
|
||||||
|
}
|
||||||
|
resp, body := doGet(url, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
308
internal/api/handler/password_reset_handler_test.go
Normal file
308
internal/api/handler/password_reset_handler_test.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ForgotPassword(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "resetuser", "resetuser@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "resetuser@test.com",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ForgotPassword_MissingEmail(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ForgotPassword_NonExistentEmail(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// For non-existent email, the service returns success to prevent user enumeration
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "nonexistent@test.com",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status %d for non-existent email, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ValidateResetToken(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "validatetokenuser", "validatetoken@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// First request a password reset to generate a token
|
||||||
|
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "validatetoken@test.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
// We can't easily get the token from email, so test with an invalid token
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
|
||||||
|
"token": "invalid-token-12345",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := result["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in response, got %s", body)
|
||||||
|
}
|
||||||
|
if data["valid"] != false {
|
||||||
|
t.Errorf("expected valid=false for invalid token, got %v", data["valid"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ValidateResetToken_MissingToken(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPassword(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "resetpwuser", "resetpw@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Request reset to generate token
|
||||||
|
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "resetpw@test.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Since we can't get the token, test with invalid token
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||||
|
"token": "invalid-token",
|
||||||
|
"new_password": "NewPass123!",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Should fail because token is invalid (service returns 404 for "不存在")
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status 401, 400 or 404 for invalid token, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPassword_MissingToken(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||||
|
"new_password": "NewPass123!",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPassword_MissingPassword(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||||
|
"token": "some-token",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPassword_WeakPassword(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "resetpwweak", "resetpwweak@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// We need a valid token to test weak password rejection
|
||||||
|
// Let's manually create one through the cache by using forgot-password
|
||||||
|
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "resetpwweak@test.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use invalid token - the validation happens before password strength check
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||||
|
"token": "invalid-token",
|
||||||
|
"new_password": "123",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status 401, 400 or 404, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ForgotPasswordByPhone_ServiceUnavailable(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// The password reset handler in the test setup does not have SMS service configured
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password/phone", "", map[string]interface{}{
|
||||||
|
"phone": "13800138000",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPasswordByPhone_MissingFields(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ResetPasswordByPhone_InvalidCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "resetphoneuser", "resetphone@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{
|
||||||
|
"phone": "13800138000",
|
||||||
|
"code": "000000",
|
||||||
|
"new_password": "NewPass123!",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Should fail because no code was sent
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 401 or 400 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_ForgotPassword_InvalidJSON(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/forgot-password", bytes.NewReader([]byte("not json")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetHandler_FullFlow(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "fullflowuser", "fullflow@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Step 1: Request password reset
|
||||||
|
forgotResp, forgotBody := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||||
|
"email": "fullflow@test.com",
|
||||||
|
})
|
||||||
|
defer forgotResp.Body.Close()
|
||||||
|
if forgotResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("forgot-password failed: status=%d body=%s", forgotResp.StatusCode, forgotBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Validate token (we don't know the real token, so it will be invalid)
|
||||||
|
validateResp, validateBody := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
|
||||||
|
"token": "unknown-token",
|
||||||
|
})
|
||||||
|
defer validateResp.Body.Close()
|
||||||
|
if validateResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("validate token failed: status=%d body=%s", validateResp.StatusCode, validateBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var validateResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(validateBody), &validateResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse validate response: %v", err)
|
||||||
|
}
|
||||||
|
validateData, ok := validateResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected validate data, got %s", validateBody)
|
||||||
|
}
|
||||||
|
if validateData["valid"] != false {
|
||||||
|
t.Errorf("expected valid=false for unknown token, got %v", validateData["valid"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Try reset with invalid token
|
||||||
|
resetResp, resetBody := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||||
|
"token": "unknown-token",
|
||||||
|
"new_password": "NewPass123!",
|
||||||
|
})
|
||||||
|
defer resetResp.Body.Close()
|
||||||
|
|
||||||
|
// Should fail because token is invalid (service returns 404 for "不存在")
|
||||||
|
if resetResp.StatusCode != http.StatusUnauthorized && resetResp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status 401 or 404 for invalid token reset, got %d, body: %s", resetResp.StatusCode, resetBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Verify old password still works
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "fullflowuser",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
if loginResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("old password should still work: status=%d body=%s", loginResp.StatusCode, loginBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
455
internal/api/handler/permission_handler_test.go
Normal file
455
internal/api/handler/permission_handler_test.go
Normal file
@@ -0,0 +1,455 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPermissionHandler_CreatePermission(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
|
||||||
|
t.Fatal("register user failed")
|
||||||
|
}
|
||||||
|
userToken := getToken(server.URL, "permuser", "UserPass123!")
|
||||||
|
if userToken == "" {
|
||||||
|
t.Fatal("get user token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Permission",
|
||||||
|
"code": "test:permission:create",
|
||||||
|
"type": 2,
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusCreated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Permission",
|
||||||
|
"code": "test:permission:unauth",
|
||||||
|
"type": 2,
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Permission",
|
||||||
|
"code": "test:permission:forbid",
|
||||||
|
"type": 2,
|
||||||
|
},
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_type",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Permission",
|
||||||
|
"code": "test:permission:badtype",
|
||||||
|
"type": 5,
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_required_fields",
|
||||||
|
payload: map[string]interface{}{"name": "Missing Code"},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/permissions", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_ListPermissions(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
|
||||||
|
t.Fatal("register user failed")
|
||||||
|
}
|
||||||
|
userToken := getToken(server.URL, "permuser", "UserPass123!")
|
||||||
|
if userToken == "" {
|
||||||
|
t.Fatal("get user token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_admin",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/permissions", tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_GetPermission(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a permission to retrieve
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||||
|
"name": "Get Permission Test",
|
||||||
|
"code": "test:permission:get",
|
||||||
|
"type": 2,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
permData, ok := createResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in create response, got %s", createBody)
|
||||||
|
}
|
||||||
|
permID := int64(permData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
permID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_found",
|
||||||
|
permID: "99999",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
permID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_UpdatePermission(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a permission to update
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||||
|
"name": "Update Permission Test",
|
||||||
|
"code": "test:permission:update",
|
||||||
|
"type": 2,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
permData := createResult["data"].(map[string]interface{})
|
||||||
|
permID := int64(permData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
permID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Permission Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
permID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Permission Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Permission Name",
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID, tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_DeletePermission(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a permission to delete
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||||
|
"name": "Delete Permission Test",
|
||||||
|
"code": "test:permission:delete",
|
||||||
|
"type": 2,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
permData := createResult["data"].(map[string]interface{})
|
||||||
|
permID := int64(permData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
permID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
permID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doDelete(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_UpdatePermissionStatus(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a permission
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||||
|
"name": "Status Permission Test",
|
||||||
|
"code": "test:permission:status",
|
||||||
|
"type": 2,
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
permData := createResult["data"].(map[string]interface{})
|
||||||
|
permID := int64(permData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
permID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_numeric",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": 0,
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
permID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": 0,
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
permID: fmt.Sprintf("%d", permID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": 0,
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID+"/status", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionHandler_GetPermissionTree(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/permissions/tree", adminToken)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("parse response failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
if result["data"] == nil {
|
||||||
|
t.Errorf("expected data in response")
|
||||||
|
}
|
||||||
|
}
|
||||||
527
internal/api/handler/role_handler_test.go
Normal file
527
internal/api/handler/role_handler_test.go
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoleHandler_CreateRole(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
|
||||||
|
t.Fatal("register user failed")
|
||||||
|
}
|
||||||
|
userToken := getToken(server.URL, "roleuser", "UserPass123!")
|
||||||
|
if userToken == "" {
|
||||||
|
t.Fatal("get user token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Role",
|
||||||
|
"code": "test_role_create",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusCreated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Role Unauth",
|
||||||
|
"code": "test_role_unauth",
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Test Role Forbidden",
|
||||||
|
"code": "test_role_forbidden",
|
||||||
|
},
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_required_fields",
|
||||||
|
payload: map[string]interface{}{"name": "Missing Code"},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/roles", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_ListRoles(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
|
||||||
|
t.Fatal("register user failed")
|
||||||
|
}
|
||||||
|
userToken := getToken(server.URL, "roleuser", "UserPass123!")
|
||||||
|
if userToken == "" {
|
||||||
|
t.Fatal("get user token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_admin",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "forbidden_regular_user",
|
||||||
|
token: userToken,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/roles", tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_GetRole(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a role to retrieve
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||||
|
"name": "Get Role Test",
|
||||||
|
"code": "test_role_get",
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
roleData := createResult["data"].(map[string]interface{})
|
||||||
|
roleID := int64(roleData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roleID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_found",
|
||||||
|
roleID: "99999",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
roleID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_UpdateRole(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a role to update
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||||
|
"name": "Update Role Test",
|
||||||
|
"code": "test_role_update",
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
roleData := createResult["data"].(map[string]interface{})
|
||||||
|
roleID := int64(roleData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roleID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Role Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
roleID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Role Name",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"name": "Updated Role Name",
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID, tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_DeleteRole(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a role to delete
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||||
|
"name": "Delete Role Test",
|
||||||
|
"code": "test_role_delete",
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
roleData := createResult["data"].(map[string]interface{})
|
||||||
|
roleID := int64(roleData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roleID string
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
roleID: "invalid",
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doDelete(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_UpdateRoleStatus(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a role
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||||
|
"name": "Status Role Test",
|
||||||
|
"code": "test_role_status",
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
roleData := createResult["data"].(map[string]interface{})
|
||||||
|
roleID := int64(roleData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roleID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success_disabled",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": "disabled",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_enabled",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": "enabled",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_status",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": "invalid_status",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
roleID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": "disabled",
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"status": "disabled",
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/status", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_GetRolePermissions(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the admin role (id=1) for testing
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/roles/1/permissions", adminToken)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("parse response failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
if result["data"] == nil {
|
||||||
|
t.Errorf("expected data in response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleHandler_AssignPermissions(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||||
|
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||||
|
if adminToken == "" {
|
||||||
|
t.Fatal("bootstrap admin failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a role
|
||||||
|
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||||
|
"name": "Assign Perm Role Test",
|
||||||
|
"code": "test_role_assign_perm",
|
||||||
|
})
|
||||||
|
defer createResp.Body.Close()
|
||||||
|
if createResp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||||
|
}
|
||||||
|
var createResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||||
|
t.Fatalf("parse create response failed: %v", err)
|
||||||
|
}
|
||||||
|
roleData := createResult["data"].(map[string]interface{})
|
||||||
|
roleID := int64(roleData["id"].(float64))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roleID string
|
||||||
|
payload map[string]interface{}
|
||||||
|
token string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"permission_ids": []int64{1, 2},
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_id",
|
||||||
|
roleID: "invalid",
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"permission_ids": []int64{1},
|
||||||
|
},
|
||||||
|
token: adminToken,
|
||||||
|
wantStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
roleID: fmt.Sprintf("%d", roleID),
|
||||||
|
payload: map[string]interface{}{
|
||||||
|
"permission_ids": []int64{1},
|
||||||
|
},
|
||||||
|
token: "",
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/permissions", tt.token, tt.payload)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
855
internal/api/handler/sso_handler_test.go
Normal file
855
internal/api/handler/sso_handler_test.go
Normal file
@@ -0,0 +1,855 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/user-management-system/internal/api/handler"
|
||||||
|
"github.com/user-management-system/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func doPostForm(targetURL, token string, data url.Values) (*http.Response, string) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if data != nil {
|
||||||
|
bodyReader = strings.NewReader(data.Encode())
|
||||||
|
}
|
||||||
|
req, _ := http.NewRequest("POST", targetURL, bodyReader)
|
||||||
|
if token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, _ := client.Do(req)
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
return resp, string(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupSSOTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
engine.Use(gin.Recovery())
|
||||||
|
|
||||||
|
ssoManager := auth.NewSSOManager()
|
||||||
|
clientsStore := auth.NewDefaultSSOClientsStore()
|
||||||
|
clientsStore.RegisterClient(&auth.SSOClient{
|
||||||
|
ClientID: "test-client",
|
||||||
|
ClientSecret: "test-secret",
|
||||||
|
Name: "Test Client",
|
||||||
|
RedirectURIs: []string{"http://localhost:8080/callback"},
|
||||||
|
})
|
||||||
|
|
||||||
|
ssoHandler := handler.NewSSOHandler(ssoManager, clientsStore)
|
||||||
|
|
||||||
|
// Simple auth middleware for testing
|
||||||
|
authMiddleware := func() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
token := c.GetHeader("Authorization")
|
||||||
|
if token == "" || token == "Bearer " {
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("user_id", int64(1))
|
||||||
|
c.Set("username", "testuser")
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ssoGroup := engine.Group("/api/v1/sso")
|
||||||
|
ssoGroup.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
ssoGroup.GET("/authorize", ssoHandler.Authorize)
|
||||||
|
ssoGroup.POST("/token", ssoHandler.Token)
|
||||||
|
ssoGroup.POST("/introspect", ssoHandler.Introspect)
|
||||||
|
ssoGroup.POST("/revoke", ssoHandler.Revoke)
|
||||||
|
ssoGroup.GET("/userinfo", ssoHandler.UserInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(engine)
|
||||||
|
return server, func() {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_MissingParams(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_UnsupportedResponseType(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=unsupported", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_CodeFlow(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=xyz", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusFound {
|
||||||
|
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
location := resp.Header.Get("Location")
|
||||||
|
if location == "" {
|
||||||
|
t.Fatal("expected redirect location")
|
||||||
|
}
|
||||||
|
if !strings.Contains(location, "code=") {
|
||||||
|
t.Errorf("expected redirect with code, got %s", location)
|
||||||
|
}
|
||||||
|
if !strings.Contains(location, "state=xyz") {
|
||||||
|
t.Errorf("expected redirect with state, got %s", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_InvalidRedirectURI(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://evil.com/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_TokenFlow(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=token&state=abc", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusFound {
|
||||||
|
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
location := resp.Header.Get("Location")
|
||||||
|
if location == "" {
|
||||||
|
t.Fatal("expected redirect location")
|
||||||
|
}
|
||||||
|
if !strings.Contains(location, "access_token=") {
|
||||||
|
t.Errorf("expected redirect with access_token, got %s", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_MissingParams(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_InvalidGrantType(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "password")
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_InvalidClient(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", "some-code")
|
||||||
|
formData.Set("client_id", "invalid-client")
|
||||||
|
formData.Set("client_secret", "wrong-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_InvalidCode(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", "invalid-code")
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_Success(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// First authorize to get a code
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
if authResp.StatusCode != http.StatusFound {
|
||||||
|
t.Fatalf("expected authorize redirect, got %d", authResp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, err := url.Parse(location)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse redirect URL: %v", err)
|
||||||
|
}
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
t.Fatal("expected authorization code in redirect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for token
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", code)
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp handler.TokenResponse
|
||||||
|
if err := json.Unmarshal([]byte(body), &tokenResp); err != nil {
|
||||||
|
t.Fatalf("failed to parse token response: %v", err)
|
||||||
|
}
|
||||||
|
if tokenResp.AccessToken == "" {
|
||||||
|
t.Errorf("expected access_token in response")
|
||||||
|
}
|
||||||
|
if tokenResp.TokenType != "Bearer" {
|
||||||
|
t.Errorf("expected token_type Bearer, got %s", tokenResp.TokenType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Introspect_MissingToken(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Introspect_InvalidToken(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": "invalid-token",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result handler.IntrospectResponse
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse introspect response: %v", err)
|
||||||
|
}
|
||||||
|
if result.Active != false {
|
||||||
|
t.Errorf("expected active=false for invalid token, got %v", result.Active)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Introspect_ValidToken(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Authorize and get token
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, _ := url.Parse(location)
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
|
||||||
|
tokenForm := url.Values{}
|
||||||
|
tokenForm.Set("grant_type", "authorization_code")
|
||||||
|
tokenForm.Set("code", code)
|
||||||
|
tokenForm.Set("client_id", "test-client")
|
||||||
|
tokenForm.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||||
|
defer tokenResp.Body.Close()
|
||||||
|
|
||||||
|
if tokenResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResult handler.TokenResponse
|
||||||
|
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse token response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Introspect the token
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result handler.IntrospectResponse
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse introspect response: %v", err)
|
||||||
|
}
|
||||||
|
if result.Active != true {
|
||||||
|
t.Errorf("expected active=true for valid token, got %v", result.Active)
|
||||||
|
}
|
||||||
|
if result.UserID != 1 {
|
||||||
|
t.Errorf("expected user_id=1, got %d", result.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Revoke_MissingToken(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Revoke_Success(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Authorize and get token
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, _ := url.Parse(location)
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
|
||||||
|
tokenForm := url.Values{}
|
||||||
|
tokenForm.Set("grant_type", "authorization_code")
|
||||||
|
tokenForm.Set("code", code)
|
||||||
|
tokenForm.Set("client_id", "test-client")
|
||||||
|
tokenForm.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||||
|
defer tokenResp.Body.Close()
|
||||||
|
|
||||||
|
if tokenResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResult handler.TokenResponse
|
||||||
|
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse token response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the token
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify token is revoked
|
||||||
|
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer introspectResp.Body.Close()
|
||||||
|
|
||||||
|
if introspectResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var introspectResult handler.IntrospectResponse
|
||||||
|
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse introspect response: %v", err)
|
||||||
|
}
|
||||||
|
if introspectResult.Active != false {
|
||||||
|
t.Errorf("expected active=false after revoke, got %v", introspectResult.Active)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_UserInfo_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_UserInfo_Success(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := result["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in response, got %s", body)
|
||||||
|
}
|
||||||
|
if data["user_id"] != float64(1) {
|
||||||
|
t.Errorf("expected user_id=1, got %v", data["user_id"])
|
||||||
|
}
|
||||||
|
if data["username"] != "testuser" {
|
||||||
|
t.Errorf("expected username=testuser, got %v", data["username"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_InvalidClientSecret(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Authorize to get a code
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, _ := url.Parse(location)
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", code)
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "wrong-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_MissingClientID(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize?redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Introspect_FormData(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Test that introspect accepts form-encoded data
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("token", "some-token")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/introspect", strings.NewReader(formData.Encode()))
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := json.Marshal(resp.Body)
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_FormData(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Authorize to get a code
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, _ := url.Parse(location)
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
|
||||||
|
// Test that token accepts form-encoded data
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", code)
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/token", strings.NewReader(formData.Encode()))
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Revoke_FormData(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("token", "some-token")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/revoke", strings.NewReader(formData.Encode()))
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := json.Marshal(resp.Body)
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_UnknownClientID(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// When client is unknown, redirect_uri validation fails
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_WithoutAuth(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("code", "some-code")
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
resp, _ := doPostForm(server.URL+"/api/v1/sso/token", "", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_UserInfo_WithoutAuth(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Introspect_WithoutAuth(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doPost(server.URL+"/api/v1/sso/introspect", "", map[string]interface{}{
|
||||||
|
"token": "some-token",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Revoke_WithoutAuth(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doPost(server.URL+"/api/v1/sso/revoke", "", map[string]interface{}{
|
||||||
|
"token": "some-token",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_InvalidClientID(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Test with valid redirect URI but unknown client
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Token_MissingCode(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("client_id", "test-client")
|
||||||
|
formData.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Code is empty, so validate should fail
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_FullFlow(t *testing.T) {
|
||||||
|
server, cleanup := setupSSOTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Step 1: Authorize
|
||||||
|
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=my-state", "Bearer test-token")
|
||||||
|
defer authResp.Body.Close()
|
||||||
|
|
||||||
|
if authResp.StatusCode != http.StatusFound {
|
||||||
|
t.Fatalf("authorize failed: status=%d", authResp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
location := authResp.Header.Get("Location")
|
||||||
|
parsedURL, _ := url.Parse(location)
|
||||||
|
code := parsedURL.Query().Get("code")
|
||||||
|
state := parsedURL.Query().Get("state")
|
||||||
|
if code == "" {
|
||||||
|
t.Fatal("expected authorization code")
|
||||||
|
}
|
||||||
|
if state != "my-state" {
|
||||||
|
t.Errorf("expected state=my-state, got %s", state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Exchange code for token
|
||||||
|
tokenForm := url.Values{}
|
||||||
|
tokenForm.Set("grant_type", "authorization_code")
|
||||||
|
tokenForm.Set("code", code)
|
||||||
|
tokenForm.Set("client_id", "test-client")
|
||||||
|
tokenForm.Set("client_secret", "test-secret")
|
||||||
|
|
||||||
|
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||||
|
defer tokenResp.Body.Close()
|
||||||
|
|
||||||
|
if tokenResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResult handler.TokenResponse
|
||||||
|
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse token response: %v", err)
|
||||||
|
}
|
||||||
|
if tokenResult.AccessToken == "" {
|
||||||
|
t.Fatal("expected access_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Introspect token
|
||||||
|
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer introspectResp.Body.Close()
|
||||||
|
|
||||||
|
if introspectResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var introspectResult handler.IntrospectResponse
|
||||||
|
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse introspect response: %v", err)
|
||||||
|
}
|
||||||
|
if !introspectResult.Active {
|
||||||
|
t.Error("expected token to be active")
|
||||||
|
}
|
||||||
|
if introspectResult.UserID != 1 {
|
||||||
|
t.Errorf("expected user_id=1, got %d", introspectResult.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Get userinfo
|
||||||
|
userinfoResp, userinfoBody := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
|
||||||
|
defer userinfoResp.Body.Close()
|
||||||
|
|
||||||
|
if userinfoResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("userinfo failed: status=%d body=%s", userinfoResp.StatusCode, userinfoBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var userinfoResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(userinfoBody), &userinfoResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse userinfo response: %v", err)
|
||||||
|
}
|
||||||
|
userinfoData, ok := userinfoResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected userinfo data, got %s", userinfoBody)
|
||||||
|
}
|
||||||
|
if userinfoData["username"] != "testuser" {
|
||||||
|
t.Errorf("expected username=testuser, got %v", userinfoData["username"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Revoke token
|
||||||
|
revokeResp, revokeBody := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer revokeResp.Body.Close()
|
||||||
|
|
||||||
|
if revokeResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("revoke failed: status=%d body=%s", revokeResp.StatusCode, revokeBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 6: Verify token is revoked
|
||||||
|
finalIntrospectResp, finalIntrospectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||||
|
"token": tokenResult.AccessToken,
|
||||||
|
})
|
||||||
|
defer finalIntrospectResp.Body.Close()
|
||||||
|
|
||||||
|
if finalIntrospectResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("final introspect failed: status=%d body=%s", finalIntrospectResp.StatusCode, finalIntrospectBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var finalResult handler.IntrospectResponse
|
||||||
|
if err := json.Unmarshal([]byte(finalIntrospectBody), &finalResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse final introspect response: %v", err)
|
||||||
|
}
|
||||||
|
if finalResult.Active {
|
||||||
|
t.Error("expected token to be inactive after revoke")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOHandler_Authorize_NoClientStore(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
ssoManager := auth.NewSSOManager()
|
||||||
|
// Pass nil clientsStore
|
||||||
|
ssoHandler := handler.NewSSOHandler(ssoManager, nil)
|
||||||
|
|
||||||
|
authMiddleware := func() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Set("user_id", int64(1))
|
||||||
|
c.Set("username", "testuser")
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ssoGroup := engine.Group("/api/v1/sso")
|
||||||
|
ssoGroup.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
ssoGroup.GET("/authorize", ssoHandler.Authorize)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(engine)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Without clients store, any redirect_uri should be accepted (or validation skipped)
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=any&redirect_uri=http://any.com/callback&response_type=code", "Bearer test-token")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusFound {
|
||||||
|
t.Errorf("expected redirect when clientsStore is nil, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
685
internal/api/handler/totp_handler_test.go
Normal file
685
internal/api/handler/totp_handler_test.go
Normal file
@@ -0,0 +1,685 @@
|
|||||||
|
package handler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/user-management-system/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTOTPHandler_GetTOTPStatus(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpstatususer", "totpstatus@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpstatususer", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/auth/2fa/status", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := result["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in response, got %s", body)
|
||||||
|
}
|
||||||
|
if data["enabled"] != false {
|
||||||
|
t.Errorf("expected enabled=false for new user, got %v", data["enabled"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_GetTOTPStatus_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/status", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_SetupTOTP(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpsetupuser", "totpsetup@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpsetupuser", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := result["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in response, got %s", body)
|
||||||
|
}
|
||||||
|
if data["secret"] == nil || data["secret"] == "" {
|
||||||
|
t.Errorf("expected secret in setup response, got %+v", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_SetupTOTP_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/setup", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_EnableTOTP(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpenableuser", "totpenable@test.com", "UserPass123!")
|
||||||
|
_ = userID
|
||||||
|
_ = secret
|
||||||
|
|
||||||
|
// setupEnabledTOTPUser already enables TOTP, so let's just verify the user can login with TOTP
|
||||||
|
// Actually, we need a fresh user to test enable
|
||||||
|
registerUser(server.URL, "totpenableuser2", "totpenable2@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpenableuser2", "UserPass123!")
|
||||||
|
|
||||||
|
// Setup TOTP
|
||||||
|
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||||
|
defer setupResp.Body.Close()
|
||||||
|
if setupResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var setupResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse setup response: %v", err)
|
||||||
|
}
|
||||||
|
setupData, ok := setupResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected setup data, got %s", setupBody)
|
||||||
|
}
|
||||||
|
newSecret, ok := setupData["secret"].(string)
|
||||||
|
if !ok || newSecret == "" {
|
||||||
|
t.Fatalf("expected secret in setup response, got %s", setupBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate valid code
|
||||||
|
code, err := auth.NewTOTPManager().GenerateCurrentCode(newSecret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable TOTP
|
||||||
|
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
defer enableResp.Body.Close()
|
||||||
|
|
||||||
|
if enableResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, enableResp.StatusCode, enableBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_EnableTOTP_InvalidCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpenableinv", "totpenableinv@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpenableinv", "UserPass123!")
|
||||||
|
|
||||||
|
// Setup TOTP first
|
||||||
|
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||||
|
defer setupResp.Body.Close()
|
||||||
|
if setupResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try enable with invalid code
|
||||||
|
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
|
||||||
|
"code": "000000",
|
||||||
|
})
|
||||||
|
defer enableResp.Body.Close()
|
||||||
|
|
||||||
|
if enableResp.StatusCode != http.StatusUnauthorized && enableResp.StatusCode != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", enableResp.StatusCode, enableBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_EnableTOTP_MissingCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpenablemiss", "totpenablemiss@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpenablemiss", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_DisableTOTP(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableuser", "totpdisable@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Login again to get a fresh token (since TOTP is enabled, login may require TOTP)
|
||||||
|
deviceID := "test-device"
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpdisableuser",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
if loginResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("login failed: status=%d body=%s", loginResp.StatusCode, loginBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse login response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If requires_totp, we need to verify TOTP first
|
||||||
|
loginData, ok := loginResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected login data, got %s", loginBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("totp verify failed: status=%d body=%s", verifyResp.StatusCode, verifyBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse verify response: %v", err)
|
||||||
|
}
|
||||||
|
verifyData, ok := verifyResult["data"].(map[string]interface{})
|
||||||
|
if ok && verifyData["access_token"] != nil {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate valid code for disable
|
||||||
|
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify TOTP is disabled
|
||||||
|
statusResp, statusBody := doGet(server.URL+"/api/v1/auth/2fa/status", token)
|
||||||
|
defer statusResp.Body.Close()
|
||||||
|
if statusResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status check failed: status=%d body=%s", statusResp.StatusCode, statusBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var statusResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(statusBody), &statusResult); err != nil {
|
||||||
|
t.Fatalf("failed to parse status response: %v", err)
|
||||||
|
}
|
||||||
|
statusData, ok := statusResult["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected status data, got %s", statusBody)
|
||||||
|
}
|
||||||
|
if statusData["enabled"] != false {
|
||||||
|
t.Errorf("expected enabled=false after disable, got %v", statusData["enabled"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_DisableTOTP_InvalidCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableinv", "totpdisableinv@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Get token (might need TOTP verification)
|
||||||
|
deviceID := "test-device"
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpdisableinv",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||||
|
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode == http.StatusOK {
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||||
|
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
|
||||||
|
"code": "000000",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_VerifyTOTP(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyuser", "totpverify@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Get token (might need TOTP verification)
|
||||||
|
deviceID := "test-device"
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpverifyuser",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||||
|
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode == http.StatusOK {
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||||
|
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if result["code"] != float64(0) {
|
||||||
|
t.Errorf("expected code 0, got %v", result["code"])
|
||||||
|
}
|
||||||
|
data, ok := result["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data in response, got %s", body)
|
||||||
|
}
|
||||||
|
if data["verified"] != true {
|
||||||
|
t.Errorf("expected verified=true, got %v", data["verified"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_VerifyTOTP_InvalidCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyinv", "totpverifyinv@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Get token
|
||||||
|
deviceID := "test-device"
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpverifyinv",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||||
|
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode == http.StatusOK {
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||||
|
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
|
||||||
|
"code": "000000",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_VerifyTOTP_MissingCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpverifymiss", "totpverifymiss@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpverifymiss", "UserPass123!")
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_VerifyTOTP_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/verify", "", map[string]interface{}{
|
||||||
|
"code": "123456",
|
||||||
|
"device_id": "test-device",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_DisableTOTP_MissingCode(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisablemiss", "totpdisablemiss@test.com", "UserPass123!")
|
||||||
|
|
||||||
|
// Get token
|
||||||
|
deviceID := "test-device"
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpdisablemiss",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": deviceID,
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||||
|
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"code": code,
|
||||||
|
"device_id": deviceID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode == http.StatusOK {
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||||
|
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_DisableTOTP_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/disable", "", map[string]interface{}{
|
||||||
|
"code": "123456",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_SetupTOTP_AlreadyEnabled(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpsetupenabled", "totpsetupenabled@test.com", "UserPass123!")
|
||||||
|
_ = secret
|
||||||
|
|
||||||
|
// Get token after TOTP login
|
||||||
|
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||||
|
"account": "totpsetupenabled",
|
||||||
|
"password": "UserPass123!",
|
||||||
|
"device_id": "test-device",
|
||||||
|
})
|
||||||
|
defer loginResp.Body.Close()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
var loginResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||||
|
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||||
|
if loginData["requires_totp"] == true {
|
||||||
|
tempToken, _ := loginData["temp_token"].(string)
|
||||||
|
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||||
|
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"temp_token": tempToken,
|
||||||
|
"code": code,
|
||||||
|
"device_id": "test-device",
|
||||||
|
})
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
if verifyResp.StatusCode == http.StatusOK {
|
||||||
|
var verifyResult map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||||
|
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||||
|
token, _ = verifyData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, _ = loginData["access_token"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("failed to get token after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try setup again - should still work or return appropriate response
|
||||||
|
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Setup may return 200 with new secret or error if already enabled
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("unexpected status %d, body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_EnableTOTP_Unauthorized(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/enable", "", map[string]interface{}{
|
||||||
|
"code": "123456",
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPHandler_InvalidJSON(t *testing.T) {
|
||||||
|
server, cleanup := setupHandlerTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerUser(server.URL, "totpjsonuser", "totpjson@test.com", "UserPass123!")
|
||||||
|
token := getToken(server.URL, "totpjsonuser", "UserPass123!")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
method string
|
||||||
|
}{
|
||||||
|
{"enable_invalid_json", "/api/v1/auth/2fa/enable", "POST"},
|
||||||
|
{"disable_invalid_json", "/api/v1/auth/2fa/disable", "POST"},
|
||||||
|
{"verify_invalid_json", "/api/v1/auth/2fa/verify", "POST"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(tc.method, server.URL+tc.path, bytes.NewReader([]byte("not json")))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
102
internal/api/middleware/gzip_test.go
Normal file
102
internal/api/middleware/gzip_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(GzipMiddleware())
|
||||||
|
router.GET("/data", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
|
||||||
|
t.Fatalf("Content-Encoding = %q, want gzip", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gzip.NewReader() error = %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
payload, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
|
||||||
|
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
acceptEncoding string
|
||||||
|
contentType string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "client does not accept gzip",
|
||||||
|
acceptEncoding: "",
|
||||||
|
contentType: "application/json",
|
||||||
|
body: strings.Repeat("b", gzipMinLength+64),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "body below threshold",
|
||||||
|
acceptEncoding: "gzip",
|
||||||
|
contentType: "application/json",
|
||||||
|
body: "small-body",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported content type",
|
||||||
|
acceptEncoding: "gzip",
|
||||||
|
contentType: "image/png",
|
||||||
|
body: strings.Repeat("c", gzipMinLength+64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(GzipMiddleware())
|
||||||
|
router.GET("/data", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", tc.contentType)
|
||||||
|
c.String(http.StatusOK, tc.body)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||||
|
if tc.acceptEncoding != "" {
|
||||||
|
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
|
||||||
|
}
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if got := recorder.Header().Get("Content-Encoding"); got != "" {
|
||||||
|
t.Fatalf("Content-Encoding = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := recorder.Body.String(); got != tc.body {
|
||||||
|
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
165
internal/api/middleware/operation_log_test.go
Normal file
165
internal/api/middleware/operation_log_test.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
|
"github.com/user-management-system/internal/repository"
|
||||||
|
gormsqlite "gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||||
|
DriverName: "sqlite",
|
||||||
|
DSN: "file:operation_log_test?mode=memory&cache=shared",
|
||||||
|
}), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
|
||||||
|
t.Fatalf("migrate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
|
||||||
|
t.Fatalf("cleanup operation_logs failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return repository.NewOperationLogRepository(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list operation logs failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(logs) >= want {
|
||||||
|
return logs
|
||||||
|
}
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list operation logs failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := newOperationLogRepositoryForTest(t)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||||
|
router.GET("/logs", func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list operation logs failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(logs) != 0 {
|
||||||
|
t.Fatalf("expected no logs for GET request, got %d", len(logs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := newOperationLogRepositoryForTest(t)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set("user_id", int64(42))
|
||||||
|
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||||
|
router.POST("/users", func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusCreated)
|
||||||
|
})
|
||||||
|
|
||||||
|
body := `{"username":"alice","password":"super-secret","token":"abc"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
|
||||||
|
req.RemoteAddr = "203.0.113.10:8080"
|
||||||
|
req.Header.Set("User-Agent", "middleware-test")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := waitForOperationLogs(t, repo, 1)
|
||||||
|
entry := logs[0]
|
||||||
|
if entry.UserID == nil || *entry.UserID != 42 {
|
||||||
|
t.Fatalf("user_id = %#v, want 42", entry.UserID)
|
||||||
|
}
|
||||||
|
if entry.OperationType != "admin:CREATE" {
|
||||||
|
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
|
||||||
|
}
|
||||||
|
if entry.ResponseStatus != http.StatusCreated {
|
||||||
|
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
|
||||||
|
}
|
||||||
|
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
|
||||||
|
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
|
||||||
|
if got := methodToType(http.MethodPatch); got != "UPDATE" {
|
||||||
|
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
|
||||||
|
}
|
||||||
|
if got := methodToType(http.MethodDelete); got != "DELETE" {
|
||||||
|
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
|
||||||
|
}
|
||||||
|
if got := methodToType(http.MethodGet); got != "OTHER" {
|
||||||
|
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := []byte(`{"password":"secret","name":"alice"}`)
|
||||||
|
sanitized := sanitizeParams(raw)
|
||||||
|
if strings.Contains(sanitized, "secret") {
|
||||||
|
t.Fatalf("expected password to be masked, got %s", sanitized)
|
||||||
|
}
|
||||||
|
|
||||||
|
plain := sanitizeParams([]byte("not-json"))
|
||||||
|
if plain != "not-json" {
|
||||||
|
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal sanitized params failed: %v", err)
|
||||||
|
}
|
||||||
|
if payload["password"] != "***" {
|
||||||
|
t.Fatalf("password = %q, want ***", payload["password"])
|
||||||
|
}
|
||||||
|
}
|
||||||
114
internal/api/middleware/rbac_test.go
Normal file
114
internal/api/middleware/rbac_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func performRBACRequest(t *testing.T, setup func(*gin.Context), middleware gin.HandlerFunc) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
if setup != nil {
|
||||||
|
router.Use(setup)
|
||||||
|
}
|
||||||
|
router.Use(middleware)
|
||||||
|
router.GET("/protected", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"code": 0})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
return recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequirePermissionRejectsMissingPermission(t *testing.T) {
|
||||||
|
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||||
|
c.Next()
|
||||||
|
}, RequirePermission("users:write"))
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequirePermissionAllowsMatchingPermission(t *testing.T) {
|
||||||
|
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||||
|
c.Next()
|
||||||
|
}, RequirePermission("users:read"))
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAllPermissionsRequiresEveryCode(t *testing.T) {
|
||||||
|
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||||
|
c.Next()
|
||||||
|
}, RequireAllPermissions("users:read", "users:write"))
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAnyPermissionIsAliasOfRequirePermission(t *testing.T) {
|
||||||
|
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyPermissionCodes, []string{"users:write"})
|
||||||
|
c.Next()
|
||||||
|
}, RequireAnyPermission("users:read", "users:write"))
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireRoleAndAdminOnly(t *testing.T) {
|
||||||
|
roleRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyRoleCodes, []string{"auditor"})
|
||||||
|
c.Next()
|
||||||
|
}, RequireRole("admin"))
|
||||||
|
if roleRecorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected role check to return 403, got %d", roleRecorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
adminRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||||
|
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||||
|
c.Next()
|
||||||
|
}, AdminOnly())
|
||||||
|
if adminRecorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected admin check to return 200, got %d", adminRecorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRBACHelpersHandleMissingContextValues(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
|
||||||
|
if got := GetRoleCodes(c); got != nil {
|
||||||
|
t.Fatalf("GetRoleCodes() = %#v, want nil", got)
|
||||||
|
}
|
||||||
|
if got := GetPermissionCodes(c); got != nil {
|
||||||
|
t.Fatalf("GetPermissionCodes() = %#v, want nil", got)
|
||||||
|
}
|
||||||
|
if IsAdmin(c) {
|
||||||
|
t.Fatal("IsAdmin() = true, want false")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||||
|
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||||
|
|
||||||
|
if !IsAdmin(c) {
|
||||||
|
t.Fatal("IsAdmin() = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/api/middleware/response_wrapper_test.go
Normal file
119
internal/api/middleware/response_wrapper_test.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseWrapper_WrapsSuccessfulJSONPayload(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ResponseWrapper())
|
||||||
|
router.GET("/users", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "alice"})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
want := `{"code":0,"data":{"id":1,"name":"alice"},"message":"success"}`
|
||||||
|
if got := recorder.Body.String(); got != want {
|
||||||
|
t.Fatalf("body = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_PassesThroughMarkedResponses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ResponseWrapper())
|
||||||
|
router.GET("/users", func(c *gin.Context) {
|
||||||
|
WrapResponse(c)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "already wrapped"})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
want := `{"code":0,"message":"already wrapped"}`
|
||||||
|
if got := recorder.Body.String(); got != want {
|
||||||
|
t.Fatalf("body = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_PassesThroughNonSuccessStatus(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ResponseWrapper())
|
||||||
|
router.GET("/users", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
want := `{"message":"bad request"}`
|
||||||
|
if got := recorder.Body.String(); got != want {
|
||||||
|
t.Fatalf("body = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_PassesThroughInvalidJSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ResponseWrapper())
|
||||||
|
router.GET("/users", func(c *gin.Context) {
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = c.Writer.WriteString("plain text")
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
if got := recorder.Body.String(); got != "plain text" {
|
||||||
|
t.Fatalf("body = %q, want plain text", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_NoWrapperMarksContext(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(NoWrapper())
|
||||||
|
router.GET("/users", func(c *gin.Context) {
|
||||||
|
if _, exists := c.Get("response_wrapped"); !exists {
|
||||||
|
t.Fatal("expected response_wrapped marker in context")
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
internal/domain/device_test.go
Normal file
136
internal/domain/device_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceType_Constants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value DeviceType
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"Unknown", DeviceTypeUnknown, 0},
|
||||||
|
{"Web", DeviceTypeWeb, 1},
|
||||||
|
{"Mobile", DeviceTypeMobile, 2},
|
||||||
|
{"Desktop", DeviceTypeDesktop, 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if int(tc.value) != tc.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceStatus_Constants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value DeviceStatus
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"Inactive", DeviceStatusInactive, 0},
|
||||||
|
{"Active", DeviceStatusActive, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if int(tc.value) != tc.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_TableName(t *testing.T) {
|
||||||
|
var d Device
|
||||||
|
if got := d.TableName(); got != "devices" {
|
||||||
|
t.Errorf("expected table name 'devices', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_StructFields(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
trustExpires := now.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
d := Device{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 2,
|
||||||
|
DeviceID: "device-123",
|
||||||
|
DeviceName: "Test Device",
|
||||||
|
DeviceType: DeviceTypeWeb,
|
||||||
|
DeviceOS: "Windows",
|
||||||
|
DeviceBrowser: "Chrome",
|
||||||
|
IP: "127.0.0.1",
|
||||||
|
Location: "Beijing",
|
||||||
|
IsTrusted: true,
|
||||||
|
TrustExpiresAt: &trustExpires,
|
||||||
|
Status: DeviceStatusActive,
|
||||||
|
LastActiveTime: now,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.ID != 1 {
|
||||||
|
t.Errorf("expected ID 1, got %d", d.ID)
|
||||||
|
}
|
||||||
|
if d.UserID != 2 {
|
||||||
|
t.Errorf("expected UserID 2, got %d", d.UserID)
|
||||||
|
}
|
||||||
|
if d.DeviceID != "device-123" {
|
||||||
|
t.Errorf("expected DeviceID 'device-123', got %q", d.DeviceID)
|
||||||
|
}
|
||||||
|
if d.DeviceName != "Test Device" {
|
||||||
|
t.Errorf("expected DeviceName 'Test Device', got %q", d.DeviceName)
|
||||||
|
}
|
||||||
|
if d.DeviceType != DeviceTypeWeb {
|
||||||
|
t.Errorf("expected DeviceTypeWeb, got %d", d.DeviceType)
|
||||||
|
}
|
||||||
|
if d.DeviceOS != "Windows" {
|
||||||
|
t.Errorf("expected DeviceOS 'Windows', got %q", d.DeviceOS)
|
||||||
|
}
|
||||||
|
if d.DeviceBrowser != "Chrome" {
|
||||||
|
t.Errorf("expected DeviceBrowser 'Chrome', got %q", d.DeviceBrowser)
|
||||||
|
}
|
||||||
|
if d.IP != "127.0.0.1" {
|
||||||
|
t.Errorf("expected IP '127.0.0.1', got %q", d.IP)
|
||||||
|
}
|
||||||
|
if d.Location != "Beijing" {
|
||||||
|
t.Errorf("expected Location 'Beijing', got %q", d.Location)
|
||||||
|
}
|
||||||
|
if !d.IsTrusted {
|
||||||
|
t.Error("expected IsTrusted to be true")
|
||||||
|
}
|
||||||
|
if d.TrustExpiresAt == nil || !d.TrustExpiresAt.Equal(trustExpires) {
|
||||||
|
t.Error("expected TrustExpiresAt to match")
|
||||||
|
}
|
||||||
|
if d.Status != DeviceStatusActive {
|
||||||
|
t.Errorf("expected DeviceStatusActive, got %d", d.Status)
|
||||||
|
}
|
||||||
|
if d.LastActiveTime.IsZero() {
|
||||||
|
t.Error("expected LastActiveTime to be set")
|
||||||
|
}
|
||||||
|
if d.CreatedAt.IsZero() {
|
||||||
|
t.Error("expected CreatedAt to be set")
|
||||||
|
}
|
||||||
|
if d.UpdatedAt.IsZero() {
|
||||||
|
t.Error("expected UpdatedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_DefaultStatus(t *testing.T) {
|
||||||
|
var d Device
|
||||||
|
if d.Status != DeviceStatusInactive {
|
||||||
|
t.Errorf("expected default status Inactive(0), got %d", d.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_DefaultDeviceType(t *testing.T) {
|
||||||
|
var d Device
|
||||||
|
if d.DeviceType != DeviceTypeUnknown {
|
||||||
|
t.Errorf("expected default device type Unknown(0), got %d", d.DeviceType)
|
||||||
|
}
|
||||||
|
}
|
||||||
35
internal/domain/password_history_test.go
Normal file
35
internal/domain/password_history_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPasswordHistory_TableName(t *testing.T) {
|
||||||
|
var h PasswordHistory
|
||||||
|
if got := h.TableName(); got != "password_histories" {
|
||||||
|
t.Errorf("expected table name 'password_histories', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistory_StructTags(t *testing.T) {
|
||||||
|
h := PasswordHistory{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 2,
|
||||||
|
PasswordHash: "hash123",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.ID != 1 {
|
||||||
|
t.Errorf("expected ID 1, got %d", h.ID)
|
||||||
|
}
|
||||||
|
if h.UserID != 2 {
|
||||||
|
t.Errorf("expected UserID 2, got %d", h.UserID)
|
||||||
|
}
|
||||||
|
if h.PasswordHash != "hash123" {
|
||||||
|
t.Errorf("expected PasswordHash 'hash123', got %q", h.PasswordHash)
|
||||||
|
}
|
||||||
|
if h.CreatedAt.IsZero() {
|
||||||
|
t.Error("expected CreatedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
77
internal/pkg/pagination/pagination_test.go
Normal file
77
internal/pkg/pagination/pagination_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package pagination
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultPagination(t *testing.T) {
|
||||||
|
p := DefaultPagination()
|
||||||
|
if p.Page != 1 {
|
||||||
|
t.Errorf("expected default page 1, got %d", p.Page)
|
||||||
|
}
|
||||||
|
if p.PageSize != 20 {
|
||||||
|
t.Errorf("expected default page_size 20, got %d", p.PageSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPaginationParams_Offset(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
wantOffset int
|
||||||
|
}{
|
||||||
|
{"page 1", 1, 20, 0},
|
||||||
|
{"page 2", 2, 20, 20},
|
||||||
|
{"page 5", 5, 20, 80},
|
||||||
|
{"zero page", 0, 20, 0},
|
||||||
|
{"negative page", -1, 20, 0},
|
||||||
|
{"page 1 size 10", 1, 10, 0},
|
||||||
|
{"page 3 size 10", 3, 10, 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
p := PaginationParams{Page: tc.page, PageSize: tc.pageSize}
|
||||||
|
if got := p.Offset(); got != tc.wantOffset {
|
||||||
|
t.Errorf("expected offset %d, got %d", tc.wantOffset, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPaginationParams_Limit(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pageSize int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"default", 20, 20},
|
||||||
|
{"size 10", 10, 10},
|
||||||
|
{"size 50", 50, 50},
|
||||||
|
{"size 100", 100, 100},
|
||||||
|
{"max cap", 101, 100},
|
||||||
|
{"zero size", 0, 20},
|
||||||
|
{"negative size", -1, 20},
|
||||||
|
{"size 1", 1, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
p := PaginationParams{PageSize: tc.pageSize}
|
||||||
|
if got := p.Limit(); got != tc.want {
|
||||||
|
t.Errorf("expected limit %d, got %d", tc.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPaginationParams_OffsetAndLimit(t *testing.T) {
|
||||||
|
p := PaginationParams{Page: 3, PageSize: 15}
|
||||||
|
if got := p.Offset(); got != 30 {
|
||||||
|
t.Errorf("expected offset 30, got %d", got)
|
||||||
|
}
|
||||||
|
if got := p.Limit(); got != 15 {
|
||||||
|
t.Errorf("expected limit 15, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
internal/repository/pagination_test.go
Normal file
95
internal/repository/pagination_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/user-management-system/internal/pkg/pagination"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPaginationResultFromTotal(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
total int64
|
||||||
|
params pagination.PaginationParams
|
||||||
|
wantPages int
|
||||||
|
wantTotal int64
|
||||||
|
wantPage int
|
||||||
|
wantPageSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact division",
|
||||||
|
total: 100,
|
||||||
|
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||||
|
wantPages: 5,
|
||||||
|
wantTotal: 100,
|
||||||
|
wantPage: 1,
|
||||||
|
wantPageSize: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with remainder",
|
||||||
|
total: 105,
|
||||||
|
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||||
|
wantPages: 6,
|
||||||
|
wantTotal: 105,
|
||||||
|
wantPage: 1,
|
||||||
|
wantPageSize: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero total",
|
||||||
|
total: 0,
|
||||||
|
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||||
|
wantPages: 0,
|
||||||
|
wantTotal: 0,
|
||||||
|
wantPage: 1,
|
||||||
|
wantPageSize: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single page",
|
||||||
|
total: 5,
|
||||||
|
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||||
|
wantPages: 1,
|
||||||
|
wantTotal: 5,
|
||||||
|
wantPage: 1,
|
||||||
|
wantPageSize: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page 2",
|
||||||
|
total: 50,
|
||||||
|
params: pagination.PaginationParams{Page: 2, PageSize: 20},
|
||||||
|
wantPages: 3,
|
||||||
|
wantTotal: 50,
|
||||||
|
wantPage: 2,
|
||||||
|
wantPageSize: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small page size",
|
||||||
|
total: 10,
|
||||||
|
params: pagination.PaginationParams{Page: 1, PageSize: 3},
|
||||||
|
wantPages: 4,
|
||||||
|
wantTotal: 10,
|
||||||
|
wantPage: 1,
|
||||||
|
wantPageSize: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := paginationResultFromTotal(tc.total, tc.params)
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
}
|
||||||
|
if result.Total != tc.wantTotal {
|
||||||
|
t.Errorf("expected total %d, got %d", tc.wantTotal, result.Total)
|
||||||
|
}
|
||||||
|
if result.Page != tc.wantPage {
|
||||||
|
t.Errorf("expected page %d, got %d", tc.wantPage, result.Page)
|
||||||
|
}
|
||||||
|
if result.PageSize != tc.wantPageSize {
|
||||||
|
t.Errorf("expected page_size %d, got %d", tc.wantPageSize, result.PageSize)
|
||||||
|
}
|
||||||
|
if result.Pages != tc.wantPages {
|
||||||
|
t.Errorf("expected pages %d, got %d", tc.wantPages, result.Pages)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
224
internal/repository/password_history_test.go
Normal file
224
internal/repository/password_history_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_Create(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
history := &domain.PasswordHistory{
|
||||||
|
UserID: 1,
|
||||||
|
PasswordHash: "hash1",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := repo.Create(ctx, history); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
if history.ID == 0 {
|
||||||
|
t.Error("expected ID to be set after create")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_GetByUserID(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create multiple records for user 1
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
h := &domain.PasswordHistory{
|
||||||
|
UserID: 1,
|
||||||
|
PasswordHash: "hash",
|
||||||
|
CreatedAt: time.Now().Add(time.Duration(i) * time.Second),
|
||||||
|
}
|
||||||
|
if err := repo.Create(ctx, h); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create record for user 2
|
||||||
|
if err := repo.Create(ctx, &domain.PasswordHistory{UserID: 2, PasswordHash: "hash", CreatedAt: time.Now()}); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID int64
|
||||||
|
limit int
|
||||||
|
wantLen int
|
||||||
|
wantUser int64
|
||||||
|
}{
|
||||||
|
{"get all for user 1", 1, 10, 5, 1},
|
||||||
|
{"limit 3 for user 1", 1, 3, 3, 1},
|
||||||
|
{"get for user 2", 2, 10, 1, 2},
|
||||||
|
{"get for nonexistent user", 999, 10, 0, 999},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
histories, err := repo.GetByUserID(ctx, tc.userID, tc.limit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(histories) != tc.wantLen {
|
||||||
|
t.Errorf("expected %d histories, got %d", tc.wantLen, len(histories))
|
||||||
|
}
|
||||||
|
for _, h := range histories {
|
||||||
|
if h.UserID != tc.wantUser {
|
||||||
|
t.Errorf("expected user_id %d, got %d", tc.wantUser, h.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_GetByUserID_Order(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create records with different timestamps
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
h := &domain.PasswordHistory{
|
||||||
|
UserID: 1,
|
||||||
|
PasswordHash: "hash",
|
||||||
|
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||||
|
}
|
||||||
|
if err := repo.Create(ctx, h); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(histories) != 3 {
|
||||||
|
t.Fatalf("expected 3 histories, got %d", len(histories))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be ordered by created_at DESC (newest first)
|
||||||
|
for i := 0; i < len(histories)-1; i++ {
|
||||||
|
if !histories[i].CreatedAt.After(histories[i+1].CreatedAt) && !histories[i].CreatedAt.Equal(histories[i+1].CreatedAt) {
|
||||||
|
t.Errorf("expected descending order, got %v before %v", histories[i].CreatedAt, histories[i+1].CreatedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_DeleteOldRecords(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 5 records for user 1
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
h := &domain.PasswordHistory{
|
||||||
|
UserID: 1,
|
||||||
|
PasswordHash: "hash",
|
||||||
|
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||||
|
}
|
||||||
|
if err := repo.Create(ctx, h); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete old records, keep only 3
|
||||||
|
if err := repo.DeleteOldRecords(ctx, 1, 3); err != nil {
|
||||||
|
t.Fatalf("delete old records failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(histories) != 3 {
|
||||||
|
t.Errorf("expected 3 histories after cleanup, got %d", len(histories))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_DeleteOldRecords_NoRecords(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Should not error when no records exist
|
||||||
|
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
|
||||||
|
t.Fatalf("delete old records on empty table should not error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordHistoryRepository_KeepsNewestRecords(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||||
|
t.Fatalf("migrate password_history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := NewPasswordHistoryRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 5 records with different timestamps
|
||||||
|
now := time.Now()
|
||||||
|
var createdIDs []int64
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
h := &domain.PasswordHistory{
|
||||||
|
UserID: 1,
|
||||||
|
PasswordHash: "hash",
|
||||||
|
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||||
|
}
|
||||||
|
if err := repo.Create(ctx, h); err != nil {
|
||||||
|
t.Fatalf("create failed: %v", err)
|
||||||
|
}
|
||||||
|
createdIDs = append(createdIDs, h.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete old records, keep only 2
|
||||||
|
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
|
||||||
|
t.Fatalf("delete old records failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(histories) != 2 {
|
||||||
|
t.Fatalf("expected 2 histories after cleanup, got %d", len(histories))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The remaining records should be the newest (last 2 created)
|
||||||
|
expectedIDs := map[int64]bool{createdIDs[3]: true, createdIDs[4]: true}
|
||||||
|
for _, h := range histories {
|
||||||
|
if !expectedIDs[h.ID] {
|
||||||
|
t.Errorf("expected remaining IDs to be %v, got %d", expectedIDs, h.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
117
internal/repository/sql_scan_test.go
Normal file
117
internal/repository/sql_scan_test.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockQueryer implements sqlQueryer for testing
|
||||||
|
type mockQueryer struct {
|
||||||
|
rows *sql.Rows
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQueryer) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||||
|
return m.rows, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanSingleRow_QueryError(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
mockErr := errors.New("query failed")
|
||||||
|
q := &mockQueryer{err: mockErr}
|
||||||
|
|
||||||
|
var dest int
|
||||||
|
err := scanSingleRow(ctx, q, "SELECT 1", nil, &dest)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, mockErr) {
|
||||||
|
t.Errorf("expected query error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanSingleRow_NoRows(t *testing.T) {
|
||||||
|
// This test requires a real database connection to create sql.Rows.
|
||||||
|
// scanSingleRow is designed to work with any sqlQueryer, but creating
|
||||||
|
// a mock sql.Rows without a real driver is complex.
|
||||||
|
// We test the behavior through integration with the test database.
|
||||||
|
db := openTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Use the raw sql.DB from gorm
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get sql.DB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dest int
|
||||||
|
err = scanSingleRow(ctx, sqlDB, "SELECT 1 WHERE 1=0", nil, &dest)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for no rows, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
t.Errorf("expected sql.ErrNoRows, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanSingleRow_Success(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get sql.DB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dest int
|
||||||
|
err = scanSingleRow(ctx, sqlDB, "SELECT 42", nil, &dest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if dest != 42 {
|
||||||
|
t.Errorf("expected 42, got %d", dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanSingleRow_MultipleColumns(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get sql.DB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var a, b int
|
||||||
|
err = scanSingleRow(ctx, sqlDB, "SELECT 1, 2", nil, &a, &b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if a != 1 {
|
||||||
|
t.Errorf("expected a=1, got %d", a)
|
||||||
|
}
|
||||||
|
if b != 2 {
|
||||||
|
t.Errorf("expected b=2, got %d", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanSingleRow_StringResult(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get sql.DB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dest string
|
||||||
|
err = scanSingleRow(ctx, sqlDB, "SELECT 'hello'", nil, &dest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if dest != "hello" {
|
||||||
|
t.Errorf("expected 'hello', got %q", dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user