diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index a1c8ca01..17b5ebbf 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -189,17 +189,30 @@ func filterByModelID(query string) modelFilter { } } -func filterBySource(source string) modelFilter { +func filterByLLMTypes(llmTypes []string) modelFilter { + allowedTypes := parseLLMTypes(llmTypes) return func(m *types.Model) bool { - switch source { - case string(types.ModelSourceCSGHub): - return m.CSGHubModelID != "" - case string(types.ModelSourceExternal): - return m.Provider != "" - default: + if len(allowedTypes) == 0 { return true } + if m == nil || m.Metadata == nil { + return false + } + llmType, _ := m.Metadata[types.MetaKeyLLMType].(string) + return allowedTypes[strings.ToLower(strings.TrimSpace(llmType))] + } +} + +func parseLLMTypes(llmTypes []string) map[string]bool { + allowedTypes := make(map[string]bool) + for _, rawType := range llmTypes { + llmType := strings.ToLower(strings.TrimSpace(rawType)) + if llmType == "" { + continue + } + allowedTypes[llmType] = true } + return allowedTypes } func filterByTask(task string) modelFilter { @@ -233,13 +246,13 @@ func applyFilters(models []types.Model, filters []modelFilter) []types.Model { } func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList { - var filters []modelFilter + filters := modelListDefaultFilters() if searchQuery := strings.ToLower(req.ModelID); searchQuery != "" { filters = append(filters, filterByModelID(searchQuery)) } - if source := strings.ToLower(req.Source); source != "" { - filters = append(filters, filterBySource(source)) + if len(req.LLMTypes) > 0 { + filters = append(filters, filterByLLMTypes(req.LLMTypes)) } if task := strings.ToLower(req.Task); task != "" { filters = append(filters, filterByTask(task)) @@ -328,7 +341,8 @@ func (c *openaiComponentImpl) getCSGHubModels(ctx context.Context, userID int64) SupportFunctionCall: supportFunctionCall, Task: string(deploy.Task), Metadata: map[string]any{ - types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type), + types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type), + types.MetaKeyRepoPath: deploy.Repository.Path, }, }, InternalModelInfo: types.InternalModelInfo{ @@ -375,30 +389,35 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model page := 1 var models []types.Model for { - extModels, _, err := m.extllmStore.Index(c, per, page, search) + extModels, _, err := m.extllmStore.IndexWithRepo(c, per, page, search) if err != nil { slog.Error("failed to get external models", "error", err) break } for _, extModel := range extModels { - // Extract tasks from metadata if present + metadata := maps.Clone(extModel.Metadata) + if metadata == nil { + metadata = map[string]any{} + } task := "" - if extModel.Metadata != nil { - if tasks, ok := extModel.Metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 { - tasksStrings := make([]string, 0, len(tasks)) - for _, t := range tasks { - if s, ok := t.(string); ok { - tasksStrings = append(tasksStrings, s) - } + if tasks, ok := metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 { + tasksStrings := make([]string, 0, len(tasks)) + for _, t := range tasks { + if s, ok := t.(string); ok { + tasksStrings = append(tasksStrings, s) } - task = strings.Join(tasksStrings, ",") } + task = strings.Join(tasksStrings, ",") } - if extModel.Metadata == nil { - extModel.Metadata = map[string]any{} + if extModel.RepoID != 0 { + if extModel.Repo != nil && extModel.Repo.Path != "" { + metadata[types.MetaKeyRepoPath] = extModel.Repo.Path + } else { + slog.WarnContext(c, "llm config repo relation unavailable", "llm_config_id", extModel.ID, "repo_id", extModel.RepoID) + } } - extModel.Metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM + metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM // Convert relational upstreams to types.UpstreamConfig for routing upstreams := dbUpstreamsToConfigs(extModel.Upstreams) provider := extModel.PrimaryProvider() @@ -407,7 +426,7 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model Object: "model", ID: extModel.ModelName, OwnedBy: provider, - Metadata: extModel.Metadata, + Metadata: metadata, Task: task, }, Endpoint: router.FirstEnabledUpstream(upstreams), diff --git a/aigateway/component/openai_ce_test.go b/aigateway/component/openai_ce_test.go index 57b72440..3a45b0e5 100644 --- a/aigateway/component/openai_ce_test.go +++ b/aigateway/component/openai_ce_test.go @@ -70,6 +70,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SecureLevel: commontypes.EndpointPublic, Repository: &database.Repository{ HFPath: "hf-model2", + Path: "model2", }, User: &database.User{ Username: "serverless-owner", @@ -84,7 +85,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(0)). Return(deploys, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil) var wg sync.WaitGroup @@ -104,8 +105,10 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { require.Len(t, models, 2) assert.Equal(t, "model1:svc1", models[0].ID) assert.Equal(t, "publicuser", models[0].OwnedBy) + assert.Equal(t, "model1", models[0].Metadata[types.MetaKeyRepoPath]) assert.Equal(t, "hf-model2:svc2", models[1].ID) assert.Equal(t, "OpenCSG", models[1].OwnedBy) + assert.Equal(t, "model2", models[1].Metadata[types.MetaKeyRepoPath]) wg.Wait() }) @@ -117,7 +120,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { } mockUserStore.EXPECT().FindByUsername(mock.Anything, "testuser"). Return(*user, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil) now := time.Now() deploys := []database.Deploy{ @@ -145,6 +148,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SecureLevel: commontypes.EndpointPublic, Repository: &database.Repository{ HFPath: "hf-model2", + Path: "model2", }, User: &database.User{ Username: "testuser", @@ -179,12 +183,14 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { assert.Equal(t, "testuser", models[0].OwnedBy) assert.Equal(t, "endpoint1", models[0].Endpoint) assert.Equal(t, "text-generation", models[0].Task) + assert.Equal(t, "model1", models[0].Metadata[types.MetaKeyRepoPath]) // Verify second model (serverless) assert.Equal(t, "hf-model2:svc2", models[1].ID) assert.Equal(t, "OpenCSG", models[1].OwnedBy) assert.Equal(t, "endpoint2", models[1].Endpoint) assert.Equal(t, "text-to-image", models[1].Task) + assert.Equal(t, "model2", models[1].Metadata[types.MetaKeyRepoPath]) wg.Wait() }) @@ -196,7 +202,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { } mockUserStore.EXPECT().FindByUsername(mock.Anything, "testuser"). Return(*user, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil) now := time.Now() @@ -294,7 +300,7 @@ func TestOpenAIComponent_GetAvailableModels_CacheUsesModelSnapshot(t *testing.T) mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(0)). Return(deploys, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil).Once() firstWriteStarted := make(chan struct{}) @@ -367,7 +373,7 @@ func TestOpenAIComponent_ListModels_CacheUsesOriginalID(t *testing.T) { SortBy: "model_size_b", SortOrder: "desc", } - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search). Return([]*database.LLMConfig{ { ID: 1, @@ -463,7 +469,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { Return(*user, nil) mockCache.EXPECT().Exists(mock.Anything, modelCacheKey). Return(0, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil).Once() now := time.Now() deploys := []database.Deploy{ @@ -513,7 +519,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { Return("", redis.Nil).Once() // Cache miss: GetModelByID falls through to GetAvailableModels, which calls getCSGHubModels and getExternalModels mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(1)).Return([]database.Deploy{}, nil).Once() - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil).Once() model, err := comp.GetModelByID(context.Background(), "testuser", "nonexistent:svc") assert.NoError(t, err) @@ -540,7 +546,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { SortBy: "model_size_b", SortOrder: "desc", } - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search). Return([]*database.LLMConfig{ { ID: 1, @@ -639,7 +645,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { SortBy: "model_size_b", SortOrder: "desc", } - mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search). + mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search). Return([]*database.LLMConfig{ { ID: 1, @@ -685,10 +691,10 @@ func TestOpenAIComponent_saveModelsToCache(t *testing.T) { }, Endpoint: "http://test-endpoint", ExternalModelInfo: types.ExternalModelInfo{ - Provider: "openai", - AuthHead: "Bearer test-token", - NeedSensitiveCheck: true, - }, + Provider: "openai", + AuthHead: "Bearer test-token", + NeedSensitiveCheck: true, + }, }, } @@ -794,7 +800,7 @@ func TestOpenAIComponent_ExtGetAvailableModels_Error(t *testing.T) { SortBy: "model_size_b", SortOrder: "desc", } - mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search). + mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search). Return(nil, 0, errors.New("test error")).Once() user := &database.User{ ID: 1, @@ -823,6 +829,9 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { extllmStore: mockLLMConfigStore, modelListCache: mockCache, } + originalMetadata := map[string]any{ + types.MetaKeyTasks: []any{"text-generation", "text-to-image"}, + } mockModels := []*database.LLMConfig{ { ID: 1, @@ -830,6 +839,13 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { Type: 16, Enabled: true, Provider: "OpenAI", + Metadata: originalMetadata, + RepoID: 100, + Repo: &database.Repository{ + ID: 100, + Path: "test-ns/test-model-1", + GitPath: "models/test-ns/test-model-1", + }, }, } user := &database.User{ @@ -848,7 +864,7 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { SortBy: "model_size_b", SortOrder: "desc", } - mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search).Return(mockModels, 1, nil) + mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search).Return(mockModels, 1, nil) mockCache.EXPECT().HSet(mock.Anything, modelCacheKey, "test-model-1", mock.Anything). Return(nil).Once() var wg sync.WaitGroup @@ -863,5 +879,44 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { require.Nil(t, err) require.Len(t, models, 1) require.Equal(t, "test-model-1", models[0].ID) + require.Equal(t, "text-generation,text-to-image", models[0].Task) + require.Equal(t, "test-ns/test-model-1", models[0].Metadata[types.MetaKeyRepoPath]) + require.Equal(t, types.ProviderTypeExternalLLM, models[0].Metadata[types.MetaKeyLLMType]) + require.NotContains(t, originalMetadata, types.MetaKeyRepoPath) + require.NotContains(t, originalMetadata, types.MetaKeyLLMType) wg.Wait() } + +func TestOpenAIComponent_GetExternalModelsWithoutRepoOmitsRepoPath(t *testing.T) { + ctx := context.Background() + mockLLMConfigStore := mockdb.NewMockLLMConfigStore(t) + component := &openaiComponentImpl{ + extllmStore: mockLLMConfigStore, + } + searchType := 16 + enabled := true + search := &commontypes.SearchLLMConfig{ + Type: &searchType, + Enabled: &enabled, + SortBy: "model_size_b", + SortOrder: "desc", + } + mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search).Return([]*database.LLMConfig{ + { + ID: 1, + ModelName: "model-without-repo", + Type: 16, + Enabled: true, + Provider: "OpenAI", + Metadata: map[string]any{"existing": "value"}, + }, + }, 1, nil).Once() + + models := component.getExternalModels(ctx) + + require.Len(t, models, 1) + require.Equal(t, "model-without-repo", models[0].ID) + require.Equal(t, "value", models[0].Metadata["existing"]) + require.Equal(t, types.ProviderTypeExternalLLM, models[0].Metadata[types.MetaKeyLLMType]) + require.NotContains(t, models[0].Metadata, types.MetaKeyRepoPath) +} diff --git a/aigateway/component/openai_model_filter_ce.go b/aigateway/component/openai_model_filter_ce.go new file mode 100644 index 00000000..fb28f45d --- /dev/null +++ b/aigateway/component/openai_model_filter_ce.go @@ -0,0 +1,7 @@ +//go:build !ee && !saas + +package component + +func modelListDefaultFilters() []modelFilter { + return nil +} diff --git a/aigateway/component/openai_test.go b/aigateway/component/openai_test.go index 8298120e..85fb6b5d 100644 --- a/aigateway/component/openai_test.go +++ b/aigateway/component/openai_test.go @@ -91,64 +91,51 @@ func TestFilterAndPaginateModels(t *testing.T) { assert.False(t, resp.HasMore) }) - t.Run("source filter csghub", func(t *testing.T) { - modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, - {BaseModel: types.BaseModel{ID: "csghub-model:svc2", Object: "model", OwnedBy: "u2"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "org/model2"}}, + t.Run("llm_types filter external_llm", func(t *testing.T) { + modelsWithLLMTypes := []types.Model{ + {BaseModel: types.BaseModel{ID: "inference-model", Object: "model", OwnedBy: "u1", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeInference}}}, + {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM, types.MetaKeyPricingConfigured: true}}}, + {BaseModel: types.BaseModel{ID: "serverless-model", Object: "model", OwnedBy: "u2", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeServerless, types.MetaKeyPricingConfigured: true}}}, } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceCSGHub)}) - assert.Equal(t, 2, resp.TotalCount) - assert.Len(t, resp.Data, 2) - assert.Equal(t, "csghub-model:svc1", resp.Data[0].ID) - assert.Equal(t, "csghub-model:svc2", resp.Data[1].ID) + resp := filterAndPaginateModels(modelsWithLLMTypes, types.ListModelsReq{LLMTypes: []string{types.ProviderTypeExternalLLM}}) + assert.Equal(t, 1, resp.TotalCount) + require.Len(t, resp.Data, 1) + assert.Equal(t, "external-model", resp.Data[0].ID) }) - t.Run("source filter external", func(t *testing.T) { - modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, - {BaseModel: types.BaseModel{ID: "claude", Object: "model", OwnedBy: "anthropic"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "anthropic"}}, + t.Run("llm_types filter supports multiple values", func(t *testing.T) { + modelsWithLLMTypes := []types.Model{ + {BaseModel: types.BaseModel{ID: "inference-model", Object: "model", OwnedBy: "u1", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeInference}}}, + {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM, types.MetaKeyPricingConfigured: true}}}, + {BaseModel: types.BaseModel{ID: "serverless-model", Object: "model", OwnedBy: "u2", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeServerless, types.MetaKeyPricingConfigured: true}}}, } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceExternal)}) + resp := filterAndPaginateModels(modelsWithLLMTypes, types.ListModelsReq{LLMTypes: []string{types.ProviderTypeServerless, types.ProviderTypeInference}}) assert.Equal(t, 2, resp.TotalCount) - assert.Len(t, resp.Data, 2) - assert.Equal(t, "gpt-4", resp.Data[0].ID) - assert.Equal(t, "claude", resp.Data[1].ID) + require.Len(t, resp.Data, 2) + assert.Equal(t, "inference-model", resp.Data[0].ID) + assert.Equal(t, "serverless-model", resp.Data[1].ID) }) - t.Run("source filter is case-insensitive", func(t *testing.T) { - modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + t.Run("llm_types filter is case-insensitive and trims spaces", func(t *testing.T) { + modelsWithLLMTypes := []types.Model{ + {BaseModel: types.BaseModel{ID: "serverless-model", Object: "model", OwnedBy: "u1", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeServerless, types.MetaKeyPricingConfigured: true}}}, + {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM, types.MetaKeyPricingConfigured: true}}}, } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: "CSGHub"}) + resp := filterAndPaginateModels(modelsWithLLMTypes, types.ListModelsReq{LLMTypes: []string{" SERVERLESS "}}) assert.Equal(t, 1, resp.TotalCount) - assert.Len(t, resp.Data, 1) - assert.Equal(t, "csghub-model:svc1", resp.Data[0].ID) - }) - - t.Run("unknown source filter includes all", func(t *testing.T) { - modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, - } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: "unknown"}) - assert.Equal(t, 2, resp.TotalCount) - assert.Len(t, resp.Data, 2) + require.Len(t, resp.Data, 1) + assert.Equal(t, "serverless-model", resp.Data[0].ID) }) - t.Run("source filter csghub includes public and private deployments", func(t *testing.T) { - modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-public", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "csghub-private", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model2"}}, - {BaseModel: types.BaseModel{ID: "external-public", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + t.Run("llm_types filter excludes models without llm_type", func(t *testing.T) { + modelsWithLLMTypes := []types.Model{ + {BaseModel: types.BaseModel{ID: "missing-llm-type", Object: "model", OwnedBy: "u1"}}, + {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM, types.MetaKeyPricingConfigured: true}}}, } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceCSGHub)}) - assert.Equal(t, 2, resp.TotalCount) - assert.Len(t, resp.Data, 2) - assert.Equal(t, "csghub-public", resp.Data[0].ID) - assert.Equal(t, "csghub-private", resp.Data[1].ID) + resp := filterAndPaginateModels(modelsWithLLMTypes, types.ListModelsReq{LLMTypes: []string{types.ProviderTypeExternalLLM}}) + assert.Equal(t, 1, resp.TotalCount) + require.Len(t, resp.Data, 1) + assert.Equal(t, "external-model", resp.Data[0].ID) }) t.Run("task filter text-generation", func(t *testing.T) { @@ -235,16 +222,16 @@ func TestFilterAndPaginateModels(t *testing.T) { assert.Equal(t, 1, resp.TotalCount) }) - t.Run("task filter combined with source filter", func(t *testing.T) { + t.Run("task filter combined with llm_types filter", func(t *testing.T) { modelsWithTask := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-gen", Object: "model", OwnedBy: "u1", Task: "text-generation"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "csghub-image", Object: "model", OwnedBy: "u1", Task: "text-to-image"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model2"}}, - {BaseModel: types.BaseModel{ID: "external-gen", Object: "model", OwnedBy: "openai", Task: "text-generation"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "inference-gen", Object: "model", OwnedBy: "u1", Task: "text-generation", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeInference}}}, + {BaseModel: types.BaseModel{ID: "inference-image", Object: "model", OwnedBy: "u1", Task: "text-to-image", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeInference}}}, + {BaseModel: types.BaseModel{ID: "external-gen", Object: "model", OwnedBy: "openai", Task: "text-generation", Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM, types.MetaKeyPricingConfigured: true}}}, } - resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Source: string(types.ModelSourceCSGHub), Task: "text-generation"}) + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{LLMTypes: []string{types.ProviderTypeInference}, Task: "text-generation"}) assert.Equal(t, 1, resp.TotalCount) assert.Len(t, resp.Data, 1) - assert.Equal(t, "csghub-gen", resp.Data[0].ID) + assert.Equal(t, "inference-gen", resp.Data[0].ID) }) } diff --git a/aigateway/handler/model_price_guard_ce.go b/aigateway/handler/model_price_guard_ce.go new file mode 100644 index 00000000..1ddb229d --- /dev/null +++ b/aigateway/handler/model_price_guard_ce.go @@ -0,0 +1,9 @@ +//go:build !ee && !saas + +package handler + +import "opencsg.com/csghub-server/aigateway/types" + +func modelSKUPriceStatus(model *types.Model) (requiresSKUPrice bool, hasConfiguredSKUPrice bool) { + return false, false +} diff --git a/aigateway/handler/model_target.go b/aigateway/handler/model_target.go index 4466b386..926e000c 100644 --- a/aigateway/handler/model_target.go +++ b/aigateway/handler/model_target.go @@ -118,6 +118,16 @@ func newInvalidRequestModelTargetError(code, message string, options modelTarget }) } +func newServerModelTargetError(code, message string, options modelTargetErrorOptions) *modelTargetError { + return newModelTargetError(modelTargetErrorParams{ + Status: http.StatusInternalServerError, + Code: code, + Message: message, + Type: "server_error", + Options: options, + }) +} + func (h *OpenAIHandlerImpl) resolveModelTarget(ctx context.Context, username, modelID string, headers http.Header) (*resolvedModelTarget, error) { model, err := h.openaiComponent.GetModelByID(ctx, username, modelID) if err != nil { @@ -126,6 +136,14 @@ func (h *OpenAIHandlerImpl) resolveModelTarget(ctx context.Context, username, mo if model == nil { return nil, newInvalidRequestModelTargetError("model_not_found", fmt.Sprintf("model '%s' not found", modelID), modelTargetErrorOptions{}) } + requiresSKUPrice, hasConfiguredSKUPrice := modelSKUPriceStatus(model) + if requiresSKUPrice && !hasConfiguredSKUPrice { + return nil, newServerModelTargetError( + "model_price_not_configured", + "target model has no configured SKU price", + modelTargetErrorOptions{Model: model}, + ) + } targetReq := commonType.EndpointReq{ ClusterID: model.ClusterID, @@ -293,16 +311,17 @@ func (h *OpenAIHandlerImpl) filterAvailableUpstreams( } return filtered, nil } + // IsCacheUpstreamCircuitOpen returns true only when the runtime circuit cache // explicitly reports this upstream as open. // // Design note: -// - If the availability manager is not initialized, return false and allow the -// proxy request to proceed. -// - If reading the runtime circuit state fails, return false and allow the -// proxy request to proceed. -// - If the runtime cache reports a non-open state, return false and allow the -// proxy request to proceed even when the persisted upstream state says open. +// - If the availability manager is not initialized, return false and allow the +// proxy request to proceed. +// - If reading the runtime circuit state fails, return false and allow the +// proxy request to proceed. +// - If the runtime cache reports a non-open state, return false and allow the +// proxy request to proceed even when the persisted upstream state says open. // // In short, we intentionally keep circuit filtering permissive here. We only // block an upstream when the runtime cache confirms it is truly open, so users diff --git a/aigateway/handler/model_target_test.go b/aigateway/handler/model_target_test.go index ca801faf..bbc93fb1 100644 --- a/aigateway/handler/model_target_test.go +++ b/aigateway/handler/model_target_test.go @@ -360,6 +360,56 @@ func TestResolveModelTarget_ModelNotRunningWhenNoEndpoint(t *testing.T) { require.Equal(t, http.StatusBadRequest, targetErr.Status) } +func TestResolveModelTarget_PricedExternalLLM(t *testing.T) { + tester, _, _ := setupTest(t) + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "external-model", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeExternalLLM, + types.MetaKeyPricingConfigured: true, + }, + }, + ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}, + Upstreams: []commontypes.UpstreamConfig{{URL: "https://api.example.com/v1/chat/completions", Enabled: true, ModelName: "provider-model"}}, + } + tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "external-model").Return(model, nil).Once() + + resolved, err := tester.handler.resolveModelTarget(context.Background(), "testuser", "external-model", http.Header{}) + + require.NoError(t, err) + require.Equal(t, "https://api.example.com/v1/chat/completions", resolved.Target) + require.Equal(t, "provider-model", resolved.ModelName) +} + +func TestResolveModelTarget_InferenceWithoutPricingFlag(t *testing.T) { + tester, _, _ := setupTest(t) + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "inference-model", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, + }, + InternalModelInfo: types.InternalModelInfo{ + CSGHubModelID: "namespace/model", + ClusterID: "cluster-1", + SvcName: "svc-model", + }, + Endpoint: "https://model.internal/v1/chat/completions", + } + tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "inference-model").Return(model, nil).Once() + tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "cluster-1").Return(&database.ClusterInfo{ + ClusterID: "cluster-1", + }, nil).Once() + + resolved, err := tester.handler.resolveModelTarget(context.Background(), "testuser", "inference-model", http.Header{}) + + require.NoError(t, err) + require.Equal(t, "https://model.internal/v1/chat/completions", resolved.Target) + require.Equal(t, "namespace/model", resolved.ModelName) +} + func TestExtractSessionKeyForModel(t *testing.T) { model := &types.Model{ RoutingPolicy: commontypes.RoutingPolicy{ @@ -486,7 +536,7 @@ func TestFilterAvailableUpstreams_FiltersCircuitOpen(t *testing.T) { require.NoError(t, err) require.Len(t, result, 3) require.Equal(t, "https://api.example.com/closed", result[0].URL) - require.Equal(t, "https://api.example.com/open", result[1].URL) // no runtime circuit state, allow proxy + require.Equal(t, "https://api.example.com/open", result[1].URL) // no runtime circuit state, allow proxy require.Equal(t, "https://api.example.com/cb_disabled", result[2].URL) // CB disabled, passes } diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index b2406885..c07f030c 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -252,31 +252,33 @@ type OpenAIHandlerImpl struct { // ListModels godoc // @Summary List available models -// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by source and task +// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by llm_types and task // @Tags AIGateway // @Accept json // @Produce json // @Param model_id query string false "Model ID for fuzzy search" -// @Param source query string false "Filter by source (csghub for CSGHub models, external for external models)" Enums(csghub, external) +// @Param llm_types query []string false "Filter by LLM types" Enums(external_llm, serverless, inference) // @Param task query string false "Filter by task (e.g., text-generation, text-to-image, image-to-image)" // @Param per query int false "Models per page (default 20, max 100)" // @Param page query int false "Page number (1-based, default 1)" // @Success 200 {object} types.ModelList "OK" -// @Failure 400 {object} error "Invalid source parameter" +// @Failure 400 {object} error "Invalid llm_types parameter" // @Failure 500 {object} error "Internal server error" // @Router /v1/models [get] func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { currentUser := httpbase.GetCurrentUser(c) - // Validate source parameter - source := strings.TrimSpace(c.Query("source")) - if source != "" { - sourceLower := strings.ToLower(source) - if sourceLower != string(types.ModelSourceCSGHub) && sourceLower != string(types.ModelSourceExternal) { + // Validate llm_types parameter + llmTypes := c.QueryArray("llm_types") + for _, llmType := range llmTypes { + if strings.TrimSpace(llmType) == "" { + continue + } + if !isValidListModelsLLMType(llmType) { c.JSON(http.StatusBadRequest, gin.H{ "error": types.Error{ Code: "invalid_request_error", - Message: fmt.Sprintf("Invalid source parameter. Must be '%s' or '%s'", types.ModelSourceCSGHub, types.ModelSourceExternal), + Message: invalidLLMTypesErrorMessage(), Type: "invalid_request_error", }}) return @@ -284,11 +286,11 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { } resp, err := h.openaiComponent.ListModels(c.Request.Context(), currentUser, types.ListModelsReq{ - ModelID: c.Query("model_id"), - Source: source, - Task: c.Query("task"), - Per: c.Query("per"), - Page: c.Query("page"), + ModelID: c.Query("model_id"), + LLMTypes: llmTypes, + Task: c.Query("task"), + Per: c.Query("per"), + Page: c.Query("page"), }) if err != nil { slog.ErrorContext(c.Request.Context(), "failed to get available models", "error", err.Error(), "current_user", currentUser) @@ -304,6 +306,19 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { c.PureJSON(http.StatusOK, resp) } +func isValidListModelsLLMType(llmType string) bool { + switch strings.ToLower(strings.TrimSpace(llmType)) { + case types.ProviderTypeExternalLLM, types.ProviderTypeServerless, types.ProviderTypeInference: + return true + default: + return false + } +} + +func invalidLLMTypesErrorMessage() string { + return fmt.Sprintf("Invalid llm_types parameter. Allowed values: %s, %s, %s", types.ProviderTypeExternalLLM, types.ProviderTypeServerless, types.ProviderTypeInference) +} + // GetModel godoc // @Security ApiKey // @Summary Get model details diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index 38af7898..c85d4b7a 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -230,9 +230,9 @@ func TestOpenAIHandler_ListModels(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, w.Code) }) - t.Run("invalid source parameter", func(t *testing.T) { + t.Run("invalid llm_types parameter", func(t *testing.T) { tester, c, w := setupTest(t) - tester.WithQuery("source", "invalid") + tester.WithQuery("llm_types", "invalid") tester.handler.ListModels(c) @@ -243,17 +243,18 @@ func TestOpenAIHandler_ListModels(t *testing.T) { errObj, ok := response["error"].(map[string]interface{}) assert.True(t, ok) assert.Equal(t, "invalid_request_error", errObj["code"]) - assert.Contains(t, errObj["message"], "Invalid source parameter") - assert.Contains(t, errObj["message"], string(types.ModelSourceCSGHub)) - assert.Contains(t, errObj["message"], string(types.ModelSourceExternal)) + assert.Contains(t, errObj["message"], "Invalid llm_types parameter") + assert.Contains(t, errObj["message"], types.ProviderTypeExternalLLM) + assert.Contains(t, errObj["message"], types.ProviderTypeServerless) + assert.Contains(t, errObj["message"], types.ProviderTypeInference) }) - t.Run("valid source parameter csghub", func(t *testing.T) { + t.Run("valid llm_types parameter external_llm", func(t *testing.T) { tester, c, w := setupTest(t) - tester.WithQuery("source", string(types.ModelSourceCSGHub)) + tester.WithQuery("llm_types", types.ProviderTypeExternalLLM) tester.mocks.openAIComp.EXPECT(). - ListModels(mock.Anything, "testuser", types.ListModelsReq{Source: string(types.ModelSourceCSGHub)}). + ListModels(mock.Anything, "testuser", types.ListModelsReq{LLMTypes: []string{types.ProviderTypeExternalLLM}}). Return(types.ModelList{Object: "list", Data: []types.Model{}, HasMore: false, TotalCount: 0}, nil).Once() tester.handler.ListModels(c) @@ -261,12 +262,13 @@ func TestOpenAIHandler_ListModels(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("valid source parameter external", func(t *testing.T) { + t.Run("valid llm_types parameter multiple values", func(t *testing.T) { tester, c, w := setupTest(t) - tester.WithQuery("source", string(types.ModelSourceExternal)) + tester.WithQuery("llm_types", types.ProviderTypeServerless) + tester.WithQuery("llm_types", types.ProviderTypeInference) tester.mocks.openAIComp.EXPECT(). - ListModels(mock.Anything, "testuser", types.ListModelsReq{Source: string(types.ModelSourceExternal)}). + ListModels(mock.Anything, "testuser", types.ListModelsReq{LLMTypes: []string{types.ProviderTypeServerless, types.ProviderTypeInference}}). Return(types.ModelList{Object: "list", Data: []types.Model{}, HasMore: false, TotalCount: 0}, nil).Once() tester.handler.ListModels(c) @@ -274,12 +276,12 @@ func TestOpenAIHandler_ListModels(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("source parameter is case-insensitive", func(t *testing.T) { + t.Run("llm_types parameter is case-insensitive", func(t *testing.T) { tester, c, w := setupTest(t) - tester.WithQuery("source", "CSGHub") + tester.WithQuery("llm_types", "SERVERLESS") tester.mocks.openAIComp.EXPECT(). - ListModels(mock.Anything, "testuser", types.ListModelsReq{Source: "CSGHub"}). + ListModels(mock.Anything, "testuser", types.ListModelsReq{LLMTypes: []string{"SERVERLESS"}}). Return(types.ModelList{Object: "list", Data: []types.Model{}, HasMore: false, TotalCount: 0}, nil).Once() tester.handler.ListModels(c) @@ -1890,9 +1892,12 @@ func TestOpenAIHandler_CreateVideo(t *testing.T) { model := &types.Model{ BaseModel: types.BaseModel{ - ID: "video-model", - Task: "text-to-video", - Metadata: map[string]any{types.MetaKeyLLMType: types.ProviderTypeExternalLLM}, + ID: "video-model", + Task: "text-to-video", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeExternalLLM, + types.MetaKeyPricingConfigured: true, + }, }, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}, Endpoint: downstream.URL + "/v1/videos", diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index ffcd94e5..47558e55 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -15,9 +15,11 @@ const ( // Metadata key constants used when enriching model metadata. const ( - MetaKeyLLMType = "llm_type" - MetaKeyPricing = "pricing" - MetaKeyTasks = "tasks" + MetaKeyLLMType = "llm_type" + MetaKeyPricing = "pricing" + MetaKeyPricingConfigured = "pricing_configured" + MetaKeyRepoPath = "repo_path" + MetaKeyTasks = "tasks" ) // Resource ID format strings for external LLM (model ID) and CSGHub internal (path segment, repo path). @@ -275,11 +277,11 @@ type ModelList struct { // Fields are passed as strings so the component layer can own parsing, // filtering, and pagination behavior consistently. type ListModelsReq struct { - ModelID string `json:"model_id"` - Per string `json:"per"` - Page string `json:"page"` - Source string `json:"source"` // filter by source (csghub for CSGHub models, external for external models) - Task string `json:"task"` // filter by task + ModelID string `json:"model_id"` + Per string `json:"per"` + Page string `json:"page"` + LLMTypes []string `json:"llm_types"` // filter by llm_type + Task string `json:"task"` // filter by task } // UserPreferenceRequest defines the request parameters for UserPreference method @@ -297,20 +299,10 @@ const ( MetaTaskValGuard = "guard" ) -// ModelSource represents the source of a model -type ModelSource string - -const ( - // ModelSourceCSGHub represents models from CSGHub (internal models) - ModelSourceCSGHub ModelSource = "csghub" - // ModelSourceExternal represents models from external providers - ModelSourceExternal ModelSource = "external" -) - // ModelTokenPrice is currency plus per-million-token rate (major units, from accounting cents + sku_unit). type ModelTokenPrice struct { - Currency string `json:"currency,omitempty"` - PricePerMillion float64 `json:"price_per_million,omitempty"` + Currency string `json:"currency"` + PricePerMillion float64 `json:"price_per_million"` } // ModelModalPrice is a unit-based media generation price. diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index c3f32fa2..c48908e0 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -161,7 +161,7 @@ func (s *lLMConfigStoreImpl) IndexWithRepo(ctx context.Context, per, page int, s var configs []*LLMConfig offset := (page - 1) * per - query := s.db.Operator.Core.NewSelect().Model(&configs).Relation("Repo").Relation("Upstreams").Limit(per).Offset(offset) + query := s.db.Operator.Core.NewSelect().Model(&configs).Relation("Repo").Relation("Upstreams.HealthState").Relation("Upstreams.CircuitState").Limit(per).Offset(offset) buildSearchLLMConfigQuery(search, query) err := query.Scan(ctx) if err != nil { diff --git a/builder/store/database/llm_config_test.go b/builder/store/database/llm_config_test.go index 143582a3..6e21de67 100644 --- a/builder/store/database/llm_config_test.go +++ b/builder/store/database/llm_config_test.go @@ -321,6 +321,7 @@ func TestLLMConfigStore_IndexWithRepo(t *testing.T) { repo := database.Repository{ Path: "test-ns-indexwithrepo/test-repo", + GitPath: "models/test-ns-indexwithrepo/test-repo", Name: "test-repo", Nickname: "Test Repo", Description: "A test repository", @@ -372,6 +373,8 @@ func TestLLMConfigStore_IndexWithRepo(t *testing.T) { require.NotNil(t, withRepo.Repo) require.Equal(t, repo.ID, withRepo.Repo.ID) require.Equal(t, "test-repo", withRepo.Repo.Name) + require.Equal(t, "test-ns-indexwithrepo/test-repo", withRepo.Repo.Path) + require.Equal(t, "models/test-ns-indexwithrepo/test-repo", withRepo.Repo.GitPath) require.Equal(t, "Test Repo", withRepo.Repo.Nickname) require.Equal(t, "A test repository", withRepo.Repo.Description)