189 lines
5.3 KiB
Go
189 lines
5.3 KiB
Go
//go:build llm_script
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type intradayProviderConfig struct {
|
|
Mode string
|
|
Command string
|
|
URL string
|
|
Fixture string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
type intradaySearchRecord struct {
|
|
Title string `json:"title"`
|
|
Summary string `json:"summary"`
|
|
URL string `json:"url"`
|
|
Provider string `json:"provider"`
|
|
ProviderURL string `json:"provider_url"`
|
|
PublishedAt string `json:"published_at"`
|
|
}
|
|
|
|
type intradayLLMRecord struct {
|
|
EventType string `json:"event_type"`
|
|
ProviderName string `json:"provider_name"`
|
|
ModelName string `json:"model_name"`
|
|
ProviderCountry string `json:"provider_country"`
|
|
Title string `json:"title"`
|
|
Summary string `json:"summary"`
|
|
CandidateURLs []string `json:"candidate_urls"`
|
|
}
|
|
|
|
type intradayLLMRequest struct {
|
|
Date string `json:"date"`
|
|
SearchResults []intradaySearchRecord `json:"search_results"`
|
|
}
|
|
|
|
func loadIntradaySearchRecords(cfg intradayProviderConfig, date string, queries []string) ([]intradaySearchRecord, error) {
|
|
var all []intradaySearchRecord
|
|
for _, query := range queries {
|
|
payload, err := loadIntradayProviderPayload(cfg, intradayProviderPayloadInput{
|
|
Date: date,
|
|
Query: query,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(bytes.TrimSpace(payload)) == 0 {
|
|
continue
|
|
}
|
|
var records []intradaySearchRecord
|
|
if err := json.Unmarshal(payload, &records); err != nil {
|
|
return nil, fmt.Errorf("unmarshal search records for query %q: %w", query, err)
|
|
}
|
|
all = append(all, records...)
|
|
if cfg.Mode == "fixture" {
|
|
break
|
|
}
|
|
}
|
|
return all, nil
|
|
}
|
|
|
|
func loadIntradayLLMRecords(cfg intradayProviderConfig, date string, searchResults []intradaySearchRecord) ([]intradayLLMRecord, error) {
|
|
request := intradayLLMRequest{Date: date, SearchResults: searchResults}
|
|
body, err := json.Marshal(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal llm request: %w", err)
|
|
}
|
|
payload, err := loadIntradayProviderPayload(cfg, intradayProviderPayloadInput{
|
|
Date: date,
|
|
RequestBody: body,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(bytes.TrimSpace(payload)) == 0 {
|
|
return nil, nil
|
|
}
|
|
var records []intradayLLMRecord
|
|
if err := json.Unmarshal(payload, &records); err != nil {
|
|
return nil, fmt.Errorf("unmarshal llm records: %w", err)
|
|
}
|
|
return records, nil
|
|
}
|
|
|
|
type intradayProviderPayloadInput struct {
|
|
Date string
|
|
Query string
|
|
RequestBody []byte
|
|
}
|
|
|
|
func loadIntradayProviderPayload(cfg intradayProviderConfig, input intradayProviderPayloadInput) ([]byte, error) {
|
|
mode := strings.TrimSpace(cfg.Mode)
|
|
switch mode {
|
|
case "fixture":
|
|
if strings.TrimSpace(cfg.Fixture) == "" {
|
|
return nil, fmt.Errorf("provider fixture 未设置")
|
|
}
|
|
return os.ReadFile(cfg.Fixture)
|
|
case "command_json":
|
|
if strings.TrimSpace(cfg.Command) == "" {
|
|
return nil, fmt.Errorf("provider command 未设置")
|
|
}
|
|
return runIntradayCommand(cfg, input)
|
|
case "http_json":
|
|
if strings.TrimSpace(cfg.URL) == "" {
|
|
return nil, fmt.Errorf("provider url 未设置")
|
|
}
|
|
return fetchIntradayHTTP(cfg, input)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider mode %q", mode)
|
|
}
|
|
}
|
|
|
|
func runIntradayCommand(cfg intradayProviderConfig, input intradayProviderPayloadInput) ([]byte, error) {
|
|
command := strings.TrimSpace(cfg.Command)
|
|
command = strings.ReplaceAll(command, "{{date}}", input.Date)
|
|
command = strings.ReplaceAll(command, "{{query}}", shellEscapeSingleArg(input.Query))
|
|
cmd := exec.Command("sh", "-c", command)
|
|
cmd.Env = append(os.Environ(),
|
|
"INTRADAY_DISCOVERY_DATE="+input.Date,
|
|
"INTRADAY_DISCOVERY_QUERY="+input.Query,
|
|
)
|
|
if len(input.RequestBody) > 0 {
|
|
cmd.Stdin = bytes.NewReader(input.RequestBody)
|
|
}
|
|
out, err := cmd.Output()
|
|
if err != nil {
|
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
return nil, fmt.Errorf("run provider command: %w: %s", err, strings.TrimSpace(string(exitErr.Stderr)))
|
|
}
|
|
return nil, fmt.Errorf("run provider command: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func fetchIntradayHTTP(cfg intradayProviderConfig, input intradayProviderPayloadInput) ([]byte, error) {
|
|
client := &http.Client{Timeout: cfg.Timeout}
|
|
rawURL := strings.TrimSpace(cfg.URL)
|
|
rawURL = strings.ReplaceAll(rawURL, "{{date}}", input.Date)
|
|
rawURL = strings.ReplaceAll(rawURL, "{{query}}", input.Query)
|
|
|
|
method := http.MethodGet
|
|
var body io.Reader
|
|
if len(input.RequestBody) > 0 {
|
|
method = http.MethodPost
|
|
body = bytes.NewReader(input.RequestBody)
|
|
}
|
|
req, err := http.NewRequest(method, rawURL, body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build provider request: %w", err)
|
|
}
|
|
if len(input.RequestBody) > 0 {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("call provider url: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
payload, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("call provider url: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(payload)))
|
|
}
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read provider response: %w", err)
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func shellEscapeSingleArg(value string) string {
|
|
if value == "" {
|
|
return "''"
|
|
}
|
|
return "'" + strings.ReplaceAll(value, "'", "'\"'\"'") + "'"
|
|
}
|