package providers import ( "context" "fmt" "io" "net/http" "strings" "testing" ) type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return fn(req) } func useDefaultTransport(t *testing.T, fn roundTripFunc) { t.Helper() originalTransport := http.DefaultTransport http.DefaultTransport = fn t.Cleanup(func() { http.DefaultTransport = originalTransport }) } func oauthResponse(body string) *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header), } } func TestQQProviderGetOpenIDAndUserInfoWithDefaultTransport(t *testing.T) { ctx := context.Background() provider := NewQQProvider("qq-app", "qq-secret", "https://example.com/callback") t.Run("get openid success", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/me" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"client_id":"qq-app","openid":"openid-123"}`), nil })) resp, err := provider.GetOpenID(ctx, "access-token") if err != nil { t.Fatalf("expected openid success, got error %v", err) } if resp.OpenID != "openid-123" || resp.ClientID != "qq-app" { t.Fatalf("unexpected openid response: %#v", resp) } }) t.Run("get openid parse error", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/me" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`not-json`), nil })) _, err := provider.GetOpenID(ctx, "access-token") if err == nil || !strings.Contains(err.Error(), "parse openid response failed") { t.Fatalf("expected openid parse error, got %v", err) } }) t.Run("get user info api error", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.qq.com" || req.URL.Path != "/user/get_user_info" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"ret":1001,"msg":"invalid token"}`), nil })) _, err := provider.GetUserInfo(ctx, "access-token", "openid-123") if err == nil || !strings.Contains(err.Error(), "qq api error: invalid token") { t.Fatalf("expected qq api error, got %v", err) } }) t.Run("get user info success", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.qq.com" || req.URL.Path != "/user/get_user_info" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"ret":0,"msg":"","nickname":"tester","gender":"male","city":"Shanghai"}`), nil })) info, err := provider.GetUserInfo(ctx, "access-token", "openid-123") if err != nil { t.Fatalf("expected user info success, got error %v", err) } if info.Nickname != "tester" || info.City != "Shanghai" { t.Fatalf("unexpected user info response: %#v", info) } }) } func TestWeiboProviderValidateTokenWithDefaultTransport(t *testing.T) { ctx := context.Background() provider := NewWeiboProvider("weibo-app", "weibo-secret", "https://example.com/callback") tests := []struct { name string body string wantValid bool wantErrContains string }{ { name: "rejects error response", body: `{"error":"invalid_token"}`, wantValid: false, }, { name: "accepts expire_in response", body: `{"expire_in":3600}`, wantValid: true, }, { name: "rejects ambiguous response", body: `{"uid":"123"}`, wantValid: false, }, { name: "returns parse error", body: `not-json`, wantErrContains: "parse response failed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "api.weibo.com" || req.URL.Path != "/oauth2/get_token_info" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(tt.body), nil })) valid, err := provider.ValidateToken(ctx, "access-token") if tt.wantErrContains != "" { if err == nil || !strings.Contains(err.Error(), tt.wantErrContains) { t.Fatalf("expected error containing %q, got %v", tt.wantErrContains, err) } return } if err != nil { t.Fatalf("expected no error, got %v", err) } if valid != tt.wantValid { t.Fatalf("expected valid=%v, got %v", tt.wantValid, valid) } }) } } func TestWeChatProviderValidateTokenWithDefaultTransport(t *testing.T) { ctx := context.Background() provider := NewWeChatProvider("wx-app", "wx-secret", "web") tests := []struct { name string body string wantValid bool wantErrContains string }{ { name: "accepts errcode zero", body: `{"errcode":0,"errmsg":"ok"}`, wantValid: true, }, { name: "rejects non-zero errcode", body: `{"errcode":40003,"errmsg":"invalid openid"}`, wantValid: false, }, { name: "returns parse error", body: `not-json`, wantErrContains: "parse response failed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "api.weixin.qq.com" || req.URL.Path != "/sns/auth" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(tt.body), nil })) valid, err := provider.ValidateToken(ctx, "access-token", "openid-123") if tt.wantErrContains != "" { if err == nil || !strings.Contains(err.Error(), tt.wantErrContains) { t.Fatalf("expected error containing %q, got %v", tt.wantErrContains, err) } return } if err != nil { t.Fatalf("expected no error, got %v", err) } if valid != tt.wantValid { t.Fatalf("expected valid=%v, got %v", tt.wantValid, valid) } }) } } func TestGoogleProviderValidateTokenWithDefaultTransport(t *testing.T) { ctx := context.Background() provider := NewGoogleProvider("google-client", "google-secret", "https://example.com/callback") t.Run("validate token success", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "www.googleapis.com" || req.URL.Path != "/oauth2/v2/userinfo" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"id":"user-1","email":"user@example.com","name":"Google User"}`), nil })) valid, err := provider.ValidateToken(ctx, "access-token") if err != nil { t.Fatalf("expected success, got error %v", err) } if !valid { t.Fatal("expected token to be valid") } }) t.Run("validate token parse error", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "www.googleapis.com" || req.URL.Path != "/oauth2/v2/userinfo" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`not-json`), nil })) valid, err := provider.ValidateToken(ctx, "access-token") if err == nil || !strings.Contains(err.Error(), "parse user info failed") { t.Fatalf("expected user info parse error, got valid=%v err=%v", valid, err) } }) } func TestFacebookProviderGetUserInfoWithDefaultTransport(t *testing.T) { ctx := context.Background() provider := NewFacebookProvider("facebook-app", "facebook-secret", "https://example.com/callback") t.Run("facebook api error", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.facebook.com" || req.URL.Path != "/v18.0/me" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"error":{"message":"token expired","type":"OAuthException","code":190}}`), nil })) _, err := provider.GetUserInfo(ctx, "access-token") if err == nil || !strings.Contains(err.Error(), "facebook api error: token expired") { t.Fatalf("expected facebook api error, got %v", err) } }) t.Run("facebook success", func(t *testing.T) { useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.URL.Host != "graph.facebook.com" || req.URL.Path != "/v18.0/me" { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) } return oauthResponse(`{"id":"user-1","name":"Facebook User","email":"fb@example.com","picture":{"data":{"url":"https://cdn.example.com/a.png"}}}`), nil })) info, err := provider.GetUserInfo(ctx, "access-token") if err != nil { t.Fatalf("expected user info success, got error %v", err) } if info.ID != "user-1" || info.Picture.Data.URL == "" { t.Fatalf("unexpected facebook user info response: %#v", info) } }) }