diff --git a/.gitignore b/.gitignore index 183138f96..5c3bee713 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ cli-proxy-api # Configuration config.yaml +config.test.yaml .env # Generated content diff --git a/go.mod b/go.mod index 863d0413c..9abc69e3d 100644 --- a/go.mod +++ b/go.mod @@ -62,8 +62,10 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pjbgf/sha1cd v0.5.0 // indirect + github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/rs/xid v1.5.0 // indirect github.com/sergi/go-diff v1.4.0 // indirect + github.com/shadowsocks/go-shadowsocks2 v0.1.5 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 4705336bf..abc12c070 100644 --- a/go.sum +++ b/go.sum @@ -118,12 +118,16 @@ github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= +github.com/shadowsocks/go-shadowsocks2 v0.1.5 h1:PDSQv9y2S85Fl7VBeOMF9StzeXZyK1HakRm86CUbr28= +github.com/shadowsocks/go-shadowsocks2 v0.1.5/go.mod h1:AGGpIoek4HRno4xzyFiAtLHkOpcoznZEkAccaI/rplM= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= @@ -160,20 +164,27 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index b38677467..4baac2a94 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -32,6 +32,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/util" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "golang.org/x/oauth2" @@ -2383,3 +2385,278 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"status": "wait"}) } + +// ModelHealth represents the health status of a model +type ModelHealth struct { + ModelID string `json:"model_id"` + DisplayName string `json:"display_name,omitempty"` + Status string `json:"status"` // "healthy", "unhealthy" + Message string `json:"message,omitempty"` + Latency int64 `json:"latency_ms,omitempty"` +} + +// CheckAuthFileModelsHealth performs health checks on all models supported by an auth file +// Mimics Cherry Studio's implementation: sends actual generation request and aborts after first chunk +// Automatically uses proxy if configured in auth.ProxyURL +func (h *Handler) CheckAuthFileModelsHealth(c *gin.Context) { + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + + name := c.Query("name") + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"}) + return + } + + // Parse optional query parameters (like Cherry Studio) + isConcurrent := c.DefaultQuery("concurrent", "false") == "true" + timeoutSeconds := 15 + if ts := c.Query("timeout"); ts != "" { + if parsed, err := strconv.Atoi(ts); err == nil && parsed >= 5 && parsed <= 60 { + timeoutSeconds = parsed + } + } + + // Parse optional model filter parameters + // - model: single model to check + // - models: comma-separated list of models to check + // If neither is specified, all models are checked + modelFilter := strings.TrimSpace(c.Query("model")) + modelsFilter := strings.TrimSpace(c.Query("models")) + + // Find auth by name or ID + var targetAuth *coreauth.Auth + if auth, ok := h.authManager.GetByID(name); ok { + targetAuth = auth + } else { + auths := h.authManager.List() + for _, auth := range auths { + if auth.FileName == name || auth.ID == name { + targetAuth = auth + break + } + } + } + + if targetAuth == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"}) + return + } + + // Get models from registry + reg := registry.GetGlobalRegistry() + models := reg.GetModelsForClient(targetAuth.ID) + + // Apply model filter if specified + if modelFilter != "" || modelsFilter != "" { + filterSet := make(map[string]struct{}) + if modelFilter != "" { + filterSet[strings.ToLower(modelFilter)] = struct{}{} + } + if modelsFilter != "" { + for _, m := range strings.Split(modelsFilter, ",") { + trimmed := strings.TrimSpace(m) + if trimmed != "" { + filterSet[strings.ToLower(trimmed)] = struct{}{} + } + } + } + if len(filterSet) > 0 { + filtered := make([]*registry.ModelInfo, 0) + for _, model := range models { + if _, ok := filterSet[strings.ToLower(model.ID)]; ok { + filtered = append(filtered, model) + } + } + models = filtered + } + } + + if len(models) == 0 { + c.JSON(http.StatusOK, gin.H{ + "auth_id": targetAuth.ID, + "status": "healthy", + "healthy_count": 0, + "unhealthy_count": 0, + "total_count": 0, + "models": []ModelHealth{}, + }) + return + } + + // Prepare health check results + results := make([]ModelHealth, 0, len(models)) + var wg sync.WaitGroup + var mu sync.Mutex + + checkModel := func(model *registry.ModelInfo) { + defer wg.Done() + + startTime := time.Now() + checkCtx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds)*time.Second) + defer cancel() + + // Build minimal OpenAI-format request for health check (mimicking Cherry Studio) + // This will be translated by the executor to the appropriate provider format + openAIRequest := map[string]interface{}{ + "model": model.ID, + "messages": []map[string]interface{}{ + {"role": "user", "content": "hi"}, + {"role": "system", "content": "test"}, + }, + "stream": true, + "max_tokens": 1, + } + + requestJSON, err := json.Marshal(openAIRequest) + if err != nil { + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "unhealthy", + Message: fmt.Sprintf("failed to build request: %v", err), + }) + mu.Unlock() + return + } + + // Build executor request + req := cliproxyexecutor.Request{ + Model: model.ID, + Payload: requestJSON, + Format: sdktranslator.FormatOpenAI, + } + + opts := cliproxyexecutor.Options{ + Stream: true, + SourceFormat: sdktranslator.FormatOpenAI, + OriginalRequest: requestJSON, + } + + // Execute stream directly with the specific auth (not load-balanced) + // This ensures we're testing this exact auth file, not any random auth of the same provider + stream, err := h.authManager.ExecuteStreamWithAuth(checkCtx, targetAuth, req, opts) + if err != nil { + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "unhealthy", + Message: err.Error(), + }) + mu.Unlock() + return + } + + // Wait for first chunk or timeout + select { + case chunk, ok := <-stream: + if ok { + // Check for error in chunk + if chunk.Err != nil { + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "unhealthy", + Message: chunk.Err.Error(), + }) + mu.Unlock() + cancel() + // Drain remaining chunks + go func() { + for range stream { + } + }() + return + } + + // Got first chunk - model is healthy + latency := time.Since(startTime).Milliseconds() + cancel() // Cancel immediately after first chunk + // Drain remaining chunks in background + go func() { + for range stream { + } + }() + + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "healthy", + Latency: latency, + }) + mu.Unlock() + } else { + // Stream closed without data + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "unhealthy", + Message: "stream closed without data", + }) + mu.Unlock() + } + case <-checkCtx.Done(): + // Timeout + mu.Lock() + results = append(results, ModelHealth{ + ModelID: model.ID, + DisplayName: model.DisplayName, + Status: "unhealthy", + Message: "health check timeout", + }) + mu.Unlock() + } + } + + // Execute health checks + if isConcurrent { + // Concurrent execution + for _, model := range models { + wg.Add(1) + go checkModel(model) + } + } else { + // Sequential execution + for _, model := range models { + wg.Add(1) + checkModel(model) + } + } + + wg.Wait() + + // Count results + healthyCount := 0 + unhealthyCount := 0 + for _, result := range results { + if result.Status == "healthy" { + healthyCount++ + } else { + unhealthyCount++ + } + } + + // Determine overall status + overallStatus := "healthy" + if unhealthyCount > 0 && healthyCount == 0 { + overallStatus = "unhealthy" + } else if unhealthyCount > 0 { + overallStatus = "partial" + } + + c.JSON(http.StatusOK, gin.H{ + "auth_id": targetAuth.ID, + "status": overallStatus, + "healthy_count": healthyCount, + "unhealthy_count": unhealthyCount, + "total_count": len(results), + "models": results, + }) +} diff --git a/internal/api/handlers/management/provider_health.go b/internal/api/handlers/management/provider_health.go new file mode 100644 index 000000000..4f5e35b5d --- /dev/null +++ b/internal/api/handlers/management/provider_health.go @@ -0,0 +1,438 @@ +package management + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" +) + +// ProviderInfo represents information about a configured provider +type ProviderInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Label string `json:"label,omitempty"` + Prefix string `json:"prefix,omitempty"` + BaseURL string `json:"base_url,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + APIKey string `json:"api_key,omitempty"` // masked + Status string `json:"status"` + Disabled bool `json:"disabled"` +} + +// ProviderHealth represents the health status of a provider +type ProviderHealth struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Label string `json:"label,omitempty"` + BaseURL string `json:"base_url,omitempty"` + Status string `json:"status"` // "healthy", "unhealthy" + Message string `json:"message,omitempty"` + Latency int64 `json:"latency_ms,omitempty"` + ModelTested string `json:"model_tested,omitempty"` +} + +// ListProviders returns all configured API key providers from configuration +func (h *Handler) ListProviders(c *gin.Context) { + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + + // Filter by type if specified + typeFilter := strings.ToLower(strings.TrimSpace(c.Query("type"))) + + auths := h.authManager.List() + providers := make([]ProviderInfo, 0) + + for _, auth := range auths { + // Only include API key providers (those with api_key attribute) + if !isAPIKeyProvider(auth) { + continue + } + + providerType := getProviderType(auth) + if typeFilter != "" && !strings.EqualFold(providerType, typeFilter) { + continue + } + + info := ProviderInfo{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + Prefix: auth.Prefix, + BaseURL: authAttribute(auth, "base_url"), + ProxyURL: auth.ProxyURL, + APIKey: util.HideAPIKey(authAttribute(auth, "api_key")), + Status: string(auth.Status), + Disabled: auth.Disabled, + } + providers = append(providers, info) + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(providers), + "providers": providers, + }) +} + +// CheckProvidersHealth performs health checks on configured API key providers +func (h *Handler) CheckProvidersHealth(c *gin.Context) { + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + + // Parse query parameters + nameFilter := strings.TrimSpace(c.Query("name")) + typeFilter := strings.ToLower(strings.TrimSpace(c.Query("type"))) + modelFilter := strings.TrimSpace(c.Query("model")) + modelsFilter := strings.TrimSpace(c.Query("models")) + isConcurrent := c.DefaultQuery("concurrent", "false") == "true" + timeoutSeconds := 15 + if ts := c.Query("timeout"); ts != "" { + if parsed, err := strconv.Atoi(ts); err == nil && parsed >= 5 && parsed <= 60 { + timeoutSeconds = parsed + } + } + + // Build model filter set + var modelFilterSet map[string]struct{} + if modelFilter != "" || modelsFilter != "" { + modelFilterSet = make(map[string]struct{}) + if modelFilter != "" { + modelFilterSet[strings.ToLower(modelFilter)] = struct{}{} + } + if modelsFilter != "" { + for _, m := range strings.Split(modelsFilter, ",") { + trimmed := strings.TrimSpace(m) + if trimmed != "" { + modelFilterSet[strings.ToLower(trimmed)] = struct{}{} + } + } + } + } + + // Get all API key providers + auths := h.authManager.List() + targetAuths := make([]*coreauth.Auth, 0) + + for _, auth := range auths { + if !isAPIKeyProvider(auth) { + continue + } + if auth.Disabled { + continue + } + + // Apply name filter + if nameFilter != "" { + if !strings.EqualFold(auth.ID, nameFilter) && + !strings.EqualFold(auth.Provider, nameFilter) && + !strings.EqualFold(auth.Label, nameFilter) { + continue + } + } + + // Apply type filter + providerType := getProviderType(auth) + if typeFilter != "" && !strings.EqualFold(providerType, typeFilter) { + continue + } + + targetAuths = append(targetAuths, auth) + } + + if len(targetAuths) == 0 { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + "healthy_count": 0, + "unhealthy_count": 0, + "total_count": 0, + "providers": []ProviderHealth{}, + }) + return + } + + // Prepare health check results + results := make([]ProviderHealth, 0, len(targetAuths)) + var wg sync.WaitGroup + var mu sync.Mutex + + checkProvider := func(auth *coreauth.Auth) { + defer wg.Done() + + providerType := getProviderType(auth) + baseURL := authAttribute(auth, "base_url") + + // Get models for this provider + reg := registry.GetGlobalRegistry() + models := reg.GetModelsForClient(auth.ID) + + // Apply model filter if specified + if len(modelFilterSet) > 0 { + filtered := make([]*registry.ModelInfo, 0) + for _, model := range models { + if _, ok := modelFilterSet[strings.ToLower(model.ID)]; ok { + filtered = append(filtered, model) + } + } + models = filtered + } + + // If no models available, report as unhealthy + if len(models) == 0 { + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: "no models available for this provider", + }) + mu.Unlock() + return + } + + // Use the first model for health check + testModel := models[0] + + startTime := time.Now() + checkCtx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds)*time.Second) + defer cancel() + + // Build minimal OpenAI-format request for health check + openAIRequest := map[string]interface{}{ + "model": testModel.ID, + "messages": []map[string]interface{}{ + {"role": "user", "content": "hi"}, + {"role": "system", "content": "test"}, + }, + "stream": true, + "max_tokens": 1, + } + + requestJSON, err := json.Marshal(openAIRequest) + if err != nil { + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: fmt.Sprintf("failed to build request: %v", err), + ModelTested: testModel.ID, + }) + mu.Unlock() + return + } + + // Build executor request + req := cliproxyexecutor.Request{ + Model: testModel.ID, + Payload: requestJSON, + Format: sdktranslator.FormatOpenAI, + } + + opts := cliproxyexecutor.Options{ + Stream: true, + SourceFormat: sdktranslator.FormatOpenAI, + OriginalRequest: requestJSON, + } + + // Execute stream directly with the specific auth + stream, err := h.authManager.ExecuteStreamWithAuth(checkCtx, auth, req, opts) + if err != nil { + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: err.Error(), + ModelTested: testModel.ID, + }) + mu.Unlock() + return + } + + // Wait for first chunk or timeout + select { + case chunk, ok := <-stream: + if ok { + if chunk.Err != nil { + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: chunk.Err.Error(), + ModelTested: testModel.ID, + }) + mu.Unlock() + cancel() + go func() { + for range stream { + } + }() + return + } + + // Got first chunk - provider is healthy + latency := time.Since(startTime).Milliseconds() + cancel() + go func() { + for range stream { + } + }() + + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "healthy", + Latency: latency, + ModelTested: testModel.ID, + }) + mu.Unlock() + } else { + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: "stream closed without data", + ModelTested: testModel.ID, + }) + mu.Unlock() + } + case <-checkCtx.Done(): + mu.Lock() + results = append(results, ProviderHealth{ + ID: auth.ID, + Name: auth.Provider, + Type: providerType, + Label: auth.Label, + BaseURL: baseURL, + Status: "unhealthy", + Message: "health check timeout", + ModelTested: testModel.ID, + }) + mu.Unlock() + } + } + + // Execute health checks + if isConcurrent { + for _, auth := range targetAuths { + wg.Add(1) + go checkProvider(auth) + } + } else { + for _, auth := range targetAuths { + wg.Add(1) + checkProvider(auth) + } + } + + wg.Wait() + + // Count results + healthyCount := 0 + unhealthyCount := 0 + for _, result := range results { + if result.Status == "healthy" { + healthyCount++ + } else { + unhealthyCount++ + } + } + + // Determine overall status + overallStatus := "healthy" + if unhealthyCount > 0 && healthyCount == 0 { + overallStatus = "unhealthy" + } else if unhealthyCount > 0 { + overallStatus = "partial" + } + + c.JSON(http.StatusOK, gin.H{ + "status": overallStatus, + "healthy_count": healthyCount, + "unhealthy_count": unhealthyCount, + "total_count": len(results), + "providers": results, + }) +} + +// isAPIKeyProvider checks if an auth entry is an API key provider (from config) +func isAPIKeyProvider(auth *coreauth.Auth) bool { + if auth == nil || auth.Attributes == nil { + return false + } + // Check for api_key attribute or label containing "apikey" + if _, ok := auth.Attributes["api_key"]; ok { + return true + } + if strings.Contains(strings.ToLower(auth.Label), "apikey") { + return true + } + // Check source attribute for config-based providers + source := strings.ToLower(auth.Attributes["source"]) + return strings.HasPrefix(source, "config:") +} + +// getProviderType returns the type of provider (gemini, claude, codex, openai-compatibility, vertex) +func getProviderType(auth *coreauth.Auth) string { + if auth == nil { + return "unknown" + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + switch provider { + case "gemini": + return "gemini-api-key" + case "claude": + return "claude-api-key" + case "codex": + return "codex-api-key" + case "vertex": + return "vertex-api-key" + default: + // Check if it's openai-compatibility + if auth.Attributes != nil { + if _, ok := auth.Attributes["compat_name"]; ok { + return "openai-compatibility" + } + } + return provider + } +} diff --git a/internal/api/modules/unified-routing/config_service.go b/internal/api/modules/unified-routing/config_service.go new file mode 100644 index 000000000..35ef048a6 --- /dev/null +++ b/internal/api/modules/unified-routing/config_service.go @@ -0,0 +1,434 @@ +package unifiedrouting + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// ConfigChangeEvent represents a configuration change event. +type ConfigChangeEvent struct { + Type string // "route_created", "route_updated", "route_deleted", "settings_updated", "pipeline_updated" + RouteID string + Payload any +} + +// ConfigChangeHandler is a callback function for configuration changes. +type ConfigChangeHandler func(event ConfigChangeEvent) + +// ConfigService manages unified routing configuration. +type ConfigService interface { + // Settings + GetSettings(ctx context.Context) (*Settings, error) + UpdateSettings(ctx context.Context, settings *Settings) error + + // Health check config + GetHealthCheckConfig(ctx context.Context) (*HealthCheckConfig, error) + UpdateHealthCheckConfig(ctx context.Context, config *HealthCheckConfig) error + + // Routes + ListRoutes(ctx context.Context) ([]*Route, error) + GetRoute(ctx context.Context, id string) (*Route, error) + CreateRoute(ctx context.Context, route *Route) error + UpdateRoute(ctx context.Context, route *Route) error + DeleteRoute(ctx context.Context, id string) error + + // Pipelines + GetPipeline(ctx context.Context, routeID string) (*Pipeline, error) + UpdatePipeline(ctx context.Context, routeID string, pipeline *Pipeline) error + + // Export/Import + Export(ctx context.Context) (*ExportData, error) + Import(ctx context.Context, data *ExportData, merge bool) error + + // Validation + Validate(ctx context.Context, route *Route, pipeline *Pipeline) []ValidationError + + // Subscriptions + Subscribe(handler ConfigChangeHandler) +} + +// DefaultConfigService implements ConfigService. +type DefaultConfigService struct { + store ConfigStore + mu sync.RWMutex + handlers []ConfigChangeHandler +} + +// NewConfigService creates a new configuration service. +func NewConfigService(store ConfigStore) *DefaultConfigService { + return &DefaultConfigService{ + store: store, + handlers: make([]ConfigChangeHandler, 0), + } +} + +func (s *DefaultConfigService) GetSettings(ctx context.Context) (*Settings, error) { + return s.store.LoadSettings(ctx) +} + +func (s *DefaultConfigService) UpdateSettings(ctx context.Context, settings *Settings) error { + if err := s.store.SaveSettings(ctx, settings); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "settings_updated", + Payload: settings, + }) + + return nil +} + +func (s *DefaultConfigService) GetHealthCheckConfig(ctx context.Context) (*HealthCheckConfig, error) { + return s.store.LoadHealthCheckConfig(ctx) +} + +func (s *DefaultConfigService) UpdateHealthCheckConfig(ctx context.Context, config *HealthCheckConfig) error { + if err := s.store.SaveHealthCheckConfig(ctx, config); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "health_config_updated", + Payload: config, + }) + + return nil +} + +func (s *DefaultConfigService) ListRoutes(ctx context.Context) ([]*Route, error) { + return s.store.ListRoutes(ctx) +} + +func (s *DefaultConfigService) GetRoute(ctx context.Context, id string) (*Route, error) { + return s.store.GetRoute(ctx, id) +} + +func (s *DefaultConfigService) CreateRoute(ctx context.Context, route *Route) error { + // Generate ID if not provided + if route.ID == "" { + route.ID = "route-" + generateShortID() + } + + // Validate route name + if route.Name == "" { + return fmt.Errorf("route name is required") + } + + // Check for duplicate name + routes, err := s.store.ListRoutes(ctx) + if err != nil { + return err + } + for _, r := range routes { + if strings.EqualFold(r.Name, route.Name) { + return fmt.Errorf("route with name '%s' already exists", route.Name) + } + } + + route.CreatedAt = time.Now() + route.UpdatedAt = route.CreatedAt + + if err := s.store.CreateRoute(ctx, route); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "route_created", + RouteID: route.ID, + Payload: route, + }) + + return nil +} + +func (s *DefaultConfigService) UpdateRoute(ctx context.Context, route *Route) error { + existing, err := s.store.GetRoute(ctx, route.ID) + if err != nil { + return err + } + + // Check for duplicate name if name changed + if !strings.EqualFold(existing.Name, route.Name) { + routes, err := s.store.ListRoutes(ctx) + if err != nil { + return err + } + for _, r := range routes { + if r.ID != route.ID && strings.EqualFold(r.Name, route.Name) { + return fmt.Errorf("route with name '%s' already exists", route.Name) + } + } + } + + route.CreatedAt = existing.CreatedAt + route.UpdatedAt = time.Now() + + if err := s.store.UpdateRoute(ctx, route); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "route_updated", + RouteID: route.ID, + Payload: route, + }) + + return nil +} + +func (s *DefaultConfigService) DeleteRoute(ctx context.Context, id string) error { + if err := s.store.DeleteRoute(ctx, id); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "route_deleted", + RouteID: id, + }) + + return nil +} + +func (s *DefaultConfigService) GetPipeline(ctx context.Context, routeID string) (*Pipeline, error) { + return s.store.GetPipeline(ctx, routeID) +} + +func (s *DefaultConfigService) UpdatePipeline(ctx context.Context, routeID string, pipeline *Pipeline) error { + // Validate pipeline + if errs := s.validatePipeline(pipeline); len(errs) > 0 { + return fmt.Errorf("pipeline validation failed: %s", errs[0].Message) + } + + // Ensure target IDs are set + for i := range pipeline.Layers { + for j := range pipeline.Layers[i].Targets { + if pipeline.Layers[i].Targets[j].ID == "" { + pipeline.Layers[i].Targets[j].ID = "target-" + generateShortID() + } + // Default weight to 1 + if pipeline.Layers[i].Targets[j].Weight <= 0 { + pipeline.Layers[i].Targets[j].Weight = 1 + } + } + } + + if err := s.store.SavePipeline(ctx, routeID, pipeline); err != nil { + return err + } + + s.notify(ConfigChangeEvent{ + Type: "pipeline_updated", + RouteID: routeID, + Payload: pipeline, + }) + + return nil +} + +func (s *DefaultConfigService) Export(ctx context.Context) (*ExportData, error) { + settings, err := s.store.LoadSettings(ctx) + if err != nil { + return nil, err + } + + healthConfig, err := s.store.LoadHealthCheckConfig(ctx) + if err != nil { + return nil, err + } + + routes, err := s.store.ListRoutes(ctx) + if err != nil { + return nil, err + } + + var routesWithPipelines []RouteWithPipeline + for _, route := range routes { + pipeline, err := s.store.GetPipeline(ctx, route.ID) + if err != nil { + pipeline = &Pipeline{RouteID: route.ID, Layers: []Layer{}} + } + routesWithPipelines = append(routesWithPipelines, RouteWithPipeline{ + Route: *route, + Pipeline: *pipeline, + }) + } + + return &ExportData{ + Version: "1.0", + ExportedAt: time.Now(), + Config: ExportedConfig{ + Settings: *settings, + HealthCheck: *healthConfig, + Routes: routesWithPipelines, + }, + }, nil +} + +func (s *DefaultConfigService) Import(ctx context.Context, data *ExportData, merge bool) error { + if !merge { + // Delete all existing routes first + routes, _ := s.store.ListRoutes(ctx) + for _, route := range routes { + _ = s.store.DeleteRoute(ctx, route.ID) + } + } + + // Import settings + if err := s.store.SaveSettings(ctx, &data.Config.Settings); err != nil { + return fmt.Errorf("failed to import settings: %w", err) + } + + // Import health config + if err := s.store.SaveHealthCheckConfig(ctx, &data.Config.HealthCheck); err != nil { + return fmt.Errorf("failed to import health config: %w", err) + } + + // Import routes and pipelines + for _, rwp := range data.Config.Routes { + route := rwp.Route + + if merge { + // Update if exists, create if not + _, err := s.store.GetRoute(ctx, route.ID) + if err != nil { + _ = s.store.CreateRoute(ctx, &route) + } else { + _ = s.store.UpdateRoute(ctx, &route) + } + } else { + _ = s.store.CreateRoute(ctx, &route) + } + + _ = s.store.SavePipeline(ctx, route.ID, &rwp.Pipeline) + } + + s.notify(ConfigChangeEvent{ + Type: "config_imported", + Payload: data, + }) + + return nil +} + +func (s *DefaultConfigService) Validate(ctx context.Context, route *Route, pipeline *Pipeline) []ValidationError { + var errors []ValidationError + + // Validate route + if route != nil { + if route.Name == "" { + errors = append(errors, ValidationError{Field: "name", Message: "route name is required"}) + } + if len(route.Name) > 64 { + errors = append(errors, ValidationError{Field: "name", Message: "route name must be 64 characters or less"}) + } + // Route name should be a valid model identifier + if !isValidModelName(route.Name) { + errors = append(errors, ValidationError{Field: "name", Message: "route name must be alphanumeric with dashes/underscores"}) + } + } + + // Validate pipeline + if pipeline != nil { + errors = append(errors, s.validatePipeline(pipeline)...) + } + + return errors +} + +func (s *DefaultConfigService) validatePipeline(pipeline *Pipeline) []ValidationError { + var errors []ValidationError + + if len(pipeline.Layers) == 0 { + errors = append(errors, ValidationError{Field: "layers", Message: "at least one layer is required"}) + return errors + } + + seenLevels := make(map[int]bool) + for i, layer := range pipeline.Layers { + // Check level uniqueness + if seenLevels[layer.Level] { + errors = append(errors, ValidationError{ + Field: fmt.Sprintf("layers[%d].level", i), + Message: fmt.Sprintf("duplicate level %d", layer.Level), + }) + } + seenLevels[layer.Level] = true + + // Check targets + if len(layer.Targets) == 0 { + errors = append(errors, ValidationError{ + Field: fmt.Sprintf("layers[%d].targets", i), + Message: "at least one target is required per layer", + }) + } + + for j, target := range layer.Targets { + if target.CredentialID == "" { + errors = append(errors, ValidationError{ + Field: fmt.Sprintf("layers[%d].targets[%d].credential_id", i, j), + Message: "credential_id is required", + }) + } + if target.Model == "" { + errors = append(errors, ValidationError{ + Field: fmt.Sprintf("layers[%d].targets[%d].model", i, j), + Message: "model is required", + }) + } + } + + // Validate strategy + switch layer.Strategy { + case StrategyRoundRobin, StrategyWeightedRound, StrategyLeastConn, StrategyRandom, StrategyFirstAvailable, "": + // Valid + default: + errors = append(errors, ValidationError{ + Field: fmt.Sprintf("layers[%d].strategy", i), + Message: fmt.Sprintf("invalid strategy: %s", layer.Strategy), + }) + } + } + + return errors +} + +func (s *DefaultConfigService) Subscribe(handler ConfigChangeHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.handlers = append(s.handlers, handler) +} + +func (s *DefaultConfigService) notify(event ConfigChangeEvent) { + s.mu.RLock() + handlers := s.handlers + s.mu.RUnlock() + + for _, handler := range handlers { + go handler(event) + } +} + +// Helper functions + +func generateShortID() string { + id := uuid.New().String() + return id[:8] +} + +func isValidModelName(name string) bool { + if name == "" { + return false + } + for _, c := range name { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' || c == '.') { + return false + } + } + return true +} diff --git a/internal/api/modules/unified-routing/engine.go b/internal/api/modules/unified-routing/engine.go new file mode 100644 index 000000000..747bae6b1 --- /dev/null +++ b/internal/api/modules/unified-routing/engine.go @@ -0,0 +1,668 @@ +package unifiedrouting + +import ( + "context" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + log "github.com/sirupsen/logrus" +) + +// RoutingEngine is the core routing engine for unified routing. +type RoutingEngine interface { + // Route determines the routing decision for a given model name. + Route(ctx context.Context, modelName string) (*RoutingDecision, error) + + // IsEnabled returns whether unified routing is enabled. + IsEnabled(ctx context.Context) bool + + // ShouldHideOriginalModels returns whether original models should be hidden. + ShouldHideOriginalModels(ctx context.Context) bool + + // GetRouteNames returns all configured route names. + GetRouteNames(ctx context.Context) []string + + // Reload reloads the engine configuration. + Reload(ctx context.Context) error + + // GetRoutingTarget returns the target model and credential for a route alias. + // Returns the target model name, credential ID, and any error. + // If modelName is not a route alias, returns RouteNotFoundError. + GetRoutingTarget(ctx context.Context, modelName string) (targetModel string, credentialID string, err error) + + // SelectTarget selects the next target from a layer based on the load balancing strategy. + SelectTarget(ctx context.Context, routeID string, layer *Layer) (*Target, error) +} + +// RoutingDecision represents the decision made by the routing engine. +type RoutingDecision struct { + RouteID string + RouteName string + TraceID string + Pipeline *Pipeline +} + +// DefaultRoutingEngine implements RoutingEngine. +type DefaultRoutingEngine struct { + configSvc ConfigService + stateMgr StateManager + metrics MetricsCollector + authManager *coreauth.Manager + + mu sync.RWMutex + routeIndex map[string]*Route // name -> route + pipelineIndex map[string]*Pipeline // routeID -> pipeline + + // Round-robin state per layer + rrCounters map[string]*atomic.Uint64 // layerKey -> counter +} + +// NewRoutingEngine creates a new routing engine. +func NewRoutingEngine( + configSvc ConfigService, + stateMgr StateManager, + metrics MetricsCollector, + authManager *coreauth.Manager, +) *DefaultRoutingEngine { + engine := &DefaultRoutingEngine{ + configSvc: configSvc, + stateMgr: stateMgr, + metrics: metrics, + authManager: authManager, + routeIndex: make(map[string]*Route), + pipelineIndex: make(map[string]*Pipeline), + rrCounters: make(map[string]*atomic.Uint64), + } + + // Subscribe to config changes + configSvc.Subscribe(func(event ConfigChangeEvent) { + _ = engine.Reload(context.Background()) + }) + + // Initial load + _ = engine.Reload(context.Background()) + + return engine +} + +func (e *DefaultRoutingEngine) Route(ctx context.Context, modelName string) (*RoutingDecision, error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // Look up route by name (case-insensitive) + route, ok := e.routeIndex[strings.ToLower(modelName)] + if !ok { + return nil, &RouteNotFoundError{ModelName: modelName} + } + + if !route.Enabled { + return nil, &RouteDisabledError{RouteName: route.Name} + } + + pipeline, ok := e.pipelineIndex[route.ID] + if !ok || len(pipeline.Layers) == 0 { + return nil, &PipelineEmptyError{RouteID: route.ID} + } + + return &RoutingDecision{ + RouteID: route.ID, + RouteName: route.Name, + TraceID: "trace-" + generateShortID(), + Pipeline: pipeline, + }, nil +} + +func (e *DefaultRoutingEngine) IsEnabled(ctx context.Context) bool { + settings, err := e.configSvc.GetSettings(ctx) + if err != nil { + return false + } + return settings.Enabled +} + +func (e *DefaultRoutingEngine) ShouldHideOriginalModels(ctx context.Context) bool { + settings, err := e.configSvc.GetSettings(ctx) + if err != nil { + return false + } + return settings.Enabled && settings.HideOriginalModels +} + +func (e *DefaultRoutingEngine) GetRouteNames(ctx context.Context) []string { + e.mu.RLock() + defer e.mu.RUnlock() + + names := make([]string, 0, len(e.routeIndex)) + for _, route := range e.routeIndex { + if route.Enabled { + names = append(names, route.Name) + } + } + return names +} + +// GetRoutingTarget returns the target model and credential for a route alias. +func (e *DefaultRoutingEngine) GetRoutingTarget(ctx context.Context, modelName string) (string, string, error) { + decision, err := e.Route(ctx, modelName) + if err != nil { + return "", "", err + } + + // Select target from the first available layer + for _, layer := range decision.Pipeline.Layers { + target, err := e.SelectTarget(ctx, decision.RouteID, &layer) + if err != nil { + continue // Try next layer + } + if target != nil { + return target.Model, target.CredentialID, nil + } + } + + return "", "", &NoAvailableTargetsError{Layer: 0} +} + +// GetRoutingDecision returns the full routing decision for a model name. +func (e *DefaultRoutingEngine) GetRoutingDecision(ctx context.Context, modelName string) (*RoutingDecision, error) { + return e.Route(ctx, modelName) +} + +func (e *DefaultRoutingEngine) Reload(ctx context.Context) error { + routes, err := e.configSvc.ListRoutes(ctx) + if err != nil { + return err + } + + newRouteIndex := make(map[string]*Route, len(routes)) + newPipelineIndex := make(map[string]*Pipeline, len(routes)) + + for _, route := range routes { + newRouteIndex[strings.ToLower(route.Name)] = route + + pipeline, err := e.configSvc.GetPipeline(ctx, route.ID) + if err != nil { + pipeline = &Pipeline{RouteID: route.ID, Layers: []Layer{}} + } + newPipelineIndex[route.ID] = pipeline + } + + e.mu.Lock() + e.routeIndex = newRouteIndex + e.pipelineIndex = newPipelineIndex + e.mu.Unlock() + + log.Debugf("unified routing engine reloaded: %d routes", len(routes)) + return nil +} + +// SelectTarget selects the next target from a layer based on the strategy. +func (e *DefaultRoutingEngine) SelectTarget(ctx context.Context, routeID string, layer *Layer) (*Target, error) { + // Get available targets + availableTargets := make([]Target, 0) + for _, target := range layer.Targets { + if !target.Enabled { + continue + } + state, _ := e.stateMgr.GetTargetState(ctx, target.ID) + if state != nil && state.Status != StatusHealthy { + continue + } + availableTargets = append(availableTargets, target) + } + + if len(availableTargets) == 0 { + return nil, &NoAvailableTargetsError{Layer: layer.Level} + } + + // Select based on strategy + var selected *Target + switch layer.Strategy { + case StrategyRoundRobin, "": + selected = e.selectRoundRobin(routeID, layer.Level, availableTargets) + case StrategyWeightedRound: + selected = e.selectWeightedRoundRobin(routeID, layer.Level, availableTargets) + case StrategyRandom: + selected = e.selectRandom(availableTargets) + case StrategyFirstAvailable: + selected = &availableTargets[0] + case StrategyLeastConn: + selected = e.selectLeastConnections(ctx, availableTargets) + default: + selected = e.selectRoundRobin(routeID, layer.Level, availableTargets) + } + + return selected, nil +} + +func (e *DefaultRoutingEngine) selectRoundRobin(routeID string, level int, targets []Target) *Target { + key := fmt.Sprintf("%s:%d", routeID, level) + + e.mu.Lock() + counter, ok := e.rrCounters[key] + if !ok { + counter = &atomic.Uint64{} + e.rrCounters[key] = counter + } + e.mu.Unlock() + + idx := counter.Add(1) - 1 + return &targets[int(idx)%len(targets)] +} + +func (e *DefaultRoutingEngine) selectWeightedRoundRobin(routeID string, level int, targets []Target) *Target { + // Calculate total weight + totalWeight := 0 + for _, t := range targets { + weight := t.Weight + if weight <= 0 { + weight = 1 + } + totalWeight += weight + } + + key := fmt.Sprintf("%s:%d:weighted", routeID, level) + + e.mu.Lock() + counter, ok := e.rrCounters[key] + if !ok { + counter = &atomic.Uint64{} + e.rrCounters[key] = counter + } + e.mu.Unlock() + + idx := int(counter.Add(1)-1) % totalWeight + + // Find the target + cumulative := 0 + for i := range targets { + weight := targets[i].Weight + if weight <= 0 { + weight = 1 + } + cumulative += weight + if idx < cumulative { + return &targets[i] + } + } + + return &targets[0] +} + +func (e *DefaultRoutingEngine) selectRandom(targets []Target) *Target { + idx := rand.Intn(len(targets)) + return &targets[idx] +} + +func (e *DefaultRoutingEngine) selectLeastConnections(ctx context.Context, targets []Target) *Target { + var minConn int64 = -1 + var selected *Target + + for i := range targets { + state, _ := e.stateMgr.GetTargetState(ctx, targets[i].ID) + conn := int64(0) + if state != nil { + conn = state.ActiveConnections + } + + if minConn < 0 || conn < minConn { + minConn = conn + selected = &targets[i] + } + } + + if selected == nil { + return &targets[0] + } + return selected +} + +// ExecuteWithFailover executes a request with automatic failover. +func (e *DefaultRoutingEngine) ExecuteWithFailover( + ctx context.Context, + decision *RoutingDecision, + executeFunc func(ctx context.Context, auth *coreauth.Auth, model string) error, +) error { + if decision == nil || decision.Pipeline == nil { + return fmt.Errorf("invalid routing decision") + } + + traceBuilder := NewTraceBuilder(decision.RouteID, decision.RouteName) + startTime := time.Now() + + // Get health check config for cooldown + healthConfig, _ := e.configSvc.GetHealthCheckConfig(ctx) + if healthConfig == nil { + cfg := DefaultHealthCheckConfig() + healthConfig = &cfg + } + + // Try each layer in order + for layerIdx, layer := range decision.Pipeline.Layers { + cooldownDuration := time.Duration(layer.CooldownSeconds) * time.Second + if cooldownDuration == 0 { + cooldownDuration = time.Duration(healthConfig.DefaultCooldownSeconds) * time.Second + } + + // Keep trying targets in this layer until no available targets remain + // SelectTarget automatically excludes cooling-down targets + for { + target, err := e.SelectTarget(ctx, decision.RouteID, &layer) + if err != nil { + // No available targets in this layer, move to next layer + break + } + + // Find auth for this target + auth := e.findAuth(target.CredentialID) + if auth == nil { + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed("credential not found") + // Mark as cooldown so we don't keep trying this target + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + continue + } + + // Execute request + attemptStart := time.Now() + err = executeFunc(ctx, auth, target.Model) + attemptLatency := time.Since(attemptStart).Milliseconds() + + if err == nil { + // Success - record and return + e.stateMgr.RecordSuccess(ctx, target.ID, time.Since(attemptStart)) + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Success(attemptLatency) + + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + e.metrics.RecordRequest(trace) + return nil + } + + // Failure - immediately start cooldown and try next target in this layer + e.stateMgr.RecordFailure(ctx, target.ID, err.Error()) + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed(err.Error()) + + // Start cooldown immediately on failure + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventCooldownStarted, + RouteID: decision.RouteID, + TargetID: target.ID, + Details: map[string]any{ + "duration_seconds": int(cooldownDuration.Seconds()), + "reason": err.Error(), + }, + }) + + // Continue loop - SelectTarget will automatically exclude cooling-down targets + } + + // Record layer fallback event when moving to next layer + if layerIdx < len(decision.Pipeline.Layers)-1 { + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventLayerFallback, + RouteID: decision.RouteID, + Details: map[string]any{ + "from_layer": layer.Level, + "to_layer": layer.Level + 1, + }, + }) + } + } + + // All layers exhausted + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + e.metrics.RecordRequest(trace) + + return &AllTargetsExhaustedError{RouteID: decision.RouteID} +} + +// StreamExecuteFunc is the function type for streaming execution. +// It returns a channel of StreamChunks and an error if connection fails. +type StreamExecuteFunc func(ctx context.Context, auth *coreauth.Auth, model string) (<-chan cliproxyexecutor.StreamChunk, error) + +// ExecuteStreamWithFailover executes a streaming request with automatic failover. +// Failover only occurs before the first successful chunk is received. +// Once streaming begins, the target is committed and cannot be changed. +func (e *DefaultRoutingEngine) ExecuteStreamWithFailover( + ctx context.Context, + decision *RoutingDecision, + executeFunc StreamExecuteFunc, +) (<-chan cliproxyexecutor.StreamChunk, error) { + if decision == nil || decision.Pipeline == nil { + return nil, fmt.Errorf("invalid routing decision") + } + + traceBuilder := NewTraceBuilder(decision.RouteID, decision.RouteName) + startTime := time.Now() + + // Get health check config for cooldown + healthConfig, _ := e.configSvc.GetHealthCheckConfig(ctx) + if healthConfig == nil { + cfg := DefaultHealthCheckConfig() + healthConfig = &cfg + } + + // Try each layer in order + for layerIdx, layer := range decision.Pipeline.Layers { + cooldownDuration := time.Duration(layer.CooldownSeconds) * time.Second + if cooldownDuration == 0 { + cooldownDuration = time.Duration(healthConfig.DefaultCooldownSeconds) * time.Second + } + + // Keep trying targets in this layer until no available targets remain + for { + target, err := e.SelectTarget(ctx, decision.RouteID, &layer) + if err != nil { + // No available targets in this layer, move to next layer + break + } + + // Find auth for this target + auth := e.findAuth(target.CredentialID) + if auth == nil { + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed("credential not found") + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + continue + } + + // Try to execute streaming request + attemptStart := time.Now() + chunks, err := executeFunc(ctx, auth, target.Model) + if err != nil { + // Connection failed, try next target + attemptLatency := time.Since(attemptStart).Milliseconds() + e.stateMgr.RecordFailure(ctx, target.ID, err.Error()) + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed(err.Error()) + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventCooldownStarted, + RouteID: decision.RouteID, + TargetID: target.ID, + Details: map[string]any{ + "duration_seconds": int(cooldownDuration.Seconds()), + "reason": err.Error(), + "latency_ms": attemptLatency, + }, + }) + continue + } + + // Got a channel, now wait for the first chunk to validate the connection + firstChunk, ok := <-chunks + if !ok { + // Channel closed immediately without any data - treat as failure + attemptLatency := time.Since(attemptStart).Milliseconds() + e.stateMgr.RecordFailure(ctx, target.ID, "stream closed without data") + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed("stream closed without data") + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventCooldownStarted, + RouteID: decision.RouteID, + TargetID: target.ID, + Details: map[string]any{ + "duration_seconds": int(cooldownDuration.Seconds()), + "reason": "stream closed without data", + "latency_ms": attemptLatency, + }, + }) + continue + } + + if firstChunk.Err != nil { + // First chunk has error - try next target + attemptLatency := time.Since(attemptStart).Milliseconds() + errMsg := firstChunk.Err.Error() + e.stateMgr.RecordFailure(ctx, target.ID, errMsg) + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed(errMsg) + e.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventCooldownStarted, + RouteID: decision.RouteID, + TargetID: target.ID, + Details: map[string]any{ + "duration_seconds": int(cooldownDuration.Seconds()), + "reason": errMsg, + "latency_ms": attemptLatency, + }, + }) + // Drain remaining chunks to prevent goroutine leak + go func() { + for range chunks { + } + }() + continue + } + + // First chunk is successful! Create a new channel that forwards all chunks + // and record success after stream completes + outputChan := make(chan cliproxyexecutor.StreamChunk, 100) + + // Write first chunk to buffer BEFORE starting goroutine to avoid race condition + outputChan <- firstChunk + + // Capture loop variables for goroutine to avoid closure issues + capturedTarget := target + capturedAttemptStart := attemptStart + + go func() { + defer close(outputChan) + + // Forward remaining chunks + var streamErr error + for chunk := range chunks { + if chunk.Err != nil { + streamErr = chunk.Err + } + outputChan <- chunk + } + + // Record result after stream completes + attemptLatency := time.Since(capturedAttemptStart).Milliseconds() + if streamErr != nil { + // Stream had an error mid-way, but we already committed to this target + // Just log it, don't start cooldown since connection was initially good + log.Warnf("[UnifiedRouting] Stream error after successful start: %v", streamErr) + } + e.stateMgr.RecordSuccess(ctx, capturedTarget.ID, time.Since(capturedAttemptStart)) + traceBuilder.AddAttempt(layer.Level, capturedTarget.ID, capturedTarget.CredentialID, capturedTarget.Model). + Success(attemptLatency) + + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + e.metrics.RecordRequest(trace) + }() + + return outputChan, nil + } + + // Record layer fallback event when moving to next layer + if layerIdx < len(decision.Pipeline.Layers)-1 { + e.metrics.RecordEvent(&RoutingEvent{ + Type: EventLayerFallback, + RouteID: decision.RouteID, + Details: map[string]any{ + "from_layer": layer.Level, + "to_layer": layer.Level + 1, + }, + }) + } + } + + // All layers exhausted + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + e.metrics.RecordRequest(trace) + + return nil, &AllTargetsExhaustedError{RouteID: decision.RouteID} +} + +func (e *DefaultRoutingEngine) findAuth(credentialID string) *coreauth.Auth { + if e.authManager == nil { + return nil + } + + auths := e.authManager.List() + for _, auth := range auths { + if auth.ID == credentialID { + return auth + } + } + return nil +} + +// Error types + +// RouteNotFoundError is returned when a route is not found. +type RouteNotFoundError struct { + ModelName string +} + +func (e *RouteNotFoundError) Error() string { + return fmt.Sprintf("route not found for model: %s", e.ModelName) +} + +// RouteDisabledError is returned when a route is disabled. +type RouteDisabledError struct { + RouteName string +} + +func (e *RouteDisabledError) Error() string { + return fmt.Sprintf("route is disabled: %s", e.RouteName) +} + +// PipelineEmptyError is returned when a pipeline has no layers. +type PipelineEmptyError struct { + RouteID string +} + +func (e *PipelineEmptyError) Error() string { + return fmt.Sprintf("pipeline is empty for route: %s", e.RouteID) +} + +// NoAvailableTargetsError is returned when no targets are available in a layer. +type NoAvailableTargetsError struct { + Layer int +} + +func (e *NoAvailableTargetsError) Error() string { + return fmt.Sprintf("no available targets in layer %d", e.Layer) +} + +// AllTargetsExhaustedError is returned when all targets in all layers are exhausted. +type AllTargetsExhaustedError struct { + RouteID string +} + +func (e *AllTargetsExhaustedError) Error() string { + return fmt.Sprintf("all targets exhausted for route: %s", e.RouteID) +} diff --git a/internal/api/modules/unified-routing/handlers.go b/internal/api/modules/unified-routing/handlers.go new file mode 100644 index 000000000..11fa6226a --- /dev/null +++ b/internal/api/modules/unified-routing/handlers.go @@ -0,0 +1,1008 @@ +package unifiedrouting + +import ( + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// Handlers contains all HTTP handlers for unified routing. +type Handlers struct { + configSvc ConfigService + stateMgr StateManager + metrics MetricsCollector + healthChecker HealthChecker + authManager *coreauth.Manager + engine RoutingEngine +} + +// NewHandlers creates a new handlers instance. +func NewHandlers( + configSvc ConfigService, + stateMgr StateManager, + metrics MetricsCollector, + healthChecker HealthChecker, + authManager *coreauth.Manager, + engine RoutingEngine, +) *Handlers { + return &Handlers{ + configSvc: configSvc, + stateMgr: stateMgr, + metrics: metrics, + healthChecker: healthChecker, + authManager: authManager, + engine: engine, + } +} + +// ================== Config: Settings ================== + +// GetSettings returns the unified routing settings. +func (h *Handlers) GetSettings(c *gin.Context) { + log.Info("[UnifiedRouting] GetSettings called") + settings, err := h.configSvc.GetSettings(c.Request.Context()) + if err != nil { + log.Errorf("[UnifiedRouting] GetSettings error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + log.Infof("[UnifiedRouting] GetSettings success: %+v", settings) + c.JSON(http.StatusOK, settings) +} + +// PutSettings updates the unified routing settings. +func (h *Handlers) PutSettings(c *gin.Context) { + var settings Settings + if err := c.ShouldBindJSON(&settings); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.configSvc.UpdateSettings(c.Request.Context(), &settings); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, settings) +} + +// GetHealthCheckConfig returns the health check configuration. +func (h *Handlers) GetHealthCheckConfig(c *gin.Context) { + config, err := h.configSvc.GetHealthCheckConfig(c.Request.Context()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, config) +} + +// PutHealthCheckConfig updates the health check configuration. +func (h *Handlers) PutHealthCheckConfig(c *gin.Context) { + var config HealthCheckConfig + if err := c.ShouldBindJSON(&config); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.configSvc.UpdateHealthCheckConfig(c.Request.Context(), &config); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, config) +} + +// ================== Config: Routes ================== + +// ListRoutes returns all routes. +func (h *Handlers) ListRoutes(c *gin.Context) { + log.Info("[UnifiedRouting] ListRoutes called") + routes, err := h.configSvc.ListRoutes(c.Request.Context()) + if err != nil { + log.Errorf("[UnifiedRouting] ListRoutes error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + log.Infof("[UnifiedRouting] ListRoutes: found %d routes", len(routes)) + + // Build response with pipeline summary + type RouteResponse struct { + *Route + PipelineSummary struct { + TotalLayers int `json:"total_layers"` + TotalTargets int `json:"total_targets"` + } `json:"pipeline_summary"` + } + + response := make([]RouteResponse, 0, len(routes)) + for _, route := range routes { + rr := RouteResponse{Route: route} + + pipeline, err := h.configSvc.GetPipeline(c.Request.Context(), route.ID) + if err == nil { + rr.PipelineSummary.TotalLayers = len(pipeline.Layers) + for _, layer := range pipeline.Layers { + rr.PipelineSummary.TotalTargets += len(layer.Targets) + } + } + + response = append(response, rr) + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(response), + "routes": response, + }) +} + +// GetRoute returns a single route. +func (h *Handlers) GetRoute(c *gin.Context) { + routeID := c.Param("route_id") + + route, err := h.configSvc.GetRoute(c.Request.Context(), routeID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + pipeline, _ := h.configSvc.GetPipeline(c.Request.Context(), routeID) + + c.JSON(http.StatusOK, gin.H{ + "route": route, + "pipeline": pipeline, + }) +} + +// CreateRoute creates a new route. +func (h *Handlers) CreateRoute(c *gin.Context) { + var req struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + Pipeline Pipeline `json:"pipeline"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Validate + route := &Route{ + Name: req.Name, + Description: req.Description, + Enabled: req.Enabled, + } + + // Only validate pipeline if it has layers (allow creating routes without pipeline) + var pipelineToValidate *Pipeline + if len(req.Pipeline.Layers) > 0 { + pipelineToValidate = &req.Pipeline + } + + if errs := h.configSvc.Validate(c.Request.Context(), route, pipelineToValidate); len(errs) > 0 { + c.JSON(http.StatusBadRequest, gin.H{"errors": errs}) + return + } + + // Create route + if err := h.configSvc.CreateRoute(c.Request.Context(), route); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Save pipeline if provided + if len(req.Pipeline.Layers) > 0 { + if err := h.configSvc.UpdatePipeline(c.Request.Context(), route.ID, &req.Pipeline); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusCreated, gin.H{ + "id": route.ID, + "name": route.Name, + "message": "route created successfully", + }) +} + +// UpdateRoute updates a route. +func (h *Handlers) UpdateRoute(c *gin.Context) { + routeID := c.Param("route_id") + + var req struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + Pipeline Pipeline `json:"pipeline"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + route := &Route{ + ID: routeID, + Name: req.Name, + Description: req.Description, + Enabled: req.Enabled, + } + + if err := h.configSvc.UpdateRoute(c.Request.Context(), route); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Update pipeline if provided + if len(req.Pipeline.Layers) > 0 { + if err := h.configSvc.UpdatePipeline(c.Request.Context(), routeID, &req.Pipeline); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusOK, gin.H{"message": "route updated successfully"}) +} + +// PatchRoute partially updates a route. +func (h *Handlers) PatchRoute(c *gin.Context) { + routeID := c.Param("route_id") + + existing, err := h.configSvc.GetRoute(c.Request.Context(), routeID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + var patch map[string]interface{} + if err := c.ShouldBindJSON(&patch); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Apply patch + if name, ok := patch["name"].(string); ok { + existing.Name = name + } + if desc, ok := patch["description"].(string); ok { + existing.Description = desc + } + if enabled, ok := patch["enabled"].(bool); ok { + existing.Enabled = enabled + } + + if err := h.configSvc.UpdateRoute(c.Request.Context(), existing); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "route updated successfully"}) +} + +// DeleteRoute deletes a route. +func (h *Handlers) DeleteRoute(c *gin.Context) { + routeID := c.Param("route_id") + + if err := h.configSvc.DeleteRoute(c.Request.Context(), routeID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "route deleted successfully"}) +} + +// ================== Config: Pipeline ================== + +// GetPipeline returns the pipeline for a route. +func (h *Handlers) GetPipeline(c *gin.Context) { + routeID := c.Param("route_id") + + pipeline, err := h.configSvc.GetPipeline(c.Request.Context(), routeID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, pipeline) +} + +// UpdatePipeline updates the pipeline for a route. +func (h *Handlers) UpdatePipeline(c *gin.Context) { + routeID := c.Param("route_id") + + var pipeline Pipeline + if err := c.ShouldBindJSON(&pipeline); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.configSvc.UpdatePipeline(c.Request.Context(), routeID, &pipeline); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "pipeline updated successfully"}) +} + +// ================== Config: Export/Import ================== + +// ExportConfig exports the configuration. +func (h *Handlers) ExportConfig(c *gin.Context) { + data, err := h.configSvc.Export(c.Request.Context()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, data) +} + +// ImportConfig imports the configuration. +func (h *Handlers) ImportConfig(c *gin.Context) { + merge := c.DefaultQuery("merge", "false") == "true" + + var data ExportData + if err := c.ShouldBindJSON(&data); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.configSvc.Import(c.Request.Context(), &data, merge); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "configuration imported successfully"}) +} + +// ValidateConfig validates a configuration. +func (h *Handlers) ValidateConfig(c *gin.Context) { + var req struct { + Route *Route `json:"route"` + Pipeline *Pipeline `json:"pipeline"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + errors := h.configSvc.Validate(c.Request.Context(), req.Route, req.Pipeline) + + c.JSON(http.StatusOK, gin.H{ + "valid": len(errors) == 0, + "errors": errors, + }) +} + +// ================== State ================== + +// GetOverview returns the overall state overview. +func (h *Handlers) GetOverview(c *gin.Context) { + log.Info("[UnifiedRouting] GetOverview called") + overview, err := h.stateMgr.GetOverview(c.Request.Context()) + if err != nil { + log.Errorf("[UnifiedRouting] GetOverview error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + log.Infof("[UnifiedRouting] GetOverview success: %d routes", overview.TotalRoutes) + c.JSON(http.StatusOK, overview) +} + +// GetRouteStatus returns the status of a route. +func (h *Handlers) GetRouteStatus(c *gin.Context) { + routeID := c.Param("route_id") + + state, err := h.stateMgr.GetRouteState(c.Request.Context(), routeID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, state) +} + +// GetTargetStatus returns the status of a target. +func (h *Handlers) GetTargetStatus(c *gin.Context) { + targetID := c.Param("target_id") + + state, err := h.stateMgr.GetTargetState(c.Request.Context(), targetID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, state) +} + +// ResetTarget resets a target's state. +func (h *Handlers) ResetTarget(c *gin.Context) { + targetID := c.Param("target_id") + + if err := h.stateMgr.ResetTarget(c.Request.Context(), targetID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "target status reset successfully", + "target_id": targetID, + "new_status": "healthy", + }) +} + +// ForceCooldown forces a target into cooldown. +func (h *Handlers) ForceCooldown(c *gin.Context) { + targetID := c.Param("target_id") + + var req struct { + DurationSeconds int `json:"duration_seconds"` + } + if err := c.ShouldBindJSON(&req); err != nil { + req.DurationSeconds = 60 + } + + duration := time.Duration(req.DurationSeconds) * time.Second + if err := h.stateMgr.ForceCooldown(c.Request.Context(), targetID, duration); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "cooldown started", + "target_id": targetID, + "duration_seconds": req.DurationSeconds, + }) +} + +// ================== Health ================== + +// TriggerHealthCheck triggers a health check. +func (h *Handlers) TriggerHealthCheck(c *gin.Context) { + routeID := c.Param("route_id") + targetID := c.Query("target_id") + + var results []*HealthResult + var err error + + if targetID != "" { + result, e := h.healthChecker.CheckTarget(c.Request.Context(), targetID) + if e != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": e.Error()}) + return + } + results = []*HealthResult{result} + } else if routeID != "" { + results, err = h.healthChecker.CheckRoute(c.Request.Context(), routeID) + } else { + results, err = h.healthChecker.CheckAll(c.Request.Context()) + } + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "checked_at": time.Now(), + "results": results, + }) +} + +// GetHealthSettings returns health check settings. +func (h *Handlers) GetHealthSettings(c *gin.Context) { + settings, err := h.healthChecker.GetSettings(c.Request.Context()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, settings) +} + +// UpdateHealthSettings updates health check settings. +func (h *Handlers) UpdateHealthSettings(c *gin.Context) { + var settings HealthCheckConfig + if err := c.ShouldBindJSON(&settings); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.healthChecker.UpdateSettings(c.Request.Context(), &settings); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, settings) +} + +// GetHealthHistory returns health check history. +func (h *Handlers) GetHealthHistory(c *gin.Context) { + filter := HealthHistoryFilter{ + TargetID: c.Query("target_id"), + Status: c.Query("status"), + } + + if limitStr := c.Query("limit"); limitStr != "" { + if limit, err := strconv.Atoi(limitStr); err == nil { + filter.Limit = limit + } + } + + history, err := h.healthChecker.GetHistory(c.Request.Context(), filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(history), + "history": history, + }) +} + +// ================== Metrics ================== + +// GetStats returns aggregated statistics. +func (h *Handlers) GetStats(c *gin.Context) { + filter := StatsFilter{ + Period: c.DefaultQuery("period", "1h"), + Granularity: c.DefaultQuery("granularity", "minute"), + } + + stats, err := h.metrics.GetStats(c.Request.Context(), filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, stats) +} + +// GetRouteStats returns statistics for a route. +func (h *Handlers) GetRouteStats(c *gin.Context) { + routeID := c.Param("route_id") + filter := StatsFilter{ + Period: c.DefaultQuery("period", "1h"), + Granularity: c.DefaultQuery("granularity", "minute"), + } + + stats, err := h.metrics.GetRouteStats(c.Request.Context(), routeID, filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + stats.Period = filter.Period + c.JSON(http.StatusOK, gin.H{ + "route_id": routeID, + "stats": stats, + }) +} + +// GetEvents returns routing events. +func (h *Handlers) GetEvents(c *gin.Context) { + filter := EventFilter{ + Type: c.DefaultQuery("type", "all"), + RouteID: c.Query("route_id"), + } + + if limitStr := c.Query("limit"); limitStr != "" { + if limit, err := strconv.Atoi(limitStr); err == nil { + filter.Limit = limit + } + } else { + filter.Limit = 100 + } + + events, err := h.metrics.GetEvents(c.Request.Context(), filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(events), + "events": events, + }) +} + +// GetTraces returns request traces. +func (h *Handlers) GetTraces(c *gin.Context) { + filter := TraceFilter{ + RouteID: c.Query("route_id"), + Status: c.Query("status"), + } + + if limitStr := c.Query("limit"); limitStr != "" { + if limit, err := strconv.Atoi(limitStr); err == nil { + filter.Limit = limit + } + } else { + filter.Limit = 50 + } + + traces, err := h.metrics.GetTraces(c.Request.Context(), filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(traces), + "traces": traces, + }) +} + +// GetTrace returns a single trace. +func (h *Handlers) GetTrace(c *gin.Context) { + traceID := c.Param("trace_id") + + trace, err := h.metrics.GetTrace(c.Request.Context(), traceID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, trace) +} + +// ================== Credentials ================== + +// ListCredentials returns all available credentials. +func (h *Handlers) ListCredentials(c *gin.Context) { + log.Info("[UnifiedRouting] ListCredentials called") + if h.authManager == nil { + log.Error("[UnifiedRouting] ListCredentials: auth manager is nil") + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth manager unavailable"}) + return + } + + typeFilter := c.Query("type") + providerFilter := c.Query("provider") + log.Infof("[UnifiedRouting] ListCredentials: type=%s, provider=%s", typeFilter, providerFilter) + + auths := h.authManager.List() + log.Infof("[UnifiedRouting] ListCredentials: found %d auths", len(auths)) + reg := registry.GetGlobalRegistry() + + credentials := make([]CredentialInfo, 0) + + for _, auth := range auths { + // Skip disabled/removed auth entries + if auth.Disabled || auth.Status == coreauth.StatusDisabled { + continue + } + + // Determine type + credType := "oauth" + if auth.Attributes != nil { + if _, ok := auth.Attributes["api_key"]; ok { + credType = "api-key" + } + } + + // Apply filters + if typeFilter != "" && credType != typeFilter && typeFilter != "all" { + continue + } + if providerFilter != "" && auth.Provider != providerFilter { + continue + } + + // Get models + models := reg.GetModelsForClient(auth.ID) + modelInfos := make([]ModelInfo, 0, len(models)) + for _, m := range models { + modelInfos = append(modelInfos, ModelInfo{ + ID: m.ID, + Name: m.ID, + Available: true, + }) + } + + cred := CredentialInfo{ + ID: auth.ID, + Provider: auth.Provider, + Type: credType, + Label: auth.Label, + Prefix: auth.Prefix, + Status: string(auth.Status), + Models: modelInfos, + } + + // Add masked API key if present + if auth.Attributes != nil { + if apiKey, ok := auth.Attributes["api_key"]; ok { + cred.APIKey = util.HideAPIKey(apiKey) + } + if baseURL, ok := auth.Attributes["base_url"]; ok { + cred.BaseURL = baseURL + } + } + + credentials = append(credentials, cred) + } + + c.JSON(http.StatusOK, gin.H{ + "total": len(credentials), + "credentials": credentials, + }) +} + +// GetCredential returns a single credential. +func (h *Handlers) GetCredential(c *gin.Context) { + credentialID := c.Param("credential_id") + + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth manager unavailable"}) + return + } + + auths := h.authManager.List() + for _, auth := range auths { + if auth.ID == credentialID { + reg := registry.GetGlobalRegistry() + models := reg.GetModelsForClient(auth.ID) + + modelInfos := make([]ModelInfo, 0, len(models)) + for _, m := range models { + modelInfos = append(modelInfos, ModelInfo{ + ID: m.ID, + Name: m.ID, + Available: true, + }) + } + + credType := "oauth" + if auth.Attributes != nil { + if _, ok := auth.Attributes["api_key"]; ok { + credType = "api-key" + } + } + + cred := CredentialInfo{ + ID: auth.ID, + Provider: auth.Provider, + Type: credType, + Label: auth.Label, + Prefix: auth.Prefix, + Status: string(auth.Status), + Models: modelInfos, + } + + if auth.Attributes != nil { + if apiKey, ok := auth.Attributes["api_key"]; ok { + cred.APIKey = util.HideAPIKey(apiKey) + } + if baseURL, ok := auth.Attributes["base_url"]; ok { + cred.BaseURL = baseURL + } + } + + c.JSON(http.StatusOK, cred) + return + } + } + + c.JSON(http.StatusNotFound, gin.H{"error": "credential not found"}) +} + +// ================== Simulate Route ================== + +// SimulateRouteRequest represents a request to simulate routing. +type SimulateRouteRequest struct { + DryRun bool `json:"dry_run"` // If true, don't actually make requests, just check availability +} + +// SimulateRouteResponse represents the result of a route simulation. +type SimulateRouteResponse struct { + RouteID string `json:"route_id"` + RouteName string `json:"route_name"` + Success bool `json:"success"` + FinalTarget *SimulateTargetResult `json:"final_target,omitempty"` + Attempts []SimulateLayerResult `json:"attempts"` + TotalTimeMs int64 `json:"total_time_ms"` +} + +// SimulateLayerResult represents the result of trying a layer. +type SimulateLayerResult struct { + Layer int `json:"layer"` + Targets []SimulateTargetResult `json:"targets"` +} + +// SimulateTargetResult represents the result of trying a target. +type SimulateTargetResult struct { + TargetID string `json:"target_id"` + CredentialID string `json:"credential_id"` + Model string `json:"model"` + Status string `json:"status"` // "success", "failed", "skipped" + Message string `json:"message,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` +} + +// SimulateRoute simulates the routing process for a specific route. +// It follows the exact same logic as ExecuteWithFailover: +// - Uses the same load balancing strategy (round-robin, weighted, etc.) +// - Records success/failure statistics +// - Starts cooldown on failure +// - Records the request trace +// The only difference is it uses health check instead of a real API request. +func (h *Handlers) SimulateRoute(c *gin.Context) { + routeID := c.Param("route_id") + + var req SimulateRouteRequest + _ = c.ShouldBindJSON(&req) + + ctx := c.Request.Context() + + // Get route + route, err := h.configSvc.GetRoute(ctx, routeID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "route not found"}) + return + } + + // Get pipeline + pipeline, err := h.configSvc.GetPipeline(ctx, routeID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Get health check config for cooldown duration + healthConfig, _ := h.configSvc.GetHealthCheckConfig(ctx) + if healthConfig == nil { + cfg := DefaultHealthCheckConfig() + healthConfig = &cfg + } + + response := SimulateRouteResponse{ + RouteID: routeID, + RouteName: route.Name, + Attempts: make([]SimulateLayerResult, 0), + } + + startTime := time.Now() + traceBuilder := NewTraceBuilder(routeID, route.Name) + + // Follow the exact same logic as ExecuteWithFailover + for layerIdx, layer := range pipeline.Layers { + layerResult := SimulateLayerResult{ + Layer: layer.Level, + Targets: make([]SimulateTargetResult, 0), + } + + // Calculate cooldown duration for this layer + cooldownDuration := time.Duration(layer.CooldownSeconds) * time.Second + if cooldownDuration == 0 { + cooldownDuration = time.Duration(healthConfig.DefaultCooldownSeconds) * time.Second + } + + // Keep trying targets in this layer until no available targets remain + // SelectTarget automatically excludes cooling-down targets + for { + // Use engine's SelectTarget for proper load balancing (round-robin, weighted, etc.) + target, err := h.engine.SelectTarget(ctx, routeID, &layer) + if err != nil { + // No more available targets in this layer + break + } + + targetResult := SimulateTargetResult{ + TargetID: target.ID, + CredentialID: target.CredentialID, + Model: target.Model, + } + + if req.DryRun { + // Just check availability without making requests + targetResult.Status = "success" + targetResult.Message = "target is available (dry run)" + layerResult.Targets = append(layerResult.Targets, targetResult) + response.Attempts = append(response.Attempts, layerResult) + response.Success = true + response.FinalTarget = &targetResult + response.TotalTimeMs = time.Since(startTime).Milliseconds() + c.JSON(http.StatusOK, response) + return + } + + // Perform health check (simulating a real request) + checkStart := time.Now() + result, checkErr := h.healthChecker.CheckTarget(ctx, target.ID) + latency := time.Since(checkStart) + targetResult.LatencyMs = latency.Milliseconds() + + if checkErr != nil || result.Status != "healthy" { + // Failed - record failure and start cooldown (same as real request) + errMsg := "health check failed" + if checkErr != nil { + errMsg = checkErr.Error() + } else if result.Message != "" { + errMsg = result.Message + } + + targetResult.Status = "failed" + targetResult.Message = errMsg + layerResult.Targets = append(layerResult.Targets, targetResult) + + // Note: RecordFailure is already called by CheckTarget, so we don't call it again here + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Failed(errMsg) + + // Start cooldown immediately on failure (CheckTarget doesn't do this) + h.stateMgr.StartCooldown(ctx, target.ID, cooldownDuration) + h.metrics.RecordEvent(&RoutingEvent{ + Type: EventCooldownStarted, + RouteID: routeID, + TargetID: target.ID, + Details: map[string]any{ + "duration_seconds": int(cooldownDuration.Seconds()), + "reason": errMsg, + "source": "simulate", + }, + }) + + // Continue to next target in this layer + continue + } + + // Success! + targetResult.Status = "success" + targetResult.Message = "health check passed" + if result.LatencyMs > 0 { + targetResult.LatencyMs = result.LatencyMs + } + layerResult.Targets = append(layerResult.Targets, targetResult) + + // Note: RecordSuccess is already called by CheckTarget, so we don't call it again here + traceBuilder.AddAttempt(layer.Level, target.ID, target.CredentialID, target.Model). + Success(targetResult.LatencyMs) + + // Record the trace + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + h.metrics.RecordRequest(trace) + + response.Attempts = append(response.Attempts, layerResult) + response.Success = true + response.FinalTarget = &targetResult + response.TotalTimeMs = time.Since(startTime).Milliseconds() + c.JSON(http.StatusOK, response) + return + } + + // Add layer result if we tried any targets in it + if len(layerResult.Targets) > 0 { + response.Attempts = append(response.Attempts, layerResult) + } + + // Record layer fallback event when moving to next layer + if layerIdx < len(pipeline.Layers)-1 && len(layerResult.Targets) > 0 { + h.metrics.RecordEvent(&RoutingEvent{ + Type: EventLayerFallback, + RouteID: routeID, + Details: map[string]any{ + "from_layer": layer.Level, + "to_layer": layer.Level + 1, + "source": "simulate", + }, + }) + } + } + + // All layers exhausted - record failed trace + trace := traceBuilder.Build(time.Since(startTime).Milliseconds()) + h.metrics.RecordRequest(trace) + + response.TotalTimeMs = time.Since(startTime).Milliseconds() + c.JSON(http.StatusOK, response) +} diff --git a/internal/api/modules/unified-routing/health_checker.go b/internal/api/modules/unified-routing/health_checker.go new file mode 100644 index 000000000..b04c91bf8 --- /dev/null +++ b/internal/api/modules/unified-routing/health_checker.go @@ -0,0 +1,425 @@ +package unifiedrouting + +import ( + "context" + "encoding/json" + "sync" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" +) + +// HealthChecker performs health checks on routing targets. +type HealthChecker interface { + // Trigger checks + CheckAll(ctx context.Context) ([]*HealthResult, error) + CheckRoute(ctx context.Context, routeID string) ([]*HealthResult, error) + CheckTarget(ctx context.Context, targetID string) (*HealthResult, error) + + // Configuration + GetSettings(ctx context.Context) (*HealthCheckConfig, error) + UpdateSettings(ctx context.Context, settings *HealthCheckConfig) error + + // History + GetHistory(ctx context.Context, filter HealthHistoryFilter) ([]*HealthResult, error) + + // Background task control + Start(ctx context.Context) error + Stop(ctx context.Context) error +} + +// DefaultHealthChecker implements HealthChecker. +type DefaultHealthChecker struct { + configSvc ConfigService + stateMgr StateManager + metrics MetricsCollector + authManager *coreauth.Manager + + mu sync.RWMutex + history []*HealthResult + maxHistory int + + stopChan chan struct{} + running bool +} + +// NewHealthChecker creates a new health checker. +func NewHealthChecker( + configSvc ConfigService, + stateMgr StateManager, + metrics MetricsCollector, + authManager *coreauth.Manager, +) *DefaultHealthChecker { + return &DefaultHealthChecker{ + configSvc: configSvc, + stateMgr: stateMgr, + metrics: metrics, + authManager: authManager, + history: make([]*HealthResult, 0, 1000), + maxHistory: 1000, + stopChan: make(chan struct{}), + } +} + +func (h *DefaultHealthChecker) CheckAll(ctx context.Context) ([]*HealthResult, error) { + routes, err := h.configSvc.ListRoutes(ctx) + if err != nil { + return nil, err + } + + var results []*HealthResult + for _, route := range routes { + routeResults, err := h.CheckRoute(ctx, route.ID) + if err != nil { + continue + } + results = append(results, routeResults...) + } + + return results, nil +} + +func (h *DefaultHealthChecker) CheckRoute(ctx context.Context, routeID string) ([]*HealthResult, error) { + pipeline, err := h.configSvc.GetPipeline(ctx, routeID) + if err != nil { + return nil, err + } + + var results []*HealthResult + for _, layer := range pipeline.Layers { + for _, target := range layer.Targets { + if !target.Enabled { + continue + } + result, err := h.CheckTarget(ctx, target.ID) + if err != nil { + results = append(results, &HealthResult{ + TargetID: target.ID, + CredentialID: target.CredentialID, + Model: target.Model, + Status: "unhealthy", + Message: err.Error(), + CheckedAt: time.Now(), + }) + continue + } + results = append(results, result) + } + } + + return results, nil +} + +func (h *DefaultHealthChecker) CheckTarget(ctx context.Context, targetID string) (*HealthResult, error) { + // Find the target configuration + routes, err := h.configSvc.ListRoutes(ctx) + if err != nil { + return nil, err + } + + var target *Target + for _, route := range routes { + pipeline, err := h.configSvc.GetPipeline(ctx, route.ID) + if err != nil { + continue + } + for _, layer := range pipeline.Layers { + for i := range layer.Targets { + if layer.Targets[i].ID == targetID { + target = &layer.Targets[i] + break + } + } + if target != nil { + break + } + } + if target != nil { + break + } + } + + if target == nil { + return nil, &TargetNotFoundError{TargetID: targetID} + } + + // Perform health check + result := h.performHealthCheck(ctx, target) + + // Record result + h.recordResult(result) + + // Update state based on result + if result.Status == "healthy" { + h.stateMgr.RecordSuccess(ctx, targetID, time.Duration(result.LatencyMs)*time.Millisecond) + } else { + h.stateMgr.RecordFailure(ctx, targetID, result.Message) + } + + // Record event + eventType := EventTargetRecovered + if result.Status == "unhealthy" { + eventType = EventTargetFailed + } + h.metrics.RecordEvent(&RoutingEvent{ + Type: eventType, + RouteID: "", + TargetID: targetID, + Details: map[string]any{ + "status": result.Status, + "latency_ms": result.LatencyMs, + "message": result.Message, + }, + }) + + return result, nil +} + +func (h *DefaultHealthChecker) performHealthCheck(ctx context.Context, target *Target) *HealthResult { + result := &HealthResult{ + TargetID: target.ID, + CredentialID: target.CredentialID, + Model: target.Model, + CheckedAt: time.Now(), + } + + if h.authManager == nil { + result.Status = "unhealthy" + result.Message = "auth manager unavailable" + return result + } + + // Find the auth entry for this credential + auths := h.authManager.List() + var targetAuth *coreauth.Auth + for _, auth := range auths { + if auth.ID == target.CredentialID { + targetAuth = auth + break + } + } + + if targetAuth == nil { + result.Status = "unhealthy" + result.Message = "credential not found" + return result + } + + // Build minimal request for health check + openAIRequest := map[string]interface{}{ + "model": target.Model, + "messages": []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, + "stream": true, + "max_tokens": 1, + } + + requestJSON, err := json.Marshal(openAIRequest) + if err != nil { + result.Status = "unhealthy" + result.Message = "failed to build request" + return result + } + + // Get health check config for timeout + healthConfig, _ := h.configSvc.GetHealthCheckConfig(ctx) + if healthConfig == nil { + cfg := DefaultHealthCheckConfig() + healthConfig = &cfg + } + + checkCtx, cancel := context.WithTimeout(ctx, time.Duration(healthConfig.CheckTimeoutSeconds)*time.Second) + defer cancel() + + startTime := time.Now() + + // Execute health check request + req := cliproxyexecutor.Request{ + Model: target.Model, + Payload: requestJSON, + Format: sdktranslator.FormatOpenAI, + } + + opts := cliproxyexecutor.Options{ + Stream: true, + SourceFormat: sdktranslator.FormatOpenAI, + OriginalRequest: requestJSON, + } + + stream, err := h.authManager.ExecuteStreamWithAuth(checkCtx, targetAuth, req, opts) + if err != nil { + result.Status = "unhealthy" + result.Message = err.Error() + return result + } + + // Wait for first chunk + select { + case chunk, ok := <-stream: + if ok { + if chunk.Err != nil { + result.Status = "unhealthy" + result.Message = chunk.Err.Error() + } else { + result.Status = "healthy" + result.LatencyMs = time.Since(startTime).Milliseconds() + } + // Drain remaining chunks + cancel() + go func() { + for range stream { + } + }() + } else { + result.Status = "unhealthy" + result.Message = "stream closed without data" + } + case <-checkCtx.Done(): + result.Status = "unhealthy" + result.Message = "health check timeout" + } + + return result +} + +func (h *DefaultHealthChecker) recordResult(result *HealthResult) { + h.mu.Lock() + defer h.mu.Unlock() + + // Ring buffer behavior + if len(h.history) >= h.maxHistory { + h.history = h.history[1:] + } + h.history = append(h.history, result) +} + +func (h *DefaultHealthChecker) GetSettings(ctx context.Context) (*HealthCheckConfig, error) { + return h.configSvc.GetHealthCheckConfig(ctx) +} + +func (h *DefaultHealthChecker) UpdateSettings(ctx context.Context, settings *HealthCheckConfig) error { + return h.configSvc.UpdateHealthCheckConfig(ctx, settings) +} + +func (h *DefaultHealthChecker) GetHistory(ctx context.Context, filter HealthHistoryFilter) ([]*HealthResult, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + var results []*HealthResult + for i := len(h.history) - 1; i >= 0; i-- { + result := h.history[i] + + // Apply filters + if filter.TargetID != "" && result.TargetID != filter.TargetID { + continue + } + if filter.Status != "" && result.Status != filter.Status { + continue + } + if !filter.Since.IsZero() && result.CheckedAt.Before(filter.Since) { + continue + } + + results = append(results, result) + + if filter.Limit > 0 && len(results) >= filter.Limit { + break + } + } + + return results, nil +} + +func (h *DefaultHealthChecker) Start(ctx context.Context) error { + h.mu.Lock() + if h.running { + h.mu.Unlock() + return nil + } + h.running = true + h.stopChan = make(chan struct{}) + h.mu.Unlock() + + go h.runBackgroundChecks(ctx) + return nil +} + +func (h *DefaultHealthChecker) Stop(ctx context.Context) error { + h.mu.Lock() + defer h.mu.Unlock() + + if !h.running { + return nil + } + + close(h.stopChan) + h.running = false + return nil +} + +func (h *DefaultHealthChecker) runBackgroundChecks(ctx context.Context) { + // Get check interval from config + healthConfig, _ := h.configSvc.GetHealthCheckConfig(ctx) + if healthConfig == nil { + cfg := DefaultHealthCheckConfig() + healthConfig = &cfg + } + + ticker := time.NewTicker(time.Duration(healthConfig.CheckIntervalSeconds) * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // Check targets that are in cooldown + h.checkCoolingTargets(ctx) + + case <-h.stopChan: + return + + case <-ctx.Done(): + return + } + } +} + +func (h *DefaultHealthChecker) checkCoolingTargets(ctx context.Context) { + states, err := h.stateMgr.ListTargetStates(ctx) + if err != nil { + return + } + + for _, state := range states { + if state.Status != StatusCooling { + continue + } + + // Check if cooldown has expired + if state.CooldownEndsAt == nil || time.Now().After(*state.CooldownEndsAt) { + // Perform health check + result, err := h.CheckTarget(ctx, state.TargetID) + if err != nil { + log.Debugf("health check failed for target %s: %v", state.TargetID, err) + continue + } + + if result.Status == "healthy" { + h.stateMgr.EndCooldown(ctx, state.TargetID) + log.Infof("target %s recovered after health check", state.TargetID) + } + } + } +} + +// TargetNotFoundError is returned when a target is not found. +type TargetNotFoundError struct { + TargetID string +} + +func (e *TargetNotFoundError) Error() string { + return "target not found: " + e.TargetID +} diff --git a/internal/api/modules/unified-routing/metrics.go b/internal/api/modules/unified-routing/metrics.go new file mode 100644 index 000000000..e3fc792fc --- /dev/null +++ b/internal/api/modules/unified-routing/metrics.go @@ -0,0 +1,356 @@ +package unifiedrouting + +import ( + "context" + "sync" + "time" + + "github.com/google/uuid" +) + +// MetricsCollector collects and provides metrics for unified routing. +type MetricsCollector interface { + // Recording + RecordRequest(trace *RequestTrace) + RecordEvent(event *RoutingEvent) + + // Queries + GetStats(ctx context.Context, filter StatsFilter) (*AggregatedStats, error) + GetRouteStats(ctx context.Context, routeID string, filter StatsFilter) (*AggregatedStats, error) + GetTargetStats(ctx context.Context, targetID string, filter StatsFilter) (*AggregatedStats, error) + + // Events + GetEvents(ctx context.Context, filter EventFilter) ([]*RoutingEvent, error) + + // Traces + GetTraces(ctx context.Context, filter TraceFilter) ([]*RequestTrace, error) + GetTrace(ctx context.Context, traceID string) (*RequestTrace, error) + + // Real-time subscriptions + Subscribe(ctx context.Context) (<-chan MetricUpdate, error) +} + +// MetricUpdate represents a real-time metric update. +type MetricUpdate struct { + Type string `json:"type"` // "trace", "event", "stats" + Timestamp time.Time `json:"timestamp"` + Data interface{} `json:"data"` +} + +// DefaultMetricsCollector implements MetricsCollector. +type DefaultMetricsCollector struct { + store MetricsStore + mu sync.RWMutex + subscribers []chan MetricUpdate +} + +// NewMetricsCollector creates a new metrics collector. +func NewMetricsCollector(store MetricsStore) *DefaultMetricsCollector { + return &DefaultMetricsCollector{ + store: store, + subscribers: make([]chan MetricUpdate, 0), + } +} + +func (c *DefaultMetricsCollector) RecordRequest(trace *RequestTrace) { + if trace.TraceID == "" { + trace.TraceID = "trace-" + uuid.New().String()[:8] + } + if trace.Timestamp.IsZero() { + trace.Timestamp = time.Now() + } + + ctx := context.Background() + _ = c.store.RecordTrace(ctx, trace) + + // Notify subscribers + c.broadcast(MetricUpdate{ + Type: "trace", + Timestamp: time.Now(), + Data: trace, + }) +} + +func (c *DefaultMetricsCollector) RecordEvent(event *RoutingEvent) { + if event.ID == "" { + event.ID = "evt-" + uuid.New().String()[:8] + } + if event.Timestamp.IsZero() { + event.Timestamp = time.Now() + } + + ctx := context.Background() + _ = c.store.RecordEvent(ctx, event) + + // Notify subscribers + c.broadcast(MetricUpdate{ + Type: "event", + Timestamp: time.Now(), + Data: event, + }) +} + +func (c *DefaultMetricsCollector) GetStats(ctx context.Context, filter StatsFilter) (*AggregatedStats, error) { + return c.store.GetStats(ctx, filter) +} + +func (c *DefaultMetricsCollector) GetRouteStats(ctx context.Context, routeID string, filter StatsFilter) (*AggregatedStats, error) { + // Get all traces for this route and calculate stats + traces, err := c.store.GetTraces(ctx, TraceFilter{RouteID: routeID, Limit: 10000}) + if err != nil { + return nil, err + } + + return c.calculateStats(traces, filter), nil +} + +func (c *DefaultMetricsCollector) GetTargetStats(ctx context.Context, targetID string, filter StatsFilter) (*AggregatedStats, error) { + // Get all traces and filter by target + traces, err := c.store.GetTraces(ctx, TraceFilter{Limit: 10000}) + if err != nil { + return nil, err + } + + // Filter traces that used this target + var filteredTraces []*RequestTrace + for _, trace := range traces { + for _, attempt := range trace.Attempts { + if attempt.TargetID == targetID { + filteredTraces = append(filteredTraces, trace) + break + } + } + } + + return c.calculateStats(filteredTraces, filter), nil +} + +func (c *DefaultMetricsCollector) calculateStats(traces []*RequestTrace, filter StatsFilter) *AggregatedStats { + stats := &AggregatedStats{ + Period: filter.Period, + } + + // Calculate time range + var since time.Time + switch filter.Period { + case "1h": + since = time.Now().Add(-1 * time.Hour) + case "24h": + since = time.Now().Add(-24 * time.Hour) + case "7d": + since = time.Now().Add(-7 * 24 * time.Hour) + case "30d": + since = time.Now().Add(-30 * 24 * time.Hour) + default: + since = time.Now().Add(-1 * time.Hour) + } + + var totalLatency int64 + layerCounts := make(map[int]int64) + targetStats := make(map[string]*TargetDistribution) + + for _, trace := range traces { + if trace.Timestamp.Before(since) { + continue + } + + stats.TotalRequests++ + totalLatency += trace.TotalLatencyMs + + switch trace.Status { + case TraceStatusSuccess, TraceStatusRetry, TraceStatusFallback: + stats.SuccessfulRequests++ + case TraceStatusFailed: + stats.FailedRequests++ + } + + // Track layer and target distribution + for _, attempt := range trace.Attempts { + if attempt.Status == AttemptStatusSuccess { + layerCounts[attempt.Layer]++ + + if _, ok := targetStats[attempt.TargetID]; !ok { + targetStats[attempt.TargetID] = &TargetDistribution{ + TargetID: attempt.TargetID, + CredentialID: attempt.CredentialID, + } + } + targetStats[attempt.TargetID].Requests++ + break + } + } + } + + if stats.TotalRequests > 0 { + stats.SuccessRate = float64(stats.SuccessfulRequests) / float64(stats.TotalRequests) + stats.AvgLatencyMs = totalLatency / stats.TotalRequests + } + + // Build distributions + for level, count := range layerCounts { + percentage := float64(0) + if stats.TotalRequests > 0 { + percentage = float64(count) / float64(stats.TotalRequests) * 100 + } + stats.LayerDistribution = append(stats.LayerDistribution, LayerDistribution{ + Level: level, + Requests: count, + Percentage: percentage, + }) + } + + for _, td := range targetStats { + stats.TargetDistribution = append(stats.TargetDistribution, *td) + } + + return stats +} + +func (c *DefaultMetricsCollector) GetEvents(ctx context.Context, filter EventFilter) ([]*RoutingEvent, error) { + return c.store.GetEvents(ctx, filter) +} + +func (c *DefaultMetricsCollector) GetTraces(ctx context.Context, filter TraceFilter) ([]*RequestTrace, error) { + return c.store.GetTraces(ctx, filter) +} + +func (c *DefaultMetricsCollector) GetTrace(ctx context.Context, traceID string) (*RequestTrace, error) { + return c.store.GetTrace(ctx, traceID) +} + +func (c *DefaultMetricsCollector) Subscribe(ctx context.Context) (<-chan MetricUpdate, error) { + ch := make(chan MetricUpdate, 100) + + c.mu.Lock() + c.subscribers = append(c.subscribers, ch) + c.mu.Unlock() + + // Clean up when context is done + go func() { + <-ctx.Done() + c.mu.Lock() + for i, sub := range c.subscribers { + if sub == ch { + c.subscribers = append(c.subscribers[:i], c.subscribers[i+1:]...) + break + } + } + c.mu.Unlock() + close(ch) + }() + + return ch, nil +} + +func (c *DefaultMetricsCollector) broadcast(update MetricUpdate) { + c.mu.RLock() + subscribers := c.subscribers + c.mu.RUnlock() + + for _, ch := range subscribers { + select { + case ch <- update: + default: + // Channel full, skip + } + } +} + +// TraceBuilder helps build request traces. +type TraceBuilder struct { + trace *RequestTrace +} + +// NewTraceBuilder creates a new trace builder. +func NewTraceBuilder(routeID, routeName string) *TraceBuilder { + return &TraceBuilder{ + trace: &RequestTrace{ + TraceID: "trace-" + uuid.New().String()[:8], + RouteID: routeID, + RouteName: routeName, + Timestamp: time.Now(), + Attempts: make([]AttemptTrace, 0), + }, + } +} + +// AddAttempt adds an attempt to the trace. +func (b *TraceBuilder) AddAttempt(layer int, targetID, credentialID, model string) *AttemptBuilder { + attempt := AttemptTrace{ + Attempt: len(b.trace.Attempts) + 1, + Layer: layer, + TargetID: targetID, + CredentialID: credentialID, + Model: model, + } + b.trace.Attempts = append(b.trace.Attempts, attempt) + return &AttemptBuilder{ + trace: b.trace, + attempt: &b.trace.Attempts[len(b.trace.Attempts)-1], + } +} + +// Build finalizes and returns the trace. +func (b *TraceBuilder) Build(totalLatencyMs int64) *RequestTrace { + b.trace.TotalLatencyMs = totalLatencyMs + + // Determine trace status based on attempts + hasSuccess := false + hasRetry := false + hasFallback := false + + for i, attempt := range b.trace.Attempts { + if attempt.Status == AttemptStatusSuccess { + hasSuccess = true + if i > 0 { + // Check if we retried or fell back + prevLayer := b.trace.Attempts[i-1].Layer + if attempt.Layer > prevLayer { + hasFallback = true + } else { + hasRetry = true + } + } + break + } + } + + if !hasSuccess { + b.trace.Status = TraceStatusFailed + } else if hasFallback { + b.trace.Status = TraceStatusFallback + } else if hasRetry { + b.trace.Status = TraceStatusRetry + } else { + b.trace.Status = TraceStatusSuccess + } + + return b.trace +} + +// AttemptBuilder helps build attempt traces. +type AttemptBuilder struct { + trace *RequestTrace + attempt *AttemptTrace +} + +// Success marks the attempt as successful. +func (b *AttemptBuilder) Success(latencyMs int64) *TraceBuilder { + b.attempt.Status = AttemptStatusSuccess + b.attempt.LatencyMs = latencyMs + return &TraceBuilder{trace: b.trace} +} + +// Failed marks the attempt as failed. +func (b *AttemptBuilder) Failed(err string) *TraceBuilder { + b.attempt.Status = AttemptStatusFailed + b.attempt.Error = err + return &TraceBuilder{trace: b.trace} +} + +// Skipped marks the attempt as skipped. +func (b *AttemptBuilder) Skipped(reason string) *TraceBuilder { + b.attempt.Status = AttemptStatusSkipped + b.attempt.Error = reason + return &TraceBuilder{trace: b.trace} +} diff --git a/internal/api/modules/unified-routing/module.go b/internal/api/modules/unified-routing/module.go new file mode 100644 index 000000000..16291c2a5 --- /dev/null +++ b/internal/api/modules/unified-routing/module.go @@ -0,0 +1,306 @@ +package unifiedrouting + +import ( + "os" + "path/filepath" + "sync" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// Option configures the Module. +type Option func(*Module) + +// Module implements the RouteModuleV2 interface for unified routing. +type Module struct { + authManager *coreauth.Manager + authMiddleware gin.HandlerFunc + + configStore ConfigStore + stateStore StateStore + metricsStore MetricsStore + + configSvc ConfigService + stateMgr StateManager + metrics MetricsCollector + healthChecker HealthChecker + engine RoutingEngine + handlers *Handlers + + initOnce sync.Once + routesOnce sync.Once + dataDir string + skipAutoRoutes bool // If true, routes won't be registered in Register() +} + +// New creates a new unified routing module. +func New(opts ...Option) *Module { + m := &Module{} + for _, opt := range opts { + opt(m) + } + return m +} + +// WithAuthManager sets the auth manager. +func WithAuthManager(am *coreauth.Manager) Option { + return func(m *Module) { + m.authManager = am + } +} + +// WithAuthMiddleware sets the authentication middleware. +func WithAuthMiddleware(middleware gin.HandlerFunc) Option { + return func(m *Module) { + m.authMiddleware = middleware + } +} + +// WithDataDir sets the data directory for configuration storage. +func WithDataDir(dir string) Option { + return func(m *Module) { + m.dataDir = dir + } +} + +// WithSkipAutoRoutes skips automatic route registration in Register(). +// Use this when you want to register routes manually via RegisterRoutes(). +func WithSkipAutoRoutes() Option { + return func(m *Module) { + m.skipAutoRoutes = true + } +} + +// Name returns the module identifier. +func (m *Module) Name() string { + return "unified-routing" +} + +// Register sets up unified routing routes. +func (m *Module) Register(ctx modules.Context) error { + log.Info("[UnifiedRouting] Register() called") + + // Initialize module (only once) + if err := m.init(ctx); err != nil { + return err + } + + // Register routes unless skipAutoRoutes is set + if !m.skipAutoRoutes { + auth := m.getAuthMiddleware(ctx) + log.Info("[UnifiedRouting] Auth middleware configured (auto)") + m.RegisterRoutes(ctx.Engine, auth) + } else { + log.Info("[UnifiedRouting] Skipping auto route registration (will be registered later)") + } + + log.Info("[UnifiedRouting] Module registered successfully") + return nil +} + +// init initializes the module services (only once). +func (m *Module) init(ctx modules.Context) error { + var initErr error + + m.initOnce.Do(func() { + log.Info("[UnifiedRouting] Initializing module...") + // Determine data directory + dataDir := m.dataDir + if dataDir == "" { + // Default to auth-dir/unified-routing + authDir := ctx.Config.AuthDir + if authDir == "" { + authDir = "~/.cli-proxy-api" + } + // Expand ~ if present + if authDir[0] == '~' { + home, _ := os.UserHomeDir() + authDir = filepath.Join(home, authDir[1:]) + } + dataDir = filepath.Join(authDir, "unified-routing") + } + log.Infof("[UnifiedRouting] Data directory: %s", dataDir) + + // Initialize stores + configStore, err := NewFileConfigStore(dataDir) + if err != nil { + initErr = err + return + } + m.configStore = configStore + m.stateStore = NewMemoryStateStore() + + // Use the shared logging module to resolve the logs directory + baseLogsDir := logging.ResolveLogDirectory(ctx.Config) + logsDir := filepath.Join(baseLogsDir, "unified-routing") + metricsStore, err := NewFileMetricsStore(logsDir, 100) // 100MB max for traces + if err != nil { + initErr = err + return + } + m.metricsStore = metricsStore + log.Infof("[UnifiedRouting] Logs directory: %s", logsDir) + + // Initialize services + m.configSvc = NewConfigService(m.configStore) + m.stateMgr = NewStateManager(m.stateStore, m.configSvc) + m.metrics = NewMetricsCollector(m.metricsStore) + m.healthChecker = NewHealthChecker(m.configSvc, m.stateMgr, m.metrics, m.authManager) + m.engine = NewRoutingEngine(m.configSvc, m.stateMgr, m.metrics, m.authManager) + + // Initialize handlers + m.handlers = NewHandlers(m.configSvc, m.stateMgr, m.metrics, m.healthChecker, m.authManager, m.engine) + + log.Info("[UnifiedRouting] Module initialization complete") + }) + + return initErr +} + +// getAuthMiddleware returns the authentication middleware. +func (m *Module) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { + if m.authMiddleware != nil { + return m.authMiddleware + } + if ctx.AuthMiddleware != nil { + return ctx.AuthMiddleware + } + // Fallback: no authentication + log.Warn("unified-routing module: no auth middleware provided, allowing all requests") + return func(c *gin.Context) { + c.Next() + } +} + +// RegisterRoutes registers all HTTP routes with the given auth middleware. +// This method can be called externally to register routes with custom auth. +// It will only register routes once (subsequent calls are no-ops). +func (m *Module) RegisterRoutes(engine *gin.Engine, auth gin.HandlerFunc) { + m.routesOnce.Do(func() { + log.Info("[UnifiedRouting] Registering routes...") + m.doRegisterRoutes(engine, auth) + log.Info("[UnifiedRouting] Routes registered") + }) +} + +// doRegisterRoutes performs the actual route registration. +func (m *Module) doRegisterRoutes(engine *gin.Engine, auth gin.HandlerFunc) { + // Base path: /v0/management/unified-routing + ur := engine.Group("/v0/management/unified-routing", auth) + + // Config: Settings + ur.GET("/config/settings", m.handlers.GetSettings) + ur.PUT("/config/settings", m.handlers.PutSettings) + + // Config: Health check settings + ur.GET("/config/health-check", m.handlers.GetHealthCheckConfig) + ur.PUT("/config/health-check", m.handlers.PutHealthCheckConfig) + + // Config: Routes + ur.GET("/config/routes", m.handlers.ListRoutes) + ur.POST("/config/routes", m.handlers.CreateRoute) + ur.GET("/config/routes/:route_id", m.handlers.GetRoute) + ur.PUT("/config/routes/:route_id", m.handlers.UpdateRoute) + ur.PATCH("/config/routes/:route_id", m.handlers.PatchRoute) + ur.DELETE("/config/routes/:route_id", m.handlers.DeleteRoute) + + // Config: Pipeline + ur.GET("/config/routes/:route_id/pipeline", m.handlers.GetPipeline) + ur.PUT("/config/routes/:route_id/pipeline", m.handlers.UpdatePipeline) + + // Config: Export/Import + ur.GET("/config/export", m.handlers.ExportConfig) + ur.POST("/config/import", m.handlers.ImportConfig) + ur.POST("/config/validate", m.handlers.ValidateConfig) + + // State + ur.GET("/state/overview", m.handlers.GetOverview) + ur.GET("/state/routes/:route_id", m.handlers.GetRouteStatus) + ur.GET("/state/targets/:target_id", m.handlers.GetTargetStatus) + ur.POST("/state/targets/:target_id/reset", m.handlers.ResetTarget) + ur.POST("/state/targets/:target_id/force-cooldown", m.handlers.ForceCooldown) + + // Health + ur.POST("/health/check", m.handlers.TriggerHealthCheck) + ur.POST("/health/check/routes/:route_id", m.handlers.TriggerHealthCheck) + ur.POST("/health/check/targets/:target_id", m.handlers.TriggerHealthCheck) + ur.GET("/health/settings", m.handlers.GetHealthSettings) + ur.PUT("/health/settings", m.handlers.UpdateHealthSettings) + ur.GET("/health/history", m.handlers.GetHealthHistory) + + // Simulate + ur.POST("/simulate/routes/:route_id", m.handlers.SimulateRoute) + + // Metrics + ur.GET("/metrics/stats", m.handlers.GetStats) + ur.GET("/metrics/stats/routes/:route_id", m.handlers.GetRouteStats) + ur.GET("/metrics/events", m.handlers.GetEvents) + ur.GET("/metrics/traces", m.handlers.GetTraces) + ur.GET("/metrics/traces/:trace_id", m.handlers.GetTrace) + + // Credentials + ur.GET("/credentials", m.handlers.ListCredentials) + ur.GET("/credentials/:credential_id", m.handlers.GetCredential) +} + +// OnConfigUpdated handles configuration updates. +func (m *Module) OnConfigUpdated(cfg *config.Config) error { + // Reload engine configuration + if m.engine != nil { + return m.engine.Reload(nil) + } + return nil +} + +// GetEngine returns the routing engine (for integration with main request handlers). +func (m *Module) GetEngine() RoutingEngine { + return m.engine +} + +// GetConfigService returns the config service. +func (m *Module) GetConfigService() ConfigService { + return m.configSvc +} + +// GetStateManager returns the state manager. +func (m *Module) GetStateManager() StateManager { + return m.stateMgr +} + +// GetMetricsCollector returns the metrics collector. +func (m *Module) GetMetricsCollector() MetricsCollector { + return m.metrics +} + +// GetHealthChecker returns the health checker. +func (m *Module) GetHealthChecker() HealthChecker { + return m.healthChecker +} + +// Start starts background tasks. +// Note: Background health checks are disabled by design. +// Cooldown expiration is handled automatically when a target is selected. +// Manual health checks can still be triggered via the API. +func (m *Module) Start() error { + // Background health checker is intentionally NOT started. + // Targets automatically become available again when their cooldown expires + // (checked in GetTargetState and SelectTarget). + // Manual health checks can still be triggered via POST /health/check endpoints. + return nil +} + +// Stop stops background tasks. +func (m *Module) Stop() error { + if m.healthChecker != nil { + return m.healthChecker.Stop(nil) + } + if sm, ok := m.stateMgr.(*DefaultStateManager); ok { + sm.Stop() + } + return nil +} diff --git a/internal/api/modules/unified-routing/state_manager.go b/internal/api/modules/unified-routing/state_manager.go new file mode 100644 index 000000000..e25c8281b --- /dev/null +++ b/internal/api/modules/unified-routing/state_manager.go @@ -0,0 +1,359 @@ +package unifiedrouting + +import ( + "context" + "sync" + "time" +) + +// StateManager manages runtime state for unified routing. +type StateManager interface { + // State queries + GetOverview(ctx context.Context) (*StateOverview, error) + GetRouteState(ctx context.Context, routeID string) (*RouteState, error) + GetTargetState(ctx context.Context, targetID string) (*TargetState, error) + ListTargetStates(ctx context.Context) ([]*TargetState, error) + + // State changes (called by engine) + RecordSuccess(ctx context.Context, targetID string, latency time.Duration) + RecordFailure(ctx context.Context, targetID string, reason string) + StartCooldown(ctx context.Context, targetID string, duration time.Duration) + EndCooldown(ctx context.Context, targetID string) + + // Manual operations + ResetTarget(ctx context.Context, targetID string) error + ForceCooldown(ctx context.Context, targetID string, duration time.Duration) error + + // Initialize/cleanup + InitializeTarget(ctx context.Context, targetID string) error + RemoveTarget(ctx context.Context, targetID string) error +} + +// DefaultStateManager implements StateManager. +type DefaultStateManager struct { + store StateStore + configSvc ConfigService + mu sync.RWMutex + cooldownChan chan string // Channel for cooldown expiry notifications + stopChan chan struct{} +} + +// NewStateManager creates a new state manager. +func NewStateManager(store StateStore, configSvc ConfigService) *DefaultStateManager { + sm := &DefaultStateManager{ + store: store, + configSvc: configSvc, + cooldownChan: make(chan string, 100), + stopChan: make(chan struct{}), + } + + // Start cooldown monitor + go sm.monitorCooldowns() + + return sm +} + +func (m *DefaultStateManager) GetOverview(ctx context.Context) (*StateOverview, error) { + settings, err := m.configSvc.GetSettings(ctx) + if err != nil { + return nil, err + } + + routes, err := m.configSvc.ListRoutes(ctx) + if err != nil { + return nil, err + } + + overview := &StateOverview{ + UnifiedRoutingEnabled: settings.Enabled, + HideOriginalModels: settings.HideOriginalModels, + TotalRoutes: len(routes), + Routes: make([]RouteState, 0, len(routes)), + } + + for _, route := range routes { + routeState, err := m.GetRouteState(ctx, route.ID) + if err != nil { + continue + } + + switch routeState.Status { + case "healthy": + overview.HealthyRoutes++ + case "degraded": + overview.DegradedRoutes++ + case "unhealthy": + overview.UnhealthyRoutes++ + } + + overview.Routes = append(overview.Routes, *routeState) + } + + return overview, nil +} + +func (m *DefaultStateManager) GetRouteState(ctx context.Context, routeID string) (*RouteState, error) { + route, err := m.configSvc.GetRoute(ctx, routeID) + if err != nil { + return nil, err + } + + pipeline, err := m.configSvc.GetPipeline(ctx, routeID) + if err != nil { + return nil, err + } + + routeState := &RouteState{ + RouteID: route.ID, + RouteName: route.Name, + ActiveLayer: 1, + LayerStates: make([]LayerState, 0, len(pipeline.Layers)), + } + + healthyTargets := 0 + totalTargets := 0 + activeLayerFound := false + + for _, layer := range pipeline.Layers { + layerState := LayerState{ + Level: layer.Level, + Status: "standby", + TargetStates: make([]*TargetState, 0, len(layer.Targets)), + } + + healthyInLayer := 0 + for _, target := range layer.Targets { + totalTargets++ + state, _ := m.store.GetTargetState(ctx, target.ID) + if state == nil { + state = &TargetState{ + TargetID: target.ID, + Status: StatusHealthy, + } + } + + // Check if cooldown has expired + if state.Status == StatusCooling && state.CooldownEndsAt != nil { + if time.Now().After(*state.CooldownEndsAt) { + state.Status = StatusHealthy + state.CooldownEndsAt = nil + } + } + + if state.Status == StatusHealthy { + healthyTargets++ + healthyInLayer++ + } + + layerState.TargetStates = append(layerState.TargetStates, state) + } + + // Determine layer status + if healthyInLayer > 0 && !activeLayerFound { + layerState.Status = "active" + routeState.ActiveLayer = layer.Level + activeLayerFound = true + } else if healthyInLayer == 0 { + layerState.Status = "exhausted" + } + + routeState.LayerStates = append(routeState.LayerStates, layerState) + } + + // Determine overall route status + if healthyTargets == totalTargets { + routeState.Status = "healthy" + } else if healthyTargets == 0 { + routeState.Status = "unhealthy" + } else { + routeState.Status = "degraded" + } + + return routeState, nil +} + +func (m *DefaultStateManager) GetTargetState(ctx context.Context, targetID string) (*TargetState, error) { + state, err := m.store.GetTargetState(ctx, targetID) + if err != nil { + return nil, err + } + + // Check if cooldown has expired + if state.Status == StatusCooling && state.CooldownEndsAt != nil { + if time.Now().After(*state.CooldownEndsAt) { + state.Status = StatusHealthy + state.CooldownEndsAt = nil + _ = m.store.SetTargetState(ctx, state) + } + } + + return state, nil +} + +func (m *DefaultStateManager) ListTargetStates(ctx context.Context) ([]*TargetState, error) { + return m.store.ListTargetStates(ctx) +} + +func (m *DefaultStateManager) RecordSuccess(ctx context.Context, targetID string, latency time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + state, _ := m.store.GetTargetState(ctx, targetID) + if state == nil { + state = &TargetState{TargetID: targetID} + } + + now := time.Now() + state.Status = StatusHealthy + state.ConsecutiveFailures = 0 + state.LastSuccessAt = &now + state.CooldownEndsAt = nil + state.TotalRequests++ + state.SuccessfulRequests++ + + _ = m.store.SetTargetState(ctx, state) +} + +func (m *DefaultStateManager) RecordFailure(ctx context.Context, targetID string, reason string) { + m.mu.Lock() + defer m.mu.Unlock() + + state, _ := m.store.GetTargetState(ctx, targetID) + if state == nil { + state = &TargetState{TargetID: targetID} + } + + now := time.Now() + state.ConsecutiveFailures++ + state.LastFailureAt = &now + state.LastFailureReason = reason + state.TotalRequests++ + + _ = m.store.SetTargetState(ctx, state) +} + +func (m *DefaultStateManager) StartCooldown(ctx context.Context, targetID string, duration time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + state, _ := m.store.GetTargetState(ctx, targetID) + if state == nil { + state = &TargetState{TargetID: targetID} + } + + cooldownEnd := time.Now().Add(duration) + state.Status = StatusCooling + state.CooldownEndsAt = &cooldownEnd + + _ = m.store.SetTargetState(ctx, state) + + // Schedule cooldown expiry + go func() { + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-timer.C: + m.cooldownChan <- targetID + case <-m.stopChan: + return + } + }() +} + +func (m *DefaultStateManager) EndCooldown(ctx context.Context, targetID string) { + m.mu.Lock() + defer m.mu.Unlock() + + state, _ := m.store.GetTargetState(ctx, targetID) + if state == nil { + return + } + + state.Status = StatusHealthy + state.CooldownEndsAt = nil + + _ = m.store.SetTargetState(ctx, state) +} + +func (m *DefaultStateManager) ResetTarget(ctx context.Context, targetID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + state := &TargetState{ + TargetID: targetID, + Status: StatusHealthy, + ConsecutiveFailures: 0, + CooldownEndsAt: nil, + } + + return m.store.SetTargetState(ctx, state) +} + +func (m *DefaultStateManager) ForceCooldown(ctx context.Context, targetID string, duration time.Duration) error { + m.StartCooldown(ctx, targetID, duration) + return nil +} + +func (m *DefaultStateManager) InitializeTarget(ctx context.Context, targetID string) error { + state := &TargetState{ + TargetID: targetID, + Status: StatusHealthy, + } + return m.store.SetTargetState(ctx, state) +} + +func (m *DefaultStateManager) RemoveTarget(ctx context.Context, targetID string) error { + return m.store.DeleteTargetState(ctx, targetID) +} + +func (m *DefaultStateManager) monitorCooldowns() { + for { + select { + case targetID := <-m.cooldownChan: + ctx := context.Background() + state, err := m.store.GetTargetState(ctx, targetID) + if err != nil || state == nil { + continue + } + + // Check if still in cooldown and cooldown has expired + if state.Status == StatusCooling && state.CooldownEndsAt != nil { + if time.Now().After(*state.CooldownEndsAt) { + m.EndCooldown(ctx, targetID) + } + } + + case <-m.stopChan: + return + } + } +} + +// Stop stops the state manager background tasks. +func (m *DefaultStateManager) Stop() { + close(m.stopChan) +} + +// IsTargetAvailable checks if a target is available for routing. +func (m *DefaultStateManager) IsTargetAvailable(ctx context.Context, targetID string) bool { + state, err := m.GetTargetState(ctx, targetID) + if err != nil { + return true // Default to available if error + } + return state.Status == StatusHealthy +} + +// GetAvailableTargetsInLayer returns available targets in a layer. +func (m *DefaultStateManager) GetAvailableTargetsInLayer(ctx context.Context, layer *Layer) []Target { + available := make([]Target, 0, len(layer.Targets)) + for _, target := range layer.Targets { + if !target.Enabled { + continue + } + if m.IsTargetAvailable(ctx, target.ID) { + available = append(available, target) + } + } + return available +} diff --git a/internal/api/modules/unified-routing/store.go b/internal/api/modules/unified-routing/store.go new file mode 100644 index 000000000..90acbaf22 --- /dev/null +++ b/internal/api/modules/unified-routing/store.go @@ -0,0 +1,743 @@ +package unifiedrouting + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v3" +) + +// ================== Store Interfaces ================== + +// ConfigStore defines the interface for configuration persistence. +type ConfigStore interface { + // Settings + LoadSettings(ctx context.Context) (*Settings, error) + SaveSettings(ctx context.Context, settings *Settings) error + + // Health check config + LoadHealthCheckConfig(ctx context.Context) (*HealthCheckConfig, error) + SaveHealthCheckConfig(ctx context.Context, config *HealthCheckConfig) error + + // Routes + ListRoutes(ctx context.Context) ([]*Route, error) + GetRoute(ctx context.Context, id string) (*Route, error) + CreateRoute(ctx context.Context, route *Route) error + UpdateRoute(ctx context.Context, route *Route) error + DeleteRoute(ctx context.Context, id string) error + + // Pipelines + GetPipeline(ctx context.Context, routeID string) (*Pipeline, error) + SavePipeline(ctx context.Context, routeID string, pipeline *Pipeline) error +} + +// StateStore defines the interface for runtime state storage (in-memory). +type StateStore interface { + GetTargetState(ctx context.Context, targetID string) (*TargetState, error) + SetTargetState(ctx context.Context, state *TargetState) error + ListTargetStates(ctx context.Context) ([]*TargetState, error) + DeleteTargetState(ctx context.Context, targetID string) error +} + +// MetricsStore defines the interface for metrics storage. +type MetricsStore interface { + // Traces + RecordTrace(ctx context.Context, trace *RequestTrace) error + GetTraces(ctx context.Context, filter TraceFilter) ([]*RequestTrace, error) + GetTrace(ctx context.Context, traceID string) (*RequestTrace, error) + + // Events + RecordEvent(ctx context.Context, event *RoutingEvent) error + GetEvents(ctx context.Context, filter EventFilter) ([]*RoutingEvent, error) + + // Stats (computed from traces) + GetStats(ctx context.Context, filter StatsFilter) (*AggregatedStats, error) +} + +// ================== File-based Config Store ================== + +// FileConfigStore implements ConfigStore using file-based persistence. +type FileConfigStore struct { + baseDir string + mu sync.RWMutex +} + +// NewFileConfigStore creates a new file-based config store. +func NewFileConfigStore(baseDir string) (*FileConfigStore, error) { + // Create directories if they don't exist + dirs := []string{ + baseDir, + filepath.Join(baseDir, "routes"), + filepath.Join(baseDir, "pipelines"), + } + for _, dir := range dirs { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create directory %s: %w", dir, err) + } + } + + return &FileConfigStore{baseDir: baseDir}, nil +} + +func (s *FileConfigStore) settingsPath() string { + return filepath.Join(s.baseDir, "settings.yaml") +} + +func (s *FileConfigStore) healthConfigPath() string { + return filepath.Join(s.baseDir, "health-config.yaml") +} + +func (s *FileConfigStore) routePath(id string) string { + return filepath.Join(s.baseDir, "routes", id+".yaml") +} + +func (s *FileConfigStore) pipelinePath(routeID string) string { + return filepath.Join(s.baseDir, "pipelines", routeID+".yaml") +} + +func (s *FileConfigStore) LoadSettings(ctx context.Context) (*Settings, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := os.ReadFile(s.settingsPath()) + if err != nil { + if os.IsNotExist(err) { + return &Settings{Enabled: false, HideOriginalModels: false}, nil + } + return nil, err + } + + var settings Settings + if err := yaml.Unmarshal(data, &settings); err != nil { + return nil, err + } + return &settings, nil +} + +func (s *FileConfigStore) SaveSettings(ctx context.Context, settings *Settings) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := yaml.Marshal(settings) + if err != nil { + return err + } + return os.WriteFile(s.settingsPath(), data, 0644) +} + +func (s *FileConfigStore) LoadHealthCheckConfig(ctx context.Context) (*HealthCheckConfig, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := os.ReadFile(s.healthConfigPath()) + if err != nil { + if os.IsNotExist(err) { + cfg := DefaultHealthCheckConfig() + return &cfg, nil + } + return nil, err + } + + var config HealthCheckConfig + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, err + } + return &config, nil +} + +func (s *FileConfigStore) SaveHealthCheckConfig(ctx context.Context, config *HealthCheckConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := yaml.Marshal(config) + if err != nil { + return err + } + return os.WriteFile(s.healthConfigPath(), data, 0644) +} + +func (s *FileConfigStore) ListRoutes(ctx context.Context) ([]*Route, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + routesDir := filepath.Join(s.baseDir, "routes") + entries, err := os.ReadDir(routesDir) + if err != nil { + if os.IsNotExist(err) { + return []*Route{}, nil + } + return nil, err + } + + var routes []*Route + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".yaml" { + continue + } + + data, err := os.ReadFile(filepath.Join(routesDir, entry.Name())) + if err != nil { + continue + } + + var route Route + if err := yaml.Unmarshal(data, &route); err != nil { + continue + } + routes = append(routes, &route) + } + + return routes, nil +} + +func (s *FileConfigStore) GetRoute(ctx context.Context, id string) (*Route, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := os.ReadFile(s.routePath(id)) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("route not found: %s", id) + } + return nil, err + } + + var route Route + if err := yaml.Unmarshal(data, &route); err != nil { + return nil, err + } + return &route, nil +} + +func (s *FileConfigStore) CreateRoute(ctx context.Context, route *Route) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if route already exists + if _, err := os.Stat(s.routePath(route.ID)); err == nil { + return fmt.Errorf("route already exists: %s", route.ID) + } + + route.CreatedAt = time.Now() + route.UpdatedAt = route.CreatedAt + + data, err := yaml.Marshal(route) + if err != nil { + return err + } + return os.WriteFile(s.routePath(route.ID), data, 0644) +} + +func (s *FileConfigStore) UpdateRoute(ctx context.Context, route *Route) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if route exists + if _, err := os.Stat(s.routePath(route.ID)); os.IsNotExist(err) { + return fmt.Errorf("route not found: %s", route.ID) + } + + route.UpdatedAt = time.Now() + + data, err := yaml.Marshal(route) + if err != nil { + return err + } + return os.WriteFile(s.routePath(route.ID), data, 0644) +} + +func (s *FileConfigStore) DeleteRoute(ctx context.Context, id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Delete route file + if err := os.Remove(s.routePath(id)); err != nil && !os.IsNotExist(err) { + return err + } + + // Delete pipeline file + if err := os.Remove(s.pipelinePath(id)); err != nil && !os.IsNotExist(err) { + return err + } + + return nil +} + +func (s *FileConfigStore) GetPipeline(ctx context.Context, routeID string) (*Pipeline, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := os.ReadFile(s.pipelinePath(routeID)) + if err != nil { + if os.IsNotExist(err) { + return &Pipeline{RouteID: routeID, Layers: []Layer{}}, nil + } + return nil, err + } + + var pipeline Pipeline + if err := yaml.Unmarshal(data, &pipeline); err != nil { + return nil, err + } + pipeline.RouteID = routeID + return &pipeline, nil +} + +func (s *FileConfigStore) SavePipeline(ctx context.Context, routeID string, pipeline *Pipeline) error { + s.mu.Lock() + defer s.mu.Unlock() + + pipeline.RouteID = routeID + data, err := yaml.Marshal(pipeline) + if err != nil { + return err + } + return os.WriteFile(s.pipelinePath(routeID), data, 0644) +} + +// ================== In-Memory State Store ================== + +// MemoryStateStore implements StateStore using in-memory storage. +type MemoryStateStore struct { + mu sync.RWMutex + states map[string]*TargetState +} + +// NewMemoryStateStore creates a new in-memory state store. +func NewMemoryStateStore() *MemoryStateStore { + return &MemoryStateStore{ + states: make(map[string]*TargetState), + } +} + +func (s *MemoryStateStore) GetTargetState(ctx context.Context, targetID string) (*TargetState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + state, ok := s.states[targetID] + if !ok { + // Return default healthy state + return &TargetState{ + TargetID: targetID, + Status: StatusHealthy, + }, nil + } + return state, nil +} + +func (s *MemoryStateStore) SetTargetState(ctx context.Context, state *TargetState) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.states[state.TargetID] = state + return nil +} + +func (s *MemoryStateStore) ListTargetStates(ctx context.Context) ([]*TargetState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + states := make([]*TargetState, 0, len(s.states)) + for _, state := range s.states { + states = append(states, state) + } + return states, nil +} + +func (s *MemoryStateStore) DeleteTargetState(ctx context.Context, targetID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.states, targetID) + return nil +} + +// ================== File-based Metrics Store ================== + +// FileMetricsStore implements MetricsStore using file-based storage. +// Traces are stored as JSON files in the traces directory. +// Directory size is enforced by cleanup logic similar to LogDirCleaner. +type FileMetricsStore struct { + mu sync.RWMutex + baseDir string + maxSizeMB int // Maximum total size in MB for traces directory +} + +// NewFileMetricsStore creates a new file-based metrics store. +func NewFileMetricsStore(baseDir string, maxSizeMB int) (*FileMetricsStore, error) { + tracesDir := filepath.Join(baseDir, "traces") + if err := os.MkdirAll(tracesDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create traces directory: %w", err) + } + + if maxSizeMB <= 0 { + maxSizeMB = 100 // Default 100MB + } + + store := &FileMetricsStore{ + baseDir: baseDir, + maxSizeMB: maxSizeMB, + } + + // Start background cleanup + go store.runCleanup() + + return store, nil +} + +func (s *FileMetricsStore) tracesDir() string { + return filepath.Join(s.baseDir, "traces") +} + +func (s *FileMetricsStore) traceFilename(trace *RequestTrace) string { + // Format: {timestamp}-{trace_id}.json + ts := trace.Timestamp.Format("2006-01-02T150405") + return fmt.Sprintf("%s-%s.json", ts, trace.TraceID[:8]) +} + +func (s *FileMetricsStore) RecordTrace(ctx context.Context, trace *RequestTrace) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := json.Marshal(trace) + if err != nil { + return fmt.Errorf("failed to marshal trace: %w", err) + } + + filename := s.traceFilename(trace) + filePath := filepath.Join(s.tracesDir(), filename) + + if err := os.WriteFile(filePath, data, 0644); err != nil { + return fmt.Errorf("failed to write trace file: %w", err) + } + + return nil +} + +func (s *FileMetricsStore) GetTraces(ctx context.Context, filter TraceFilter) ([]*RequestTrace, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entries, err := os.ReadDir(s.tracesDir()) + if err != nil { + if os.IsNotExist(err) { + return []*RequestTrace{}, nil + } + return nil, err + } + + // Sort by name descending (newest first since filename starts with timestamp) + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() > entries[j].Name() + }) + + var result []*RequestTrace + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + + data, err := os.ReadFile(filepath.Join(s.tracesDir(), entry.Name())) + if err != nil { + continue + } + + var trace RequestTrace + if err := json.Unmarshal(data, &trace); err != nil { + continue + } + + // Apply filters + if filter.RouteID != "" && trace.RouteID != filter.RouteID { + continue + } + if filter.Status != "" && string(trace.Status) != filter.Status { + continue + } + + result = append(result, &trace) + + if filter.Limit > 0 && len(result) >= filter.Limit { + break + } + } + + return result, nil +} + +func (s *FileMetricsStore) GetTrace(ctx context.Context, traceID string) (*RequestTrace, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entries, err := os.ReadDir(s.tracesDir()) + if err != nil { + return nil, fmt.Errorf("trace not found: %s", traceID) + } + + shortID := traceID + if len(traceID) > 8 { + shortID = traceID[:8] + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + // Check if filename contains the trace ID + if !strings.Contains(entry.Name(), shortID) { + continue + } + + data, err := os.ReadFile(filepath.Join(s.tracesDir(), entry.Name())) + if err != nil { + continue + } + + var trace RequestTrace + if err := json.Unmarshal(data, &trace); err != nil { + continue + } + + if trace.TraceID == traceID || (len(trace.TraceID) >= 8 && trace.TraceID[:8] == shortID) { + return &trace, nil + } + } + + return nil, fmt.Errorf("trace not found: %s", traceID) +} + +// RecordEvent is a no-op for file store (events are no longer tracked) +func (s *FileMetricsStore) RecordEvent(ctx context.Context, event *RoutingEvent) error { + // Events are no longer stored - this is intentionally a no-op + return nil +} + +// GetEvents returns empty list (events are no longer tracked) +func (s *FileMetricsStore) GetEvents(ctx context.Context, filter EventFilter) ([]*RoutingEvent, error) { + // Events are no longer stored + return []*RoutingEvent{}, nil +} + +func (s *FileMetricsStore) GetStats(ctx context.Context, filter StatsFilter) (*AggregatedStats, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + stats := &AggregatedStats{ + Period: filter.Period, + } + + // Calculate time range + var since time.Time + switch filter.Period { + case "1h": + since = time.Now().Add(-1 * time.Hour) + case "24h": + since = time.Now().Add(-24 * time.Hour) + case "7d": + since = time.Now().Add(-7 * 24 * time.Hour) + case "30d": + since = time.Now().Add(-30 * 24 * time.Hour) + default: + since = time.Now().Add(-1 * time.Hour) + } + + // Load all traces from files + entries, err := os.ReadDir(s.tracesDir()) + if err != nil { + if os.IsNotExist(err) { + return stats, nil + } + return nil, err + } + + var totalLatency int64 + layerCounts := make(map[int]int64) + targetCounts := make(map[string]*TargetDistribution) + attemptsCounts := make(map[int]int64) // Track 1-attempt, 2-attempt, etc. successes + + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + + data, err := os.ReadFile(filepath.Join(s.tracesDir(), entry.Name())) + if err != nil { + continue + } + + var trace RequestTrace + if err := json.Unmarshal(data, &trace); err != nil { + continue + } + + if trace.Timestamp.Before(since) { + continue + } + + stats.TotalRequests++ + totalLatency += trace.TotalLatencyMs + + switch trace.Status { + case TraceStatusSuccess, TraceStatusRetry, TraceStatusFallback: + stats.SuccessfulRequests++ + case TraceStatusFailed: + stats.FailedRequests++ + } + + // Track attempts distribution (how many attempts needed for success) + attemptCount := len(trace.Attempts) + if trace.Status == TraceStatusSuccess || trace.Status == TraceStatusRetry || trace.Status == TraceStatusFallback { + attemptsCounts[attemptCount]++ + } + + // Track layer distribution (use the successful attempt's layer) + for _, attempt := range trace.Attempts { + if attempt.Status == AttemptStatusSuccess { + layerCounts[attempt.Layer]++ + + // Track target distribution + if _, ok := targetCounts[attempt.TargetID]; !ok { + targetCounts[attempt.TargetID] = &TargetDistribution{ + TargetID: attempt.TargetID, + CredentialID: attempt.CredentialID, + } + } + targetCounts[attempt.TargetID].Requests++ + break + } + } + } + + if stats.TotalRequests > 0 { + stats.SuccessRate = float64(stats.SuccessfulRequests) / float64(stats.TotalRequests) + stats.AvgLatencyMs = totalLatency / stats.TotalRequests + } + + // Build layer distribution + for level, count := range layerCounts { + stats.LayerDistribution = append(stats.LayerDistribution, LayerDistribution{ + Level: level, + Requests: count, + Percentage: float64(count) / float64(stats.TotalRequests) * 100, + }) + } + + // Build target distribution + for _, td := range targetCounts { + stats.TargetDistribution = append(stats.TargetDistribution, *td) + } + + // Build attempts distribution + for attempts, count := range attemptsCounts { + stats.AttemptsDistribution = append(stats.AttemptsDistribution, AttemptsDistribution{ + Attempts: attempts, + Count: count, + Percentage: float64(count) / float64(stats.SuccessfulRequests) * 100, + }) + } + // Sort by attempts + sort.Slice(stats.AttemptsDistribution, func(i, j int) bool { + return stats.AttemptsDistribution[i].Attempts < stats.AttemptsDistribution[j].Attempts + }) + + return stats, nil +} + +// runCleanup runs the background cleanup task to enforce directory size limit. +func (s *FileMetricsStore) runCleanup() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for range ticker.C { + s.enforceSize() + } +} + +// enforceSize removes old trace files if total size exceeds maxSizeMB. +func (s *FileMetricsStore) enforceSize() { + s.mu.Lock() + defer s.mu.Unlock() + + maxBytes := int64(s.maxSizeMB) * 1024 * 1024 + + entries, err := os.ReadDir(s.tracesDir()) + if err != nil { + return + } + + type traceFile struct { + path string + size int64 + modTime time.Time + } + + var files []traceFile + var total int64 + + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + path := filepath.Join(s.tracesDir(), entry.Name()) + files = append(files, traceFile{ + path: path, + size: info.Size(), + modTime: info.ModTime(), + }) + total += info.Size() + } + + if total <= maxBytes { + return + } + + // Sort by modTime ascending (oldest first) + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + // Remove oldest files until under limit + for _, f := range files { + if total <= maxBytes { + break + } + if err := os.Remove(f.path); err == nil { + total -= f.size + } + } +} + +// MarshalJSON implements json.Marshaler for TargetState. +func (s *TargetState) MarshalJSON() ([]byte, error) { + type Alias TargetState + return json.Marshal(&struct { + *Alias + CooldownRemainingSeconds int `json:"cooldown_remaining_seconds,omitempty"` + }{ + Alias: (*Alias)(s), + CooldownRemainingSeconds: s.CooldownRemainingSeconds(), + }) +} + +// CooldownRemainingSeconds returns the remaining cooldown time in seconds. +func (s *TargetState) CooldownRemainingSeconds() int { + if s.CooldownEndsAt == nil || s.Status != StatusCooling { + return 0 + } + remaining := time.Until(*s.CooldownEndsAt).Seconds() + if remaining < 0 { + return 0 + } + return int(remaining) +} diff --git a/internal/api/modules/unified-routing/types.go b/internal/api/modules/unified-routing/types.go new file mode 100644 index 000000000..5e796358d --- /dev/null +++ b/internal/api/modules/unified-routing/types.go @@ -0,0 +1,336 @@ +// Package unifiedrouting provides a unified routing system that allows +// defining custom model aliases with multi-layer failover pipelines. +package unifiedrouting + +import ( + "time" +) + +// ================== Configuration Types ================== + +// Settings holds the global settings for unified routing. +type Settings struct { + Enabled bool `json:"enabled" yaml:"enabled"` + HideOriginalModels bool `json:"hide_original_models" yaml:"hide-original-models"` +} + +// HealthCheckConfig holds the health check configuration. +type HealthCheckConfig struct { + DefaultCooldownSeconds int `json:"default_cooldown_seconds" yaml:"default-cooldown-seconds"` + CheckIntervalSeconds int `json:"check_interval_seconds" yaml:"check-interval-seconds"` + CheckTimeoutSeconds int `json:"check_timeout_seconds" yaml:"check-timeout-seconds"` + MaxConsecutiveFailures int `json:"max_consecutive_failures" yaml:"max-consecutive-failures"` +} + +// DefaultHealthCheckConfig returns the default health check configuration. +func DefaultHealthCheckConfig() HealthCheckConfig { + return HealthCheckConfig{ + DefaultCooldownSeconds: 60, + CheckIntervalSeconds: 30, + CheckTimeoutSeconds: 10, + MaxConsecutiveFailures: 3, + } +} + +// Route represents a routing configuration (persistent entity). +type Route struct { + ID string `json:"id" yaml:"id"` + Name string `json:"name" yaml:"name"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Enabled bool `json:"enabled" yaml:"enabled"` + CreatedAt time.Time `json:"created_at" yaml:"-"` + UpdatedAt time.Time `json:"updated_at" yaml:"-"` +} + +// Pipeline represents the routing pipeline configuration (value object). +type Pipeline struct { + RouteID string `json:"route_id" yaml:"-"` + Layers []Layer `json:"layers" yaml:"layers"` +} + +// Layer represents a layer in the pipeline (value object). +type Layer struct { + Level int `json:"level" yaml:"level"` + Strategy LoadStrategy `json:"strategy" yaml:"strategy"` + CooldownSeconds int `json:"cooldown_seconds" yaml:"cooldown-seconds"` + Targets []Target `json:"targets" yaml:"targets"` +} + +// Target represents a target in a layer (value object). +type Target struct { + ID string `json:"id" yaml:"id"` + CredentialID string `json:"credential_id" yaml:"credential-id"` + Model string `json:"model" yaml:"model"` + Weight int `json:"weight,omitempty" yaml:"weight,omitempty"` + Enabled bool `json:"enabled" yaml:"enabled"` +} + +// LoadStrategy defines the load balancing strategy. +type LoadStrategy string + +const ( + StrategyRoundRobin LoadStrategy = "round-robin" + StrategyWeightedRound LoadStrategy = "weighted-round-robin" + StrategyLeastConn LoadStrategy = "least-connections" + StrategyRandom LoadStrategy = "random" + StrategyFirstAvailable LoadStrategy = "first-available" +) + +// ================== Runtime State Types ================== + +// TargetState represents the runtime state of a target (in-memory entity). +type TargetState struct { + TargetID string `json:"target_id"` + Status TargetStatus `json:"status"` + ConsecutiveFailures int `json:"consecutive_failures"` + CooldownEndsAt *time.Time `json:"cooldown_ends_at,omitempty"` + LastSuccessAt *time.Time `json:"last_success_at,omitempty"` + LastFailureAt *time.Time `json:"last_failure_at,omitempty"` + LastFailureReason string `json:"last_failure_reason,omitempty"` + ActiveConnections int64 `json:"active_connections"` + TotalRequests int64 `json:"total_requests"` + SuccessfulRequests int64 `json:"successful_requests"` +} + +// TargetStatus defines the status of a target. +// Simplified to only two states: +// - healthy: target is available (default state) +// - cooling: target is in cooldown after failure +type TargetStatus string + +const ( + StatusHealthy TargetStatus = "healthy" + StatusCooling TargetStatus = "cooling" +) + +// RouteState represents the runtime state of a route. +type RouteState struct { + RouteID string `json:"route_id"` + RouteName string `json:"route_name"` + Status string `json:"status"` // "healthy", "degraded", "unhealthy" + ActiveLayer int `json:"active_layer"` + LayerStates []LayerState `json:"layers"` +} + +// LayerState represents the runtime state of a layer. +type LayerState struct { + Level int `json:"level"` + Status string `json:"status"` // "active", "standby", "exhausted" + TargetStates []*TargetState `json:"targets"` +} + +// StateOverview represents the overall state overview. +type StateOverview struct { + UnifiedRoutingEnabled bool `json:"unified_routing_enabled"` + HideOriginalModels bool `json:"hide_original_models"` + TotalRoutes int `json:"total_routes"` + HealthyRoutes int `json:"healthy_routes"` + DegradedRoutes int `json:"degraded_routes"` + UnhealthyRoutes int `json:"unhealthy_routes"` + Routes []RouteState `json:"routes,omitempty"` +} + +// ================== Monitoring Types ================== + +// RequestTrace represents a request trace record. +type RequestTrace struct { + TraceID string `json:"trace_id"` + RouteID string `json:"route_id"` + RouteName string `json:"route_name"` + Timestamp time.Time `json:"timestamp"` + Status TraceStatus `json:"status"` + TotalLatencyMs int64 `json:"total_latency_ms"` + Attempts []AttemptTrace `json:"attempts"` +} + +// TraceStatus defines the status of a trace. +type TraceStatus string + +const ( + TraceStatusSuccess TraceStatus = "success" + TraceStatusRetry TraceStatus = "retry" + TraceStatusFallback TraceStatus = "fallback" + TraceStatusFailed TraceStatus = "failed" +) + +// AttemptTrace represents a single attempt within a trace. +type AttemptTrace struct { + Attempt int `json:"attempt"` + Layer int `json:"layer"` + TargetID string `json:"target_id"` + CredentialID string `json:"credential_id"` + Model string `json:"model"` + Status AttemptStatus `json:"status"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Error string `json:"error,omitempty"` +} + +// AttemptStatus defines the status of an attempt. +type AttemptStatus string + +const ( + AttemptStatusSuccess AttemptStatus = "success" + AttemptStatusFailed AttemptStatus = "failed" + AttemptStatusSkipped AttemptStatus = "skipped" +) + +// RoutingEvent represents a routing event. +type RoutingEvent struct { + ID string `json:"id"` + Type RoutingEventType `json:"type"` + Timestamp time.Time `json:"timestamp"` + RouteID string `json:"route_id"` + TargetID string `json:"target_id,omitempty"` + Details map[string]any `json:"details,omitempty"` +} + +// RoutingEventType defines the type of routing event. +type RoutingEventType string + +const ( + EventTargetFailed RoutingEventType = "target_failed" + EventTargetRecovered RoutingEventType = "target_recovered" + EventLayerFallback RoutingEventType = "layer_fallback" + EventCooldownStarted RoutingEventType = "cooldown_started" + EventCooldownEnded RoutingEventType = "cooldown_ended" +) + +// ================== Statistics Types ================== + +// AggregatedStats represents aggregated statistics. +type AggregatedStats struct { + Period string `json:"period"` + TotalRequests int64 `json:"total_requests"` + SuccessfulRequests int64 `json:"successful_requests"` + FailedRequests int64 `json:"failed_requests"` + SuccessRate float64 `json:"success_rate"` + AvgLatencyMs int64 `json:"avg_latency_ms"` + P95LatencyMs int64 `json:"p95_latency_ms"` + P99LatencyMs int64 `json:"p99_latency_ms"` + LayerDistribution []LayerDistribution `json:"layer_distribution,omitempty"` + TargetDistribution []TargetDistribution `json:"target_distribution,omitempty"` + AttemptsDistribution []AttemptsDistribution `json:"attempts_distribution,omitempty"` +} + +// AttemptsDistribution represents the distribution of how many attempts +// were needed for successful requests. +type AttemptsDistribution struct { + Attempts int `json:"attempts"` // Number of attempts (1, 2, 3, ...) + Count int64 `json:"count"` // Number of requests that succeeded with this many attempts + Percentage float64 `json:"percentage"` // Percentage of successful requests +} + +// LayerDistribution represents the distribution of requests across layers. +type LayerDistribution struct { + Level int `json:"level"` + Requests int64 `json:"requests"` + Percentage float64 `json:"percentage"` +} + +// TargetDistribution represents the distribution of requests across targets. +type TargetDistribution struct { + TargetID string `json:"target_id"` + CredentialID string `json:"credential_id"` + Requests int64 `json:"requests"` + SuccessRate float64 `json:"success_rate"` + AvgLatencyMs int64 `json:"avg_latency_ms"` +} + +// ================== Credential Types ================== + +// CredentialInfo represents information about a credential. +type CredentialInfo struct { + ID string `json:"id"` + Provider string `json:"provider"` + Type string `json:"type"` // "oauth", "api-key" + Label string `json:"label,omitempty"` + Prefix string `json:"prefix,omitempty"` + BaseURL string `json:"base_url,omitempty"` + APIKey string `json:"api_key,omitempty"` // masked + Status string `json:"status"` + Models []ModelInfo `json:"models"` +} + +// ModelInfo represents information about a model. +type ModelInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Available bool `json:"available"` +} + +// ================== Health Check Types ================== + +// HealthResult represents the result of a health check. +type HealthResult struct { + TargetID string `json:"target_id"` + CredentialID string `json:"credential_id"` + Model string `json:"model"` + Status string `json:"status"` // "healthy", "unhealthy" + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CheckedAt time.Time `json:"checked_at"` +} + +// ================== Filter Types ================== + +// StatsFilter defines the filter for statistics queries. +type StatsFilter struct { + Period string `json:"period"` // "1h", "24h", "7d", "30d" + Granularity string `json:"granularity"` // "minute", "hour", "day" + StartTime time.Time `json:"start_time,omitempty"` + EndTime time.Time `json:"end_time,omitempty"` +} + +// EventFilter defines the filter for event queries. +type EventFilter struct { + Type string `json:"type,omitempty"` // "failure", "recovery", "fallback", "all" + RouteID string `json:"route_id,omitempty"` + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` +} + +// TraceFilter defines the filter for trace queries. +type TraceFilter struct { + RouteID string `json:"route_id,omitempty"` + Status string `json:"status,omitempty"` // "success", "retry", "fallback", "failed" + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` +} + +// HealthHistoryFilter defines the filter for health history queries. +type HealthHistoryFilter struct { + TargetID string `json:"target_id,omitempty"` + Status string `json:"status,omitempty"` + Limit int `json:"limit,omitempty"` + Since time.Time `json:"since,omitempty"` +} + +// ================== Export/Import Types ================== + +// ExportData represents the data for export/import. +type ExportData struct { + Version string `json:"version"` + ExportedAt time.Time `json:"exported_at"` + Config ExportedConfig `json:"config"` +} + +// ExportedConfig represents the exported configuration. +type ExportedConfig struct { + Settings Settings `json:"settings"` + HealthCheck HealthCheckConfig `json:"health_check"` + Routes []RouteWithPipeline `json:"routes"` +} + +// RouteWithPipeline combines route and its pipeline for export. +type RouteWithPipeline struct { + Route Route `json:"route"` + Pipeline Pipeline `json:"pipeline"` +} + +// ================== Validation Types ================== + +// ValidationError represents a validation error. +type ValidationError struct { + Field string `json:"field"` + Message string `json:"message"` +} diff --git a/internal/api/server.go b/internal/api/server.go index 8b26044e1..d1b535511 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,10 +5,12 @@ package api import ( + "bytes" "context" "crypto/subtle" "errors" "fmt" + "io" "net/http" "os" "path/filepath" @@ -23,6 +25,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp" + unifiedrouting "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/unified-routing" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" @@ -36,7 +39,11 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "gopkg.in/yaml.v3" ) @@ -155,6 +162,9 @@ type Server struct { // ampModule is the Amp routing module for model mapping hot-reload ampModule *ampmodule.AmpModule + // unifiedRoutingModule is the unified routing module for custom model routing + unifiedRoutingModule *unifiedrouting.Module + // managementRoutesRegistered tracks whether the management routes have been attached to the engine. managementRoutesRegistered atomic.Bool // managementRoutesEnabled controls whether management endpoints serve real handlers. @@ -280,6 +290,18 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk log.Errorf("Failed to register Amp module: %v", err) } + // Register Unified Routing module + // Note: We skip auto route registration here because unified-routing routes + // should use management auth middleware (not API key auth). + // Routes will be registered in registerManagementRoutes() with correct auth. + s.unifiedRoutingModule = unifiedrouting.New( + unifiedrouting.WithAuthManager(authManager), + unifiedrouting.WithSkipAutoRoutes(), // Routes registered later with management auth + ) + if err := modules.RegisterModule(ctx, s.unifiedRoutingModule); err != nil { + log.Errorf("Failed to register Unified Routing module: %v", err) + } + // Apply additional router configurators from options if optionState.routerConfigurator != nil { optionState.routerConfigurator(engine, s.handlers, cfg) @@ -320,20 +342,21 @@ func (s *Server) setupRoutes() { v1.Use(AuthMiddleware(s.accessManager)) { v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) - v1.POST("/chat/completions", openaiHandlers.ChatCompletions) - v1.POST("/completions", openaiHandlers.Completions) - v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) - v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) - v1.POST("/responses", openaiResponsesHandlers.Responses) + // Wrap handlers with unified routing support + v1.POST("/chat/completions", s.wrapWithUnifiedRouting(openaiHandlers.ChatCompletions)) + v1.POST("/completions", s.wrapWithUnifiedRouting(openaiHandlers.Completions)) + v1.POST("/messages", s.wrapWithUnifiedRoutingClaude(claudeCodeHandlers.ClaudeMessages)) + v1.POST("/messages/count_tokens", s.wrapWithUnifiedRoutingClaude(claudeCodeHandlers.ClaudeCountTokens)) + v1.POST("/responses", s.wrapWithUnifiedRouting(openaiResponsesHandlers.Responses)) } // Gemini compatible API routes v1beta := s.engine.Group("/v1beta") v1beta.Use(AuthMiddleware(s.accessManager)) { - v1beta.GET("/models", geminiHandlers.GeminiModels) - v1beta.POST("/models/*action", geminiHandlers.GeminiHandler) - v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler) + v1beta.GET("/models", s.unifiedGeminiModelsHandler(geminiHandlers)) + v1beta.POST("/models/*action", s.wrapWithUnifiedRoutingGemini(geminiHandlers.GeminiHandler)) + v1beta.GET("/models/*action", s.wrapWithUnifiedRoutingGemini(geminiHandlers.GeminiGetHandler)) } // Root endpoint @@ -347,7 +370,7 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) + s.engine.POST("/v1internal:method", s.wrapWithUnifiedRoutingGeminiCLI(geminiCLIHandlers.CLIHandler)) // OAuth callback endpoints (reuse main server port) // These endpoints receive provider redirects and persist @@ -607,10 +630,15 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels) + mgmt.GET("/auth-files/health", s.mgmt.CheckAuthFileModelsHealth) mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile) mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus) + + mgmt.GET("/providers", s.mgmt.ListProviders) + mgmt.GET("/providers/health", s.mgmt.CheckProvidersHealth) + mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential) mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) @@ -623,6 +651,23 @@ func (s *Server) registerManagementRoutes() { mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } + + // Register unified-routing module routes with management auth + if s.unifiedRoutingModule != nil { + // Create a combined middleware that includes both availability check and management auth + managementAuth := func() gin.HandlerFunc { + availMiddleware := s.managementAvailabilityMiddleware() + authMiddleware := s.mgmt.Middleware() + return func(c *gin.Context) { + availMiddleware(c) + if c.IsAborted() { + return + } + authMiddleware(c) + } + }() + s.unifiedRoutingModule.RegisterRoutes(s.engine, managementAuth) + } } func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc { @@ -740,12 +785,78 @@ func (s *Server) watchKeepAlive() { } } +// unifiedGeminiModelsHandler creates a unified handler for the /v1beta/models endpoint. +// When unified routing is enabled with hide_original_models=true, only route aliases are returned in Gemini format. +func (s *Server) unifiedGeminiModelsHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc { + return func(c *gin.Context) { + // Check if unified routing should hide original models + if s.unifiedRoutingModule != nil { + engine := s.unifiedRoutingModule.GetEngine() + if engine != nil && engine.ShouldHideOriginalModels(c.Request.Context()) { + // Return only route aliases as models in Gemini format + routeNames := engine.GetRouteNames(c.Request.Context()) + models := make([]map[string]any, len(routeNames)) + for i, name := range routeNames { + models[i] = map[string]any{ + "name": "models/" + name, + "displayName": name, + "description": name, + "supportedGenerationMethods": []string{"generateContent"}, + } + } + c.JSON(200, gin.H{ + "models": models, + }) + return + } + } + + // Delegate to original handler + geminiHandler.GeminiModels(c) + } +} + // unifiedModelsHandler creates a unified handler for the /v1/models endpoint // that routes to different handlers based on the User-Agent header. // If User-Agent starts with "claude-cli", it routes to Claude handler, // otherwise it routes to OpenAI handler. +// When unified routing is enabled with hide_original_models=true, only route aliases are returned. func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc { return func(c *gin.Context) { + // Check if unified routing should hide original models + if s.unifiedRoutingModule != nil { + engine := s.unifiedRoutingModule.GetEngine() + if engine != nil { + isEnabled := engine.IsEnabled(c.Request.Context()) + shouldHide := engine.ShouldHideOriginalModels(c.Request.Context()) + log.Debugf("[UnifiedRouting] /v1/models check: enabled=%v, hideOriginal=%v", isEnabled, shouldHide) + + if shouldHide { + // Return only route aliases as models + routeNames := engine.GetRouteNames(c.Request.Context()) + log.Debugf("[UnifiedRouting] Returning %d route aliases as models: %v", len(routeNames), routeNames) + models := make([]map[string]any, len(routeNames)) + for i, name := range routeNames { + models[i] = map[string]any{ + "id": name, + "object": "model", + "created": 1700000000, + "owned_by": "unified-routing", + } + } + c.JSON(200, gin.H{ + "object": "list", + "data": models, + }) + return + } + } else { + log.Debugf("[UnifiedRouting] /v1/models: engine is nil") + } + } else { + log.Debugf("[UnifiedRouting] /v1/models: module is nil") + } + userAgent := c.GetHeader("User-Agent") // Route to Claude handler if User-Agent starts with "claude-cli" @@ -759,6 +870,502 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl } } +// wrapWithUnifiedRouting wraps an API handler with unified routing support for OpenAI format. +// When a model name is a route alias, it executes the request using the target credential directly. +// When hide_original_models is enabled and the model is not a route alias, it returns 404. +// Otherwise, it delegates to the original handler. +func (s *Server) wrapWithUnifiedRouting(originalHandler gin.HandlerFunc) gin.HandlerFunc { + return s.wrapWithUnifiedRoutingFormat(originalHandler, sdktranslator.FormatOpenAI, "model") +} + +// wrapWithUnifiedRoutingClaude wraps an API handler with unified routing support for Claude format. +func (s *Server) wrapWithUnifiedRoutingClaude(originalHandler gin.HandlerFunc) gin.HandlerFunc { + return s.wrapWithUnifiedRoutingFormat(originalHandler, sdktranslator.FormatClaude, "model") +} + +// wrapWithUnifiedRoutingFormat wraps an API handler with unified routing support for a specific format. +func (s *Server) wrapWithUnifiedRoutingFormat(originalHandler gin.HandlerFunc, sourceFormat sdktranslator.Format, modelField string) gin.HandlerFunc { + return func(c *gin.Context) { + // Skip if unified routing module is not configured + if s.unifiedRoutingModule == nil { + originalHandler(c) + return + } + + engine := s.unifiedRoutingModule.GetEngine() + if engine == nil || !engine.IsEnabled(c.Request.Context()) { + originalHandler(c) + return + } + + // Read the request body + rawBody, err := io.ReadAll(c.Request.Body) + if err != nil { + originalHandler(c) + return + } + + // Extract model from request + modelName := gjson.GetBytes(rawBody, modelField).String() + if modelName == "" { + c.Request.Body = io.NopCloser(bytes.NewReader(rawBody)) + originalHandler(c) + return + } + + // Check if this model is a route alias + _, _, routeErr := engine.GetRoutingTarget(c.Request.Context(), modelName) + + if routeErr == nil { + // Model is a route alias - execute with full failover support + log.Debugf("[UnifiedRouting] Routing request for model: %s (format: %s)", modelName, sourceFormat) + + stream := gjson.GetBytes(rawBody, "stream").Bool() + + // Use ExecuteWithFailover for full multi-layer failover support + s.executeWithUnifiedRoutingFailoverFormat(c, engine, modelName, rawBody, stream, sourceFormat) + return + } + + // Model is not a route alias + // Check if we should hide original models + if engine.ShouldHideOriginalModels(c.Request.Context()) { + // Return model not found error + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("The model '%s' does not exist or you do not have access to it.", modelName), + "type": "invalid_request_error", + "code": "model_not_found", + }, + }) + return + } + + // Allow the request to proceed with original handler + c.Request.Body = io.NopCloser(bytes.NewReader(rawBody)) + originalHandler(c) + } +} + +// wrapWithUnifiedRoutingGemini wraps a Gemini API handler with unified routing support. +// Gemini format has the model name in the URL path (e.g., /v1beta/models/gemini-pro:generateContent) +func (s *Server) wrapWithUnifiedRoutingGemini(originalHandler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // Skip if unified routing module is not configured + if s.unifiedRoutingModule == nil { + originalHandler(c) + return + } + + engine := s.unifiedRoutingModule.GetEngine() + if engine == nil || !engine.IsEnabled(c.Request.Context()) { + originalHandler(c) + return + } + + // Extract model name from URL path + // Format: /v1beta/models/{model}:{method} or /v1beta/models/{model} + action := c.Param("action") + if action == "" { + originalHandler(c) + return + } + + action = strings.TrimPrefix(action, "/") + parts := strings.Split(action, ":") + modelName := parts[0] + + if modelName == "" { + originalHandler(c) + return + } + + // Check if this model is a route alias + _, _, routeErr := engine.GetRoutingTarget(c.Request.Context(), modelName) + + if routeErr == nil { + // Model is a route alias - execute with unified routing + log.Debugf("[UnifiedRouting] Routing Gemini request for model: %s", modelName) + + // Read the request body + rawBody, err := io.ReadAll(c.Request.Body) + if err != nil { + rawBody = []byte("{}") + } + + // For Gemini, check if it's streaming based on the method + method := "" + if len(parts) > 1 { + method = parts[1] + } + stream := method == "streamGenerateContent" + + s.executeWithUnifiedRoutingFailoverFormat(c, engine, modelName, rawBody, stream, sdktranslator.FormatGemini) + return + } + + // Model is not a route alias + // Check if we should hide original models + if engine.ShouldHideOriginalModels(c.Request.Context()) { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("The model '%s' does not exist or you do not have access to it.", modelName), + "type": "invalid_request_error", + "code": "model_not_found", + }, + }) + return + } + + // Allow the request to proceed with original handler + originalHandler(c) + } +} + +// wrapWithUnifiedRoutingGeminiCLI wraps a Gemini CLI API handler with unified routing support. +// Gemini CLI format has the model name in the request body's "model" field. +// The method is determined by the URL path (e.g., /v1internal:generateContent, /v1internal:streamGenerateContent) +func (s *Server) wrapWithUnifiedRoutingGeminiCLI(originalHandler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // Skip if unified routing module is not configured + if s.unifiedRoutingModule == nil { + originalHandler(c) + return + } + + engine := s.unifiedRoutingModule.GetEngine() + if engine == nil || !engine.IsEnabled(c.Request.Context()) { + originalHandler(c) + return + } + + // Read the request body + rawBody, err := io.ReadAll(c.Request.Body) + if err != nil { + originalHandler(c) + return + } + + // Extract model from request body + modelName := gjson.GetBytes(rawBody, "model").String() + if modelName == "" { + c.Request.Body = io.NopCloser(bytes.NewReader(rawBody)) + originalHandler(c) + return + } + + // Check if this model is a route alias + _, _, routeErr := engine.GetRoutingTarget(c.Request.Context(), modelName) + + if routeErr == nil { + // Model is a route alias - execute with unified routing + log.Debugf("[UnifiedRouting] Routing Gemini CLI request for model: %s", modelName) + + // Determine if streaming based on URL path + requestPath := c.Request.URL.Path + stream := strings.Contains(requestPath, "streamGenerateContent") + + s.executeWithUnifiedRoutingFailoverFormat(c, engine, modelName, rawBody, stream, sdktranslator.FormatGeminiCLI) + return + } + + // Model is not a route alias + // Check if we should hide original models + if engine.ShouldHideOriginalModels(c.Request.Context()) { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("The model '%s' does not exist or you do not have access to it.", modelName), + "type": "invalid_request_error", + "code": "model_not_found", + }, + }) + return + } + + // Allow the request to proceed with original handler + c.Request.Body = io.NopCloser(bytes.NewReader(rawBody)) + originalHandler(c) + } +} + +// executeWithUnifiedRoutingFailover executes a request with full multi-layer failover support (OpenAI format). +func (s *Server) executeWithUnifiedRoutingFailover(c *gin.Context, engine unifiedrouting.RoutingEngine, modelName string, rawBody []byte, stream bool) { + s.executeWithUnifiedRoutingFailoverFormat(c, engine, modelName, rawBody, stream, sdktranslator.FormatOpenAI) +} + +// executeWithUnifiedRoutingFailoverFormat executes a request with full multi-layer failover support for any format. +func (s *Server) executeWithUnifiedRoutingFailoverFormat(c *gin.Context, engine unifiedrouting.RoutingEngine, modelName string, rawBody []byte, stream bool, sourceFormat sdktranslator.Format) { + ctx := c.Request.Context() + + // Get routing decision + routingEngine, ok := engine.(*unifiedrouting.DefaultRoutingEngine) + if !ok { + // Fallback to simple routing + s.executeWithUnifiedRoutingSimpleFormat(c, engine, modelName, rawBody, stream, sourceFormat) + return + } + + decision, err := routingEngine.GetRoutingDecision(ctx, modelName) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{"message": err.Error(), "type": "invalid_request_error", "code": "model_not_found"}, + }) + return + } + + // For non-streaming requests, use ExecuteWithFailover + if !stream { + var responsePayload []byte + + // Create executor function that will be called for each target + executeFunc := func(execCtx context.Context, targetAuth *auth.Auth, targetModel string) error { + // Replace model in request body + newBody, err := sjson.SetBytes(rawBody, "model", targetModel) + if err != nil { + newBody = rawBody + } + + req := cliproxyexecutor.Request{ + Model: targetModel, + Payload: newBody, + } + opts := cliproxyexecutor.Options{ + Stream: false, + OriginalRequest: rawBody, + SourceFormat: sourceFormat, + } + + resp, err := s.handlers.AuthManager.ExecuteWithAuth(execCtx, targetAuth, req, opts) + if err != nil { + return err + } + responsePayload = resp.Payload + return nil + } + + // Execute with failover + err := routingEngine.ExecuteWithFailover(ctx, decision, executeFunc) + if err != nil { + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + c.JSON(status, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "server_error", + }, + }) + return + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(responsePayload) + return + } + + // For streaming requests, use streaming failover + streamExecuteFunc := func(execCtx context.Context, targetAuth *auth.Auth, targetModel string) (<-chan cliproxyexecutor.StreamChunk, error) { + // Replace model in request body + newBody, err := sjson.SetBytes(rawBody, "model", targetModel) + if err != nil { + newBody = rawBody + } + + req := cliproxyexecutor.Request{ + Model: targetModel, + Payload: newBody, + } + opts := cliproxyexecutor.Options{ + Stream: true, + OriginalRequest: rawBody, + SourceFormat: sourceFormat, + } + + return s.handlers.AuthManager.ExecuteStreamWithAuth(execCtx, targetAuth, req, opts) + } + + chunks, err := routingEngine.ExecuteStreamWithFailover(ctx, decision, streamExecuteFunc) + if err != nil { + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + c.JSON(status, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "server_error", + }, + }) + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + flusher, _ := c.Writer.(http.Flusher) + wroteData := false + for chunk := range chunks { + if chunk.Err != nil { + log.Warnf("[UnifiedRouting] Stream error: %v", chunk.Err) + break + } + if len(chunk.Payload) > 0 { + wroteData = true + // Check if chunk already has SSE format (from Claude, Gemini, etc.) + if bytes.HasPrefix(chunk.Payload, []byte("data:")) || + bytes.HasPrefix(chunk.Payload, []byte("event:")) { + // Already SSE formatted, write directly + _, _ = c.Writer.Write(chunk.Payload) + // Ensure newline termination if not present + if !bytes.HasSuffix(chunk.Payload, []byte("\n\n")) { + _, _ = c.Writer.Write([]byte("\n\n")) + } + } else { + // Raw JSON, wrap in SSE format + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk.Payload)) + } + if flusher != nil { + flusher.Flush() + } + } + } + // Send SSE termination signal only if we wrote data and it's OpenAI format + // Claude/Gemini handle their own termination + if wroteData { + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } +} + +// executeWithUnifiedRoutingSimple executes a request with simple single-target routing (OpenAI format). +func (s *Server) executeWithUnifiedRoutingSimple(c *gin.Context, engine unifiedrouting.RoutingEngine, modelName string, rawBody []byte, stream bool) { + s.executeWithUnifiedRoutingSimpleFormat(c, engine, modelName, rawBody, stream, sdktranslator.FormatOpenAI) +} + +// executeWithUnifiedRoutingSimpleFormat executes a request with simple single-target routing for any format. +func (s *Server) executeWithUnifiedRoutingSimpleFormat(c *gin.Context, engine unifiedrouting.RoutingEngine, modelName string, rawBody []byte, stream bool, sourceFormat sdktranslator.Format) { + ctx := c.Request.Context() + + targetModel, credentialID, err := engine.GetRoutingTarget(ctx, modelName) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{"message": err.Error(), "type": "server_error"}, + }) + return + } + + targetAuth, found := s.handlers.AuthManager.GetByID(credentialID) + if !found { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Credential '%s' not found", credentialID), + "type": "server_error", + }, + }) + return + } + + // Replace model in request body + newBody, err := sjson.SetBytes(rawBody, "model", targetModel) + if err != nil { + newBody = rawBody + } + + req := cliproxyexecutor.Request{ + Model: targetModel, + Payload: newBody, + } + opts := cliproxyexecutor.Options{ + Stream: stream, + OriginalRequest: rawBody, + SourceFormat: sourceFormat, + } + + if stream { + chunks, err := s.handlers.AuthManager.ExecuteStreamWithAuth(ctx, targetAuth, req, opts) + if err != nil { + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + c.JSON(status, gin.H{ + "error": gin.H{"message": err.Error(), "type": "server_error"}, + }) + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + flusher, _ := c.Writer.(http.Flusher) + wroteData := false + for chunk := range chunks { + if chunk.Err != nil { + log.Warnf("[UnifiedRouting] Stream error: %v", chunk.Err) + break + } + if len(chunk.Payload) > 0 { + wroteData = true + // Check if chunk already has SSE format (from Claude, Gemini, etc.) + if bytes.HasPrefix(chunk.Payload, []byte("data:")) || + bytes.HasPrefix(chunk.Payload, []byte("event:")) { + // Already SSE formatted, write directly + _, _ = c.Writer.Write(chunk.Payload) + // Ensure newline termination if not present + if !bytes.HasSuffix(chunk.Payload, []byte("\n\n")) { + _, _ = c.Writer.Write([]byte("\n\n")) + } + } else { + // Raw JSON, wrap in SSE format + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk.Payload)) + } + if flusher != nil { + flusher.Flush() + } + } + } + // Send SSE termination signal only if we wrote data and it's OpenAI format + if wroteData { + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + } + } else { + resp, err := s.handlers.AuthManager.ExecuteWithAuth(ctx, targetAuth, req, opts) + if err != nil { + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + c.JSON(status, gin.H{ + "error": gin.H{"message": err.Error(), "type": "server_error"}, + }) + return + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(resp.Payload) + } +} + // Start begins listening for and serving HTTP or HTTPS requests. // It's a blocking call and will only return on an unrecoverable error. // @@ -999,6 +1606,14 @@ func (s *Server) UpdateClients(cfg *config.Config) { log.Warnf("amp module is nil, skipping config update") } + // Notify Unified Routing module of config changes + if s.unifiedRoutingModule != nil { + log.Debugf("triggering unified routing module config update") + if err := s.unifiedRoutingModule.OnConfigUpdated(cfg); err != nil { + log.Errorf("failed to update Unified Routing module config: %v", err) + } + } + // Count client sources from configuration and auth store. tokenStore := sdkAuth.GetTokenStore() if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok { diff --git a/internal/config/config.go b/internal/config/config.go index 839b7b057..b5593bf65 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -287,6 +287,10 @@ type ClaudeKey struct { // ProxyURL overrides the global proxy setting for this API key if provided. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` + // Models defines upstream model names and aliases for request routing. Models []ClaudeModel `yaml:"models" json:"models"` @@ -335,6 +339,10 @@ type CodexKey struct { // ProxyURL overrides the global proxy setting for this API key if provided. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` + // Models defines upstream model names and aliases for request routing. Models []CodexModel `yaml:"models" json:"models"` @@ -379,6 +387,10 @@ type GeminiKey struct { // ProxyURL optionally overrides the global proxy for this API key. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` + // Models defines upstream model names and aliases for request routing. Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"` @@ -437,6 +449,10 @@ type OpenAICompatibilityAPIKey struct { // ProxyURL overrides the global proxy setting for this API key if provided. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` } // OpenAICompatibilityModel represents a model configuration for OpenAI compatibility, diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index 4d4abc37a..8a292aeb6 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -9,6 +9,10 @@ type SDKConfig struct { // ProxyURL is the URL of an optional proxy server to use for outbound requests. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` + // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") // to target prefixed credentials. When false, unprefixed model requests may use prefixed // credentials as well. diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index 786c5318c..1be3548bb 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -28,6 +28,10 @@ type VertexCompatKey struct { // ProxyURL optionally overrides the global proxy for this API key. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when proxy-url uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `yaml:"proxy-dns,omitempty" json:"proxy-dns,omitempty"` + // Headers optionally adds extra HTTP headers for requests sent with this key. // Commonly used for cookies, user-agent, and other authentication headers. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` diff --git a/internal/managementasset/updater.go b/internal/managementasset/updater.go index c941da024..abce72553 100644 --- a/internal/managementasset/updater.go +++ b/internal/managementasset/updater.go @@ -188,6 +188,12 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL ctx = context.Background() } + // Skip auto-update when MANAGEMENT_STATIC_PATH is explicitly set (user wants to use custom/local version) + if override := strings.TrimSpace(os.Getenv("MANAGEMENT_STATIC_PATH")); override != "" { + log.Debug("management asset auto-update skipped: MANAGEMENT_STATIC_PATH is set, using local file") + return + } + if disableControlPanel.Load() { log.Debug("management asset sync skipped: control panel disabled by configuration") return diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626ac..ce2b8fbe3 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -2,6 +2,10 @@ package executor import ( "context" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "fmt" "net" "net/http" "net/url" @@ -10,7 +14,9 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sscore "github.com/shadowsocks/go-shadowsocks2/core" log "github.com/sirupsen/logrus" + "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/proxy" ) @@ -33,20 +39,22 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip httpClient.Timeout = timeout } - // Priority 1: Use auth.ProxyURL if configured - var proxyURL string + // Priority 1: Use auth.ProxyURL and auth.ProxyDNS if configured + var proxyURL, proxyDNS string if auth != nil { proxyURL = strings.TrimSpace(auth.ProxyURL) + proxyDNS = strings.TrimSpace(auth.ProxyDNS) } - // Priority 2: Use cfg.ProxyURL if auth proxy is not configured + // Priority 2: Use cfg.ProxyURL and cfg.ProxyDNS if auth proxy is not configured if proxyURL == "" && cfg != nil { proxyURL = strings.TrimSpace(cfg.ProxyURL) + proxyDNS = strings.TrimSpace(cfg.ProxyDNS) } // If we have a proxy URL configured, set up the transport if proxyURL != "" { - transport := buildProxyTransport(proxyURL) + transport := buildProxyTransport(proxyURL, proxyDNS) if transport != nil { httpClient.Transport = transport return httpClient @@ -64,14 +72,15 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip } // buildProxyTransport creates an HTTP transport configured for the given proxy URL. -// It supports SOCKS5, HTTP, and HTTPS proxy protocols. +// It supports SOCKS5, SS (Shadowsocks), HTTP, and HTTPS proxy protocols. // // Parameters: -// - proxyURL: The proxy URL string (e.g., "socks5://user:pass@host:port", "http://host:port") +// - proxyURL: The proxy URL string (e.g., "socks5://user:pass@host:port", "ss://method:pass@host:port", "http://host:port") +// - proxyDNS: Optional DoT DNS server (format: "tls://host:port") for resolving SS proxy hostnames // // Returns: // - *http.Transport: A configured transport, or nil if the proxy URL is invalid -func buildProxyTransport(proxyURL string) *http.Transport { +func buildProxyTransport(proxyURL, proxyDNS string) *http.Transport { if proxyURL == "" { return nil } @@ -104,6 +113,43 @@ func buildProxyTransport(proxyURL string) *http.Transport { return dialer.Dial(network, addr) }, } + } else if parsedURL.Scheme == "ss" { + // Configure Shadowsocks proxy + ssMethod, ssPassword, ssServer, errSS := parseSSProxyURL(proxyURL) + if errSS != nil { + log.Errorf("parse Shadowsocks URL failed: %v", errSS) + return nil + } + + // Resolve SS server address using custom DNS if provided + resolvedServer, errResolve := resolveSSServerAddr(ssServer, proxyDNS) + if errResolve != nil { + log.Errorf("resolve Shadowsocks server address failed: %v", errResolve) + return nil + } + + cipher, errCipher := sscore.PickCipher(ssMethod, nil, ssPassword) + if errCipher != nil { + log.Errorf("create Shadowsocks cipher failed (method=%s): %v", ssMethod, errCipher) + return nil + } + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Connect to the Shadowsocks server using resolved address + rawConn, errDial := net.Dial("tcp", resolvedServer) + if errDial != nil { + return nil, errDial + } + // Wrap the connection with Shadowsocks encryption + ssConn := cipher.StreamConn(rawConn) + // Write the target address in SOCKS5-style format + if errAddr := writeSSProxyTargetAddr(ssConn, addr); errAddr != nil { + rawConn.Close() + return nil, errAddr + } + return ssConn, nil + }, + } } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { // Configure HTTP or HTTPS proxy transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} @@ -114,3 +160,232 @@ func buildProxyTransport(proxyURL string) *http.Transport { return transport } + +// resolveSSServerAddr resolves the SS server address using custom DoT DNS if provided. +// If proxyDNS is empty or the host is already an IP, returns the original address. +// +// Parameters: +// - serverAddr: The server address in "host:port" format +// - proxyDNS: Optional DoT DNS server (format: "tls://host:port") +// +// Returns: +// - resolved address in "ip:port" format +// - error if resolution fails +func resolveSSServerAddr(serverAddr, proxyDNS string) (string, error) { + host, port, err := net.SplitHostPort(serverAddr) + if err != nil { + return "", fmt.Errorf("invalid server address: %w", err) + } + + // If host is already an IP, return as-is + if ip := net.ParseIP(host); ip != nil { + return serverAddr, nil + } + + // If no custom DNS, use system DNS + if proxyDNS == "" { + return serverAddr, nil + } + + // Parse DoT DNS URL (format: tls://host:port) + dnsURL, err := url.Parse(proxyDNS) + if err != nil { + return "", fmt.Errorf("parse proxy-dns URL: %w", err) + } + if dnsURL.Scheme != "tls" { + return "", fmt.Errorf("proxy-dns must use tls:// scheme, got: %s", dnsURL.Scheme) + } + dnsServer := dnsURL.Host + if dnsServer == "" { + return "", fmt.Errorf("proxy-dns missing host") + } + + // Resolve using DoT + resolvedIP, err := resolveWithDoT(host, dnsServer) + if err != nil { + return "", fmt.Errorf("DoT resolution failed: %w", err) + } + + log.Debugf("resolved SS server %s to %s using DoT DNS %s", host, resolvedIP, dnsServer) + return net.JoinHostPort(resolvedIP, port), nil +} + +// resolveWithDoT resolves a domain name using DNS over TLS. +// +// Parameters: +// - domain: The domain name to resolve +// - dnsServer: The DoT server address in "host:port" format +// +// Returns: +// - The resolved IP address as a string +// - error if resolution fails +func resolveWithDoT(domain, dnsServer string) (string, error) { + // Connect to DoT server with TLS + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 10 * time.Second}, + "tcp", + dnsServer, + &tls.Config{InsecureSkipVerify: true}, + ) + if err != nil { + return "", fmt.Errorf("connect to DoT server: %w", err) + } + defer conn.Close() + + // Build DNS query message + var msg dnsmessage.Message + msg.Header.ID = uint16(time.Now().UnixNano() & 0xFFFF) + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{{ + Name: dnsmessage.MustNewName(domain + "."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }} + + packed, err := msg.Pack() + if err != nil { + return "", fmt.Errorf("pack DNS query: %w", err) + } + + // DNS over TLS requires length prefix (2 bytes, big-endian) + length := make([]byte, 2) + binary.BigEndian.PutUint16(length, uint16(len(packed))) + if _, err := conn.Write(length); err != nil { + return "", fmt.Errorf("write length prefix: %w", err) + } + if _, err := conn.Write(packed); err != nil { + return "", fmt.Errorf("write DNS query: %w", err) + } + + // Read response + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + respLen := make([]byte, 2) + if _, err := conn.Read(respLen); err != nil { + return "", fmt.Errorf("read response length: %w", err) + } + respSize := int(binary.BigEndian.Uint16(respLen)) + resp := make([]byte, respSize) + if _, err := conn.Read(resp); err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + // Parse response + var respMsg dnsmessage.Message + if err := respMsg.Unpack(resp); err != nil { + return "", fmt.Errorf("unpack DNS response: %w", err) + } + + // Extract A record + for _, ans := range respMsg.Answers { + if ans.Header.Type == dnsmessage.TypeA { + aRecord := ans.Body.(*dnsmessage.AResource) + return fmt.Sprintf("%d.%d.%d.%d", aRecord.A[0], aRecord.A[1], aRecord.A[2], aRecord.A[3]), nil + } + } + + return "", fmt.Errorf("no A record found for %s", domain) +} + +// parseSSProxyURL parses a Shadowsocks URL and returns method, password, and server address. +// Supports formats: +// - ss://method:password@host:port +// - ss://BASE64(method:password)@host:port (SIP002 format) +func parseSSProxyURL(ssURL string) (method, password, server string, err error) { + u, errParse := url.Parse(ssURL) + if errParse != nil { + return "", "", "", fmt.Errorf("parse URL: %w", errParse) + } + if u.Scheme != "ss" { + return "", "", "", fmt.Errorf("not a Shadowsocks URL") + } + server = u.Host + if server == "" { + return "", "", "", fmt.Errorf("missing server address") + } + // Try to get method:password from userinfo + if u.User != nil { + // Format: ss://method:password@host:port + method = u.User.Username() + password, _ = u.User.Password() + if method != "" && password != "" { + return method, password, server, nil + } + // If only username is present, it might be base64 encoded + encoded := u.User.Username() + if encoded != "" { + decoded, errDecode := decodeSSProxyUserinfo(encoded) + if errDecode == nil { + parts := strings.SplitN(decoded, ":", 2) + if len(parts) == 2 { + return parts[0], parts[1], server, nil + } + } + } + } + return "", "", "", fmt.Errorf("cannot parse method and password from URL") +} + +// decodeSSProxyUserinfo decodes base64 userinfo (supports both standard and URL-safe base64). +func decodeSSProxyUserinfo(encoded string) (string, error) { + // Try URL-safe base64 first (used by SIP002) + decoded, err := base64.RawURLEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + // Try standard base64 + decoded, err = base64.StdEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + // Try standard base64 without padding + decoded, err = base64.RawStdEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + return "", fmt.Errorf("failed to decode base64: %w", err) +} + +// writeSSProxyTargetAddr writes the target address in SOCKS5-style format to the Shadowsocks connection. +// Format: ATYP (1 byte) + DST.ADDR (variable) + DST.PORT (2 bytes big-endian) +func writeSSProxyTargetAddr(conn net.Conn, addr string) error { + host, portStr, errSplit := net.SplitHostPort(addr) + if errSplit != nil { + return fmt.Errorf("split host port: %w", errSplit) + } + port, errPort := net.LookupPort("tcp", portStr) + if errPort != nil { + return fmt.Errorf("lookup port: %w", errPort) + } + var buf []byte + // Check if the host is an IP address + ip := net.ParseIP(host) + if ip != nil { + if ip4 := ip.To4(); ip4 != nil { + // IPv4 address: ATYP=0x01 + buf = make([]byte, 1+4+2) + buf[0] = 0x01 + copy(buf[1:5], ip4) + } else { + // IPv6 address: ATYP=0x04 + buf = make([]byte, 1+16+2) + buf[0] = 0x04 + copy(buf[1:17], ip.To16()) + } + } else { + // Domain name: ATYP=0x03 + if len(host) > 255 { + return fmt.Errorf("domain name too long: %d", len(host)) + } + buf = make([]byte, 1+1+len(host)+2) + buf[0] = 0x03 + buf[1] = byte(len(host)) + copy(buf[2:2+len(host)], host) + } + // Write port (big-endian) at the end + binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(port)) + _, errWrite := conn.Write(buf) + if errWrite != nil { + return fmt.Errorf("write target address: %w", errWrite) + } + return nil +} diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index b1ae58856..dec735809 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -55,6 +55,7 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea prefix := strings.TrimSpace(entry.Prefix) base := strings.TrimSpace(entry.BaseURL) proxyURL := strings.TrimSpace(entry.ProxyURL) + proxyDNS := strings.TrimSpace(entry.ProxyDNS) id, token := idGen.Next("gemini:apikey", key, base) attrs := map[string]string{ "source": fmt.Sprintf("config:gemini[%s]", token), @@ -77,6 +78,7 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Attributes: attrs, CreatedAt: now, UpdatedAt: now, @@ -118,6 +120,7 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea } addConfigHeadersToAttrs(ck.Headers, attrs) proxyURL := strings.TrimSpace(ck.ProxyURL) + proxyDNS := strings.TrimSpace(ck.ProxyDNS) a := &coreauth.Auth{ ID: id, Provider: "claude", @@ -125,6 +128,7 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Attributes: attrs, CreatedAt: now, UpdatedAt: now, @@ -165,6 +169,7 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau } addConfigHeadersToAttrs(ck.Headers, attrs) proxyURL := strings.TrimSpace(ck.ProxyURL) + proxyDNS := strings.TrimSpace(ck.ProxyDNS) a := &coreauth.Auth{ ID: id, Provider: "codex", @@ -172,6 +177,7 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Attributes: attrs, CreatedAt: now, UpdatedAt: now, @@ -204,6 +210,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor entry := &compat.APIKeyEntries[j] key := strings.TrimSpace(entry.APIKey) proxyURL := strings.TrimSpace(entry.ProxyURL) + proxyDNS := strings.TrimSpace(entry.ProxyDNS) idKind := fmt.Sprintf("openai-compatibility:%s", providerName) id, token := idGen.Next(idKind, key, base, proxyURL) attrs := map[string]string{ @@ -229,6 +236,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Attributes: attrs, CreatedAt: now, UpdatedAt: now, @@ -284,6 +292,7 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor key := strings.TrimSpace(compat.APIKey) prefix := strings.TrimSpace(compat.Prefix) proxyURL := strings.TrimSpace(compat.ProxyURL) + proxyDNS := strings.TrimSpace(compat.ProxyDNS) idKind := "vertex:apikey" id, token := idGen.Next(idKind, key, base, proxyURL) attrs := map[string]string{ @@ -308,6 +317,7 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Attributes: attrs, CreatedAt: now, UpdatedAt: now, diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 190d310ab..c3d56e635 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -77,6 +77,11 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e proxyURL = p } + proxyDNS := "" + if p, ok := metadata["proxy_dns"].(string); ok { + proxyDNS = p + } + prefix := "" if rawPrefix, ok := metadata["prefix"].(string); ok { trimmed := strings.TrimSpace(rawPrefix) @@ -97,6 +102,7 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e "path": full, }, ProxyURL: proxyURL, + ProxyDNS: proxyDNS, Metadata: metadata, CreatedAt: now, UpdatedAt: now, @@ -171,6 +177,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an if proxy != "" { metadataCopy["proxy_url"] = proxy } + proxyDNS := strings.TrimSpace(primary.ProxyDNS) + if proxyDNS != "" { + metadataCopy["proxy_dns"] = proxyDNS + } virtual := &coreauth.Auth{ ID: buildGeminiVirtualID(primary.ID, projectID), Provider: originalProvider, @@ -179,6 +189,7 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an Attributes: attrs, Metadata: metadataCopy, ProxyURL: primary.ProxyURL, + ProxyDNS: primary.ProxyDNS, Prefix: primary.Prefix, CreatedAt: primary.CreatedAt, UpdatedAt: primary.UpdatedAt, diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 6ac8b8a3f..1974448e0 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -124,6 +124,11 @@ func (s *FileTokenStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error) return walkErr } if d.IsDir() { + // Skip logs directory and unified-routing traces + name := d.Name() + if name == "logs" || name == "traces" { + return filepath.SkipDir + } return nil } if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 6662f9b9e..9be7d6e72 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -2204,3 +2204,98 @@ func (m *Manager) HttpRequest(ctx context.Context, auth *Auth, req *http.Request } return exec.HttpRequest(ctx, auth, req) } + +// GetExecutor returns the registered ProviderExecutor for the given provider key. +// Returns nil if no executor is registered for the provider. +// This enables external packages to directly invoke provider-specific execution logic. +func (m *Manager) GetExecutor(provider string) ProviderExecutor { + return m.executorFor(provider) +} + +// ExecuteStreamWithAuth performs a streaming execution using a specific auth directly. +// Unlike ExecuteStream which performs load balancing across auths of the same provider, +// this method targets a specific auth entry - useful for health checks and diagnostics. +// The method handles model rewriting, proxy/RoundTripper setup, and proper request translation. +func (m *Manager) ExecuteStreamWithAuth(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + if m == nil { + return nil, &Error{Code: "manager_nil", Message: "auth manager is nil"} + } + if auth == nil { + return nil, &Error{Code: "auth_not_found", Message: "auth is nil"} + } + + // Determine provider key from auth + providerKey := executorKeyFromAuth(auth) + if providerKey == "" { + return nil, &Error{Code: "provider_not_found", Message: "auth provider is empty"} + } + + // Get the executor for this provider + executor := m.executorFor(providerKey) + if executor == nil { + return nil, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey} + } + + // Setup execution context with RoundTripper (proxy support) + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + + // Apply model rewriting if configured + routeModel := req.Model + execReq := req + execReq.Model = rewriteModelForAuth(routeModel, auth) + execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) + execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + + // Ensure requested model metadata is set + opts = ensureRequestedModelMetadata(opts, routeModel) + + // Execute the stream request directly with the specified auth + return executor.ExecuteStream(execCtx, auth, execReq, opts) +} + +// ExecuteWithAuth performs a non-streaming execution using a specific auth directly. +// This is the non-streaming counterpart to ExecuteStreamWithAuth. +func (m *Manager) ExecuteWithAuth(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if m == nil { + return cliproxyexecutor.Response{}, &Error{Code: "manager_nil", Message: "auth manager is nil"} + } + if auth == nil { + return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "auth is nil"} + } + + // Determine provider key from auth + providerKey := executorKeyFromAuth(auth) + if providerKey == "" { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "auth provider is empty"} + } + + // Get the executor for this provider + executor := m.executorFor(providerKey) + if executor == nil { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey} + } + + // Setup execution context with RoundTripper (proxy support) + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + + // Apply model rewriting if configured + routeModel := req.Model + execReq := req + execReq.Model = rewriteModelForAuth(routeModel, auth) + execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) + execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + + // Ensure requested model metadata is set + opts = ensureRequestedModelMetadata(opts, routeModel) + + // Execute the request directly with the specified auth + return executor.Execute(execCtx, auth, execReq, opts) +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 4c69ae905..0a62f9c85 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -38,6 +38,9 @@ type Auth struct { Unavailable bool `json:"unavailable"` // ProxyURL overrides the global proxy setting for this auth if provided. ProxyURL string `json:"proxy_url,omitempty"` + // ProxyDNS is the DNS server (DoT format: tls://host:port) used to resolve SS proxy server hostnames. + // Only used when ProxyURL uses the ss:// scheme. Leave empty to use system DNS. + ProxyDNS string `json:"proxy_dns,omitempty"` // Attributes stores provider specific metadata needed by executors (immutable configuration). Attributes map[string]string `json:"attributes,omitempty"` // Metadata stores runtime mutable provider state (e.g. tokens, cookies). diff --git a/sdk/cliproxy/rtprovider.go b/sdk/cliproxy/rtprovider.go index dad4fc238..d0acf5c8b 100644 --- a/sdk/cliproxy/rtprovider.go +++ b/sdk/cliproxy/rtprovider.go @@ -2,14 +2,21 @@ package cliproxy import ( "context" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "fmt" "net" "net/http" "net/url" "strings" "sync" + "time" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sscore "github.com/shadowsocks/go-shadowsocks2/core" log "github.com/sirupsen/logrus" + "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/proxy" ) @@ -33,8 +40,13 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http. if proxyStr == "" { return nil } + proxyDNS := strings.TrimSpace(auth.ProxyDNS) + + // Cache key includes both proxy URL and DNS to handle different DNS configs + cacheKey := proxyStr + "|" + proxyDNS + p.mu.RLock() - rt := p.cache[proxyStr] + rt := p.cache[cacheKey] p.mu.RUnlock() if rt != nil { return rt @@ -63,6 +75,43 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http. return dialer.Dial(network, addr) }, } + } else if proxyURL.Scheme == "ss" { + // Configure Shadowsocks proxy. + ssMethod, ssPassword, ssServer, errSS := parseSSURL(proxyStr) + if errSS != nil { + log.Errorf("parse Shadowsocks URL failed: %v", errSS) + return nil + } + + // Resolve SS server address using custom DNS if provided + resolvedServer, errResolve := resolveSSServer(ssServer, proxyDNS) + if errResolve != nil { + log.Errorf("resolve Shadowsocks server address failed: %v", errResolve) + return nil + } + + cipher, errCipher := sscore.PickCipher(ssMethod, nil, ssPassword) + if errCipher != nil { + log.Errorf("create Shadowsocks cipher failed (method=%s): %v", ssMethod, errCipher) + return nil + } + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Connect to the Shadowsocks server using resolved address. + rawConn, errDial := net.Dial("tcp", resolvedServer) + if errDial != nil { + return nil, errDial + } + // Wrap the connection with Shadowsocks encryption. + ssConn := cipher.StreamConn(rawConn) + // Write the target address in SOCKS5-style format. + if errAddr := writeSSTargetAddr(ssConn, addr); errAddr != nil { + rawConn.Close() + return nil, errAddr + } + return ssConn, nil + }, + } } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { // Configure HTTP or HTTPS proxy. transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} @@ -71,7 +120,220 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http. return nil } p.mu.Lock() - p.cache[proxyStr] = transport + p.cache[cacheKey] = transport p.mu.Unlock() return transport } + +// parseSSURL parses a Shadowsocks URL and returns method, password, and server address. +// Supports formats: +// - ss://method:password@host:port +// - ss://BASE64(method:password)@host:port (SIP002 format) +func parseSSURL(ssURL string) (method, password, server string, err error) { + u, errParse := url.Parse(ssURL) + if errParse != nil { + return "", "", "", fmt.Errorf("parse URL: %w", errParse) + } + if u.Scheme != "ss" { + return "", "", "", fmt.Errorf("not a Shadowsocks URL") + } + server = u.Host + if server == "" { + return "", "", "", fmt.Errorf("missing server address") + } + // Try to get method:password from userinfo. + if u.User != nil { + // Format: ss://method:password@host:port + method = u.User.Username() + password, _ = u.User.Password() + if method != "" && password != "" { + return method, password, server, nil + } + // If only username is present, it might be base64 encoded. + encoded := u.User.Username() + if encoded != "" { + decoded, errDecode := decodeSSUserinfo(encoded) + if errDecode == nil { + parts := strings.SplitN(decoded, ":", 2) + if len(parts) == 2 { + return parts[0], parts[1], server, nil + } + } + } + } + return "", "", "", fmt.Errorf("cannot parse method and password from URL") +} + +// decodeSSUserinfo decodes base64 userinfo (supports both standard and URL-safe base64). +func decodeSSUserinfo(encoded string) (string, error) { + // Try URL-safe base64 first (used by SIP002). + decoded, err := base64.RawURLEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + // Try standard base64. + decoded, err = base64.StdEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + // Try standard base64 without padding. + decoded, err = base64.RawStdEncoding.DecodeString(encoded) + if err == nil { + return string(decoded), nil + } + return "", fmt.Errorf("failed to decode base64: %w", err) +} + +// writeSSTargetAddr writes the target address in SOCKS5-style format to the Shadowsocks connection. +// Format: ATYP (1 byte) + DST.ADDR (variable) + DST.PORT (2 bytes big-endian) +func writeSSTargetAddr(conn net.Conn, addr string) error { + host, portStr, errSplit := net.SplitHostPort(addr) + if errSplit != nil { + return fmt.Errorf("split host port: %w", errSplit) + } + port, errPort := net.LookupPort("tcp", portStr) + if errPort != nil { + return fmt.Errorf("lookup port: %w", errPort) + } + var buf []byte + // Check if the host is an IP address. + ip := net.ParseIP(host) + if ip != nil { + if ip4 := ip.To4(); ip4 != nil { + // IPv4 address: ATYP=0x01 + buf = make([]byte, 1+4+2) + buf[0] = 0x01 + copy(buf[1:5], ip4) + } else { + // IPv6 address: ATYP=0x04 + buf = make([]byte, 1+16+2) + buf[0] = 0x04 + copy(buf[1:17], ip.To16()) + } + } else { + // Domain name: ATYP=0x03 + if len(host) > 255 { + return fmt.Errorf("domain name too long: %d", len(host)) + } + buf = make([]byte, 1+1+len(host)+2) + buf[0] = 0x03 + buf[1] = byte(len(host)) + copy(buf[2:2+len(host)], host) + } + // Write port (big-endian) at the end. + binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(port)) + _, errWrite := conn.Write(buf) + if errWrite != nil { + return fmt.Errorf("write target address: %w", errWrite) + } + return nil +} + +// resolveSSServer resolves the SS server address using custom DoT DNS if provided. +// If proxyDNS is empty or the host is already an IP, returns the original address. +func resolveSSServer(serverAddr, proxyDNS string) (string, error) { + host, port, err := net.SplitHostPort(serverAddr) + if err != nil { + return "", fmt.Errorf("invalid server address: %w", err) + } + + // If host is already an IP, return as-is + if ip := net.ParseIP(host); ip != nil { + return serverAddr, nil + } + + // If no custom DNS, use system DNS + if proxyDNS == "" { + return serverAddr, nil + } + + // Parse DoT DNS URL (format: tls://host:port) + dnsURL, err := url.Parse(proxyDNS) + if err != nil { + return "", fmt.Errorf("parse proxy-dns URL: %w", err) + } + if dnsURL.Scheme != "tls" { + return "", fmt.Errorf("proxy-dns must use tls:// scheme, got: %s", dnsURL.Scheme) + } + dnsServer := dnsURL.Host + if dnsServer == "" { + return "", fmt.Errorf("proxy-dns missing host") + } + + // Resolve using DoT + resolvedIP, err := resolveWithDoT(host, dnsServer) + if err != nil { + return "", fmt.Errorf("DoT resolution failed: %w", err) + } + + log.Debugf("resolved SS server %s to %s using DoT DNS %s", host, resolvedIP, dnsServer) + return net.JoinHostPort(resolvedIP, port), nil +} + +// resolveWithDoT resolves a domain name using DNS over TLS. +func resolveWithDoT(domain, dnsServer string) (string, error) { + // Connect to DoT server with TLS + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 10 * time.Second}, + "tcp", + dnsServer, + &tls.Config{InsecureSkipVerify: true}, + ) + if err != nil { + return "", fmt.Errorf("connect to DoT server: %w", err) + } + defer conn.Close() + + // Build DNS query message + var msg dnsmessage.Message + msg.Header.ID = uint16(time.Now().UnixNano() & 0xFFFF) + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{{ + Name: dnsmessage.MustNewName(domain + "."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }} + + packed, err := msg.Pack() + if err != nil { + return "", fmt.Errorf("pack DNS query: %w", err) + } + + // DNS over TLS requires length prefix (2 bytes, big-endian) + length := make([]byte, 2) + binary.BigEndian.PutUint16(length, uint16(len(packed))) + if _, err := conn.Write(length); err != nil { + return "", fmt.Errorf("write length prefix: %w", err) + } + if _, err := conn.Write(packed); err != nil { + return "", fmt.Errorf("write DNS query: %w", err) + } + + // Read response + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + respLen := make([]byte, 2) + if _, err := conn.Read(respLen); err != nil { + return "", fmt.Errorf("read response length: %w", err) + } + respSize := int(binary.BigEndian.Uint16(respLen)) + resp := make([]byte, respSize) + if _, err := conn.Read(resp); err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + // Parse response + var respMsg dnsmessage.Message + if err := respMsg.Unpack(resp); err != nil { + return "", fmt.Errorf("unpack DNS response: %w", err) + } + + // Extract A record + for _, ans := range respMsg.Answers { + if ans.Header.Type == dnsmessage.TypeA { + aRecord := ans.Body.(*dnsmessage.AResource) + return fmt.Sprintf("%d.%d.%d.%d", aRecord.A[0], aRecord.A[1], aRecord.A[2], aRecord.A[3]), nil + } + } + + return "", fmt.Errorf("no A record found for %s", domain) +}