安全修复: - CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证 - HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证 - HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵) - HIGH: 短信验证码熵值过低(4字节) - 提升到6字节 - HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时 - HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern 新功能: - Go SDK: sdk/go/user-management/ 完整SDK实现 - CAS SSO框架: internal/auth/cas.go CAS协议支持 其他: - L1Cache实例问题修复 - AuthMiddleware共享l1Cache - 设备指纹XSS防护 - 内存存储替代localStorage - 响应格式协议中间件 - 导出无界查询修复
465 lines
13 KiB
Go
465 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
cryptorand "crypto/rand"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
aliyunopenapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils"
|
||
aliyunsms "github.com/alibabacloud-go/dysmsapi-20170525/v5/client"
|
||
"github.com/alibabacloud-go/tea/dara"
|
||
tccommon "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||
tcprofile "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||
tcsms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
|
||
)
|
||
|
||
var (
|
||
validPhonePattern = regexp.MustCompile(`^((\+86|86)?1[3-9]\d{9}|\+[1-9]\d{6,14})$`)
|
||
mainlandPhonePattern = regexp.MustCompile(`^1[3-9]\d{9}$`)
|
||
mainlandPhone86Pattern = regexp.MustCompile(`^86(1[3-9]\d{9})$`)
|
||
mainlandPhone0086Pattern = regexp.MustCompile(`^0086(1[3-9]\d{9})$`)
|
||
verificationCodeCharset10 = 1000000
|
||
)
|
||
|
||
// SMSProvider sends one verification code to one phone number.
|
||
type SMSProvider interface {
|
||
SendVerificationCode(ctx context.Context, phone, code string) error
|
||
}
|
||
|
||
// MockSMSProvider is a test helper and is not wired into the server runtime.
|
||
type MockSMSProvider struct{}
|
||
|
||
func (m *MockSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||
_ = ctx
|
||
// 安全:不在日志中记录完整验证码,仅显示部分信息用于调试
|
||
maskedCode := "****"
|
||
if len(code) >= 4 {
|
||
maskedCode = strings.Repeat("*", len(code)-4) + code[len(code)-4:]
|
||
}
|
||
log.Printf("[sms-mock] phone=%s code=%s ttl=5m", phone, maskedCode)
|
||
return nil
|
||
}
|
||
|
||
type aliyunSMSClient interface {
|
||
SendSms(request *aliyunsms.SendSmsRequest) (*aliyunsms.SendSmsResponse, error)
|
||
}
|
||
|
||
type tencentSMSClient interface {
|
||
SendSmsWithContext(ctx context.Context, request *tcsms.SendSmsRequest) (*tcsms.SendSmsResponse, error)
|
||
}
|
||
|
||
type AliyunSMSConfig struct {
|
||
AccessKeyID string
|
||
AccessKeySecret string
|
||
SignName string
|
||
TemplateCode string
|
||
Endpoint string
|
||
RegionID string
|
||
CodeParamName string
|
||
}
|
||
|
||
type AliyunSMSProvider struct {
|
||
cfg AliyunSMSConfig
|
||
client aliyunSMSClient
|
||
}
|
||
|
||
func NewAliyunSMSProvider(cfg AliyunSMSConfig) (SMSProvider, error) {
|
||
cfg = normalizeAliyunSMSConfig(cfg)
|
||
if cfg.AccessKeyID == "" || cfg.AccessKeySecret == "" || cfg.SignName == "" || cfg.TemplateCode == "" {
|
||
return nil, fmt.Errorf("aliyun SMS config is incomplete")
|
||
}
|
||
|
||
client, err := newAliyunSMSClient(cfg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create aliyun SMS client failed: %w", err)
|
||
}
|
||
|
||
return &AliyunSMSProvider{
|
||
cfg: cfg,
|
||
client: client,
|
||
}, nil
|
||
}
|
||
|
||
func newAliyunSMSClient(cfg AliyunSMSConfig) (aliyunSMSClient, error) {
|
||
client, err := aliyunsms.NewClient(&aliyunopenapiutil.Config{
|
||
AccessKeyId: dara.String(cfg.AccessKeyID),
|
||
AccessKeySecret: dara.String(cfg.AccessKeySecret),
|
||
Endpoint: stringPointerOrNil(cfg.Endpoint),
|
||
RegionId: dara.String(cfg.RegionID),
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return client, nil
|
||
}
|
||
|
||
func (a *AliyunSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||
_ = ctx
|
||
|
||
templateParam, err := json.Marshal(map[string]string{
|
||
a.cfg.CodeParamName: code,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("marshal aliyun SMS template param failed: %w", err)
|
||
}
|
||
|
||
resp, err := a.client.SendSms(
|
||
new(aliyunsms.SendSmsRequest).
|
||
SetPhoneNumbers(normalizePhoneForSMS(phone)).
|
||
SetSignName(a.cfg.SignName).
|
||
SetTemplateCode(a.cfg.TemplateCode).
|
||
SetTemplateParam(string(templateParam)),
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("aliyun SMS request failed: %w", err)
|
||
}
|
||
if resp == nil || resp.Body == nil {
|
||
return fmt.Errorf("aliyun SMS returned empty response")
|
||
}
|
||
|
||
body := resp.Body
|
||
if !strings.EqualFold(dara.StringValue(body.Code), "OK") {
|
||
return fmt.Errorf(
|
||
"aliyun SMS rejected: code=%s message=%s request_id=%s",
|
||
valueOrDefault(dara.StringValue(body.Code), "unknown"),
|
||
valueOrDefault(dara.StringValue(body.Message), "unknown"),
|
||
valueOrDefault(dara.StringValue(body.RequestId), "unknown"),
|
||
)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type TencentSMSConfig struct {
|
||
SecretID string
|
||
SecretKey string
|
||
AppID string
|
||
SignName string
|
||
TemplateID string
|
||
Region string
|
||
Endpoint string
|
||
}
|
||
|
||
type TencentSMSProvider struct {
|
||
cfg TencentSMSConfig
|
||
client tencentSMSClient
|
||
}
|
||
|
||
func NewTencentSMSProvider(cfg TencentSMSConfig) (SMSProvider, error) {
|
||
cfg = normalizeTencentSMSConfig(cfg)
|
||
if cfg.SecretID == "" || cfg.SecretKey == "" || cfg.AppID == "" || cfg.SignName == "" || cfg.TemplateID == "" {
|
||
return nil, fmt.Errorf("tencent SMS config is incomplete")
|
||
}
|
||
|
||
client, err := newTencentSMSClient(cfg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create tencent SMS client failed: %w", err)
|
||
}
|
||
|
||
return &TencentSMSProvider{
|
||
cfg: cfg,
|
||
client: client,
|
||
}, nil
|
||
}
|
||
|
||
func newTencentSMSClient(cfg TencentSMSConfig) (tencentSMSClient, error) {
|
||
clientProfile := tcprofile.NewClientProfile()
|
||
clientProfile.HttpProfile.ReqTimeout = 30
|
||
if cfg.Endpoint != "" {
|
||
clientProfile.HttpProfile.Endpoint = cfg.Endpoint
|
||
}
|
||
|
||
client, err := tcsms.NewClient(
|
||
tccommon.NewCredential(cfg.SecretID, cfg.SecretKey),
|
||
cfg.Region,
|
||
clientProfile,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
func (t *TencentSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||
req := tcsms.NewSendSmsRequest()
|
||
req.PhoneNumberSet = []*string{tccommon.StringPtr(normalizePhoneForSMS(phone))}
|
||
req.SmsSdkAppId = tccommon.StringPtr(t.cfg.AppID)
|
||
req.SignName = tccommon.StringPtr(t.cfg.SignName)
|
||
req.TemplateId = tccommon.StringPtr(t.cfg.TemplateID)
|
||
req.TemplateParamSet = []*string{tccommon.StringPtr(code)}
|
||
|
||
resp, err := t.client.SendSmsWithContext(ctx, req)
|
||
if err != nil {
|
||
return fmt.Errorf("tencent SMS request failed: %w", err)
|
||
}
|
||
if resp == nil || resp.Response == nil {
|
||
return fmt.Errorf("tencent SMS returned empty response")
|
||
}
|
||
if len(resp.Response.SendStatusSet) == 0 {
|
||
return fmt.Errorf(
|
||
"tencent SMS returned empty status list: request_id=%s",
|
||
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
||
)
|
||
}
|
||
|
||
status := resp.Response.SendStatusSet[0]
|
||
if !strings.EqualFold(pointerString(status.Code), "Ok") {
|
||
return fmt.Errorf(
|
||
"tencent SMS rejected: code=%s message=%s request_id=%s",
|
||
valueOrDefault(pointerString(status.Code), "unknown"),
|
||
valueOrDefault(pointerString(status.Message), "unknown"),
|
||
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
||
)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type SMSCodeConfig struct {
|
||
CodeTTL time.Duration
|
||
ResendCooldown time.Duration
|
||
MaxDailyLimit int
|
||
}
|
||
|
||
func DefaultSMSCodeConfig() SMSCodeConfig {
|
||
return SMSCodeConfig{
|
||
CodeTTL: 5 * time.Minute,
|
||
ResendCooldown: time.Minute,
|
||
MaxDailyLimit: 10,
|
||
}
|
||
}
|
||
|
||
type SMSCodeService struct {
|
||
provider SMSProvider
|
||
cache cacheInterface
|
||
cfg SMSCodeConfig
|
||
}
|
||
|
||
type cacheInterface interface {
|
||
Get(ctx context.Context, key string) (interface{}, bool)
|
||
Set(ctx context.Context, key string, value interface{}, l1TTL, l2TTL time.Duration) error
|
||
Delete(ctx context.Context, key string) error
|
||
}
|
||
|
||
func NewSMSCodeService(provider SMSProvider, cacheManager cacheInterface, cfg SMSCodeConfig) *SMSCodeService {
|
||
if cfg.CodeTTL <= 0 {
|
||
cfg.CodeTTL = 5 * time.Minute
|
||
}
|
||
if cfg.ResendCooldown <= 0 {
|
||
cfg.ResendCooldown = time.Minute
|
||
}
|
||
if cfg.MaxDailyLimit <= 0 {
|
||
cfg.MaxDailyLimit = 10
|
||
}
|
||
|
||
return &SMSCodeService{
|
||
provider: provider,
|
||
cache: cacheManager,
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
type SendCodeRequest struct {
|
||
Phone string `json:"phone" binding:"required"`
|
||
Purpose string `json:"purpose"`
|
||
Scene string `json:"scene"`
|
||
}
|
||
|
||
type SendCodeResponse struct {
|
||
ExpiresIn int `json:"expires_in"`
|
||
Cooldown int `json:"cooldown"`
|
||
}
|
||
|
||
func (s *SMSCodeService) SendCode(ctx context.Context, req *SendCodeRequest) (*SendCodeResponse, error) {
|
||
if s == nil || s.provider == nil || s.cache == nil {
|
||
return nil, fmt.Errorf("sms code service is not configured")
|
||
}
|
||
if req == nil {
|
||
return nil, newValidationError("\u8bf7\u6c42\u4e0d\u80fd\u4e3a\u7a7a")
|
||
}
|
||
|
||
phone := strings.TrimSpace(req.Phone)
|
||
if !isValidPhone(phone) {
|
||
return nil, newValidationError("\u624b\u673a\u53f7\u7801\u683c\u5f0f\u4e0d\u6b63\u786e")
|
||
}
|
||
purpose := strings.TrimSpace(req.Purpose)
|
||
if purpose == "" {
|
||
purpose = strings.TrimSpace(req.Scene)
|
||
}
|
||
|
||
cooldownKey := fmt.Sprintf("sms_cooldown:%s", phone)
|
||
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
|
||
return nil, newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
|
||
}
|
||
|
||
dailyKey := fmt.Sprintf("sms_daily:%s:%s", phone, time.Now().Format("2006-01-02"))
|
||
var dailyCount int
|
||
if val, ok := s.cache.Get(ctx, dailyKey); ok {
|
||
if n, ok := intValue(val); ok {
|
||
dailyCount = n
|
||
}
|
||
}
|
||
if dailyCount >= s.cfg.MaxDailyLimit {
|
||
return nil, newRateLimitError(fmt.Sprintf("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff08%d\u6b21\uff09\uff0c\u8bf7\u660e\u65e5\u518d\u8bd5", s.cfg.MaxDailyLimit))
|
||
}
|
||
|
||
code, err := generateSMSCode()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate sms code failed: %w", err)
|
||
}
|
||
|
||
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
||
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
|
||
return nil, fmt.Errorf("store sms code failed: %w", err)
|
||
}
|
||
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
|
||
_ = s.cache.Delete(ctx, codeKey)
|
||
return nil, fmt.Errorf("store sms cooldown failed: %w", err)
|
||
}
|
||
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
|
||
_ = s.cache.Delete(ctx, codeKey)
|
||
_ = s.cache.Delete(ctx, cooldownKey)
|
||
return nil, fmt.Errorf("store sms daily counter failed: %w", err)
|
||
}
|
||
|
||
if err := s.provider.SendVerificationCode(ctx, phone, code); err != nil {
|
||
_ = s.cache.Delete(ctx, codeKey)
|
||
_ = s.cache.Delete(ctx, cooldownKey)
|
||
return nil, fmt.Errorf("\u77ed\u4fe1\u53d1\u9001\u5931\u8d25: %w", err)
|
||
}
|
||
|
||
return &SendCodeResponse{
|
||
ExpiresIn: int(s.cfg.CodeTTL.Seconds()),
|
||
Cooldown: int(s.cfg.ResendCooldown.Seconds()),
|
||
}, nil
|
||
}
|
||
|
||
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code string) error {
|
||
if s == nil || s.cache == nil {
|
||
return fmt.Errorf("sms code service is not configured")
|
||
}
|
||
if strings.TrimSpace(code) == "" {
|
||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u80fd\u4e3a\u7a7a")
|
||
}
|
||
|
||
phone = strings.TrimSpace(phone)
|
||
purpose = strings.TrimSpace(purpose)
|
||
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
||
val, ok := s.cache.Get(ctx, codeKey)
|
||
if !ok {
|
||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u5df2\u8fc7\u671f\u6216\u4e0d\u5b58\u5728")
|
||
}
|
||
|
||
stored, ok := val.(string)
|
||
if !ok || stored != code {
|
||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
|
||
}
|
||
|
||
if err := s.cache.Delete(ctx, codeKey); err != nil {
|
||
return fmt.Errorf("consume sms code failed: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func isValidPhone(phone string) bool {
|
||
return validPhonePattern.MatchString(strings.TrimSpace(phone))
|
||
}
|
||
|
||
func generateSMSCode() (string, error) {
|
||
// 使用 6 字节随机数提供足够的熵(48 位)
|
||
b := make([]byte, 6)
|
||
if _, err := cryptorand.Read(b); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
n := int(b[0])<<40 | int(b[1])<<32 | int(b[2])<<24 |
|
||
int(b[3])<<16 | int(b[4])<<8 | int(b[5])
|
||
if n < 0 {
|
||
n = -n
|
||
}
|
||
n = n % verificationCodeCharset10
|
||
if n < 100000 {
|
||
n += 100000
|
||
}
|
||
|
||
return fmt.Sprintf("%06d", n), nil
|
||
}
|
||
|
||
func normalizeAliyunSMSConfig(cfg AliyunSMSConfig) AliyunSMSConfig {
|
||
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
|
||
cfg.AccessKeySecret = strings.TrimSpace(cfg.AccessKeySecret)
|
||
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
||
cfg.TemplateCode = strings.TrimSpace(cfg.TemplateCode)
|
||
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
||
cfg.RegionID = strings.TrimSpace(cfg.RegionID)
|
||
cfg.CodeParamName = strings.TrimSpace(cfg.CodeParamName)
|
||
|
||
if cfg.RegionID == "" {
|
||
cfg.RegionID = "cn-hangzhou"
|
||
}
|
||
if cfg.CodeParamName == "" {
|
||
cfg.CodeParamName = "code"
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
func normalizeTencentSMSConfig(cfg TencentSMSConfig) TencentSMSConfig {
|
||
cfg.SecretID = strings.TrimSpace(cfg.SecretID)
|
||
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
|
||
cfg.AppID = strings.TrimSpace(cfg.AppID)
|
||
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
||
cfg.TemplateID = strings.TrimSpace(cfg.TemplateID)
|
||
cfg.Region = strings.TrimSpace(cfg.Region)
|
||
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
||
|
||
if cfg.Region == "" {
|
||
cfg.Region = "ap-guangzhou"
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
func normalizePhoneForSMS(phone string) string {
|
||
phone = strings.TrimSpace(phone)
|
||
|
||
switch {
|
||
case mainlandPhonePattern.MatchString(phone):
|
||
return "+86" + phone
|
||
case mainlandPhone86Pattern.MatchString(phone):
|
||
return "+" + phone
|
||
case mainlandPhone0086Pattern.MatchString(phone):
|
||
return "+86" + mainlandPhone0086Pattern.ReplaceAllString(phone, "$1")
|
||
default:
|
||
return phone
|
||
}
|
||
}
|
||
|
||
func stringPointerOrNil(value string) *string {
|
||
if value == "" {
|
||
return nil
|
||
}
|
||
return dara.String(value)
|
||
}
|
||
|
||
func pointerString(value *string) string {
|
||
if value == nil {
|
||
return ""
|
||
}
|
||
return *value
|
||
}
|
||
|
||
func valueOrDefault(value, fallback string) string {
|
||
if strings.TrimSpace(value) == "" {
|
||
return fallback
|
||
}
|
||
return value
|
||
}
|