Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions aigateway/component/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,30 @@ func filterByModelID(query string) modelFilter {
}
}

func filterBySource(source string) modelFilter {
func filterByLLMTypes(llmTypes []string) modelFilter {
allowedTypes := parseLLMTypes(llmTypes)
return func(m *types.Model) bool {
switch source {
case string(types.ModelSourceCSGHub):
return m.CSGHubModelID != ""
case string(types.ModelSourceExternal):
return m.Provider != ""
default:
if len(allowedTypes) == 0 {
return true
}
if m == nil || m.Metadata == nil {
return false
}
llmType, _ := m.Metadata[types.MetaKeyLLMType].(string)
return allowedTypes[strings.ToLower(strings.TrimSpace(llmType))]
}
}

func parseLLMTypes(llmTypes []string) map[string]bool {
allowedTypes := make(map[string]bool)
for _, rawType := range llmTypes {
llmType := strings.ToLower(strings.TrimSpace(rawType))
if llmType == "" {
continue
}
allowedTypes[llmType] = true
}
return allowedTypes
}

func filterByTask(task string) modelFilter {
Expand Down Expand Up @@ -233,13 +246,13 @@ func applyFilters(models []types.Model, filters []modelFilter) []types.Model {
}

func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList {
var filters []modelFilter
filters := modelListDefaultFilters()

if searchQuery := strings.ToLower(req.ModelID); searchQuery != "" {
filters = append(filters, filterByModelID(searchQuery))
}
if source := strings.ToLower(req.Source); source != "" {
filters = append(filters, filterBySource(source))
if len(req.LLMTypes) > 0 {
filters = append(filters, filterByLLMTypes(req.LLMTypes))
}
if task := strings.ToLower(req.Task); task != "" {
filters = append(filters, filterByTask(task))
Expand Down Expand Up @@ -328,7 +341,8 @@ func (c *openaiComponentImpl) getCSGHubModels(ctx context.Context, userID int64)
SupportFunctionCall: supportFunctionCall,
Task: string(deploy.Task),
Metadata: map[string]any{
types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type),
types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type),
types.MetaKeyRepoPath: deploy.Repository.Path,
},
},
InternalModelInfo: types.InternalModelInfo{
Expand Down Expand Up @@ -375,30 +389,35 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model
page := 1
var models []types.Model
for {
extModels, _, err := m.extllmStore.Index(c, per, page, search)
extModels, _, err := m.extllmStore.IndexWithRepo(c, per, page, search)
if err != nil {
slog.Error("failed to get external models", "error", err)
break
}

for _, extModel := range extModels {
// Extract tasks from metadata if present
metadata := maps.Clone(extModel.Metadata)
if metadata == nil {
metadata = map[string]any{}
}
task := ""
if extModel.Metadata != nil {
if tasks, ok := extModel.Metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 {
tasksStrings := make([]string, 0, len(tasks))
for _, t := range tasks {
if s, ok := t.(string); ok {
tasksStrings = append(tasksStrings, s)
}
if tasks, ok := metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 {
tasksStrings := make([]string, 0, len(tasks))
for _, t := range tasks {
if s, ok := t.(string); ok {
tasksStrings = append(tasksStrings, s)
}
task = strings.Join(tasksStrings, ",")
}
task = strings.Join(tasksStrings, ",")
}
if extModel.Metadata == nil {
extModel.Metadata = map[string]any{}
if extModel.RepoID != 0 {
if extModel.Repo != nil && extModel.Repo.Path != "" {
metadata[types.MetaKeyRepoPath] = extModel.Repo.Path
} else {
slog.WarnContext(c, "llm config repo relation unavailable", "llm_config_id", extModel.ID, "repo_id", extModel.RepoID)
}
}
extModel.Metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM
metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM
// Convert relational upstreams to types.UpstreamConfig for routing
upstreams := dbUpstreamsToConfigs(extModel.Upstreams)
provider := extModel.PrimaryProvider()
Expand All @@ -407,7 +426,7 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model
Object: "model",
ID: extModel.ModelName,
OwnedBy: provider,
Metadata: extModel.Metadata,
Metadata: metadata,
Task: task,
},
Endpoint: router.FirstEnabledUpstream(upstreams),
Expand Down
85 changes: 70 additions & 15 deletions aigateway/component/openai_ce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
SecureLevel: commontypes.EndpointPublic,
Repository: &database.Repository{
HFPath: "hf-model2",
Path: "model2",
},
User: &database.User{
Username: "serverless-owner",
Expand All @@ -84,7 +85,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {

mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(0)).
Return(deploys, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil)

var wg sync.WaitGroup
Expand All @@ -104,8 +105,10 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
require.Len(t, models, 2)
assert.Equal(t, "model1:svc1", models[0].ID)
assert.Equal(t, "publicuser", models[0].OwnedBy)
assert.Equal(t, "model1", models[0].Metadata[types.MetaKeyRepoPath])
assert.Equal(t, "hf-model2:svc2", models[1].ID)
assert.Equal(t, "OpenCSG", models[1].OwnedBy)
assert.Equal(t, "model2", models[1].Metadata[types.MetaKeyRepoPath])
wg.Wait()
})

Expand All @@ -117,7 +120,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
}
mockUserStore.EXPECT().FindByUsername(mock.Anything, "testuser").
Return(*user, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil)
now := time.Now()
deploys := []database.Deploy{
Expand Down Expand Up @@ -145,6 +148,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
SecureLevel: commontypes.EndpointPublic,
Repository: &database.Repository{
HFPath: "hf-model2",
Path: "model2",
},
User: &database.User{
Username: "testuser",
Expand Down Expand Up @@ -179,12 +183,14 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
assert.Equal(t, "testuser", models[0].OwnedBy)
assert.Equal(t, "endpoint1", models[0].Endpoint)
assert.Equal(t, "text-generation", models[0].Task)
assert.Equal(t, "model1", models[0].Metadata[types.MetaKeyRepoPath])

// Verify second model (serverless)
assert.Equal(t, "hf-model2:svc2", models[1].ID)
assert.Equal(t, "OpenCSG", models[1].OwnedBy)
assert.Equal(t, "endpoint2", models[1].Endpoint)
assert.Equal(t, "text-to-image", models[1].Task)
assert.Equal(t, "model2", models[1].Metadata[types.MetaKeyRepoPath])
wg.Wait()
})

Expand All @@ -196,7 +202,7 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) {
}
mockUserStore.EXPECT().FindByUsername(mock.Anything, "testuser").
Return(*user, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil)

now := time.Now()
Expand Down Expand Up @@ -294,7 +300,7 @@ func TestOpenAIComponent_GetAvailableModels_CacheUsesModelSnapshot(t *testing.T)

mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(0)).
Return(deploys, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil).Once()

firstWriteStarted := make(chan struct{})
Expand Down Expand Up @@ -367,7 +373,7 @@ func TestOpenAIComponent_ListModels_CacheUsesOriginalID(t *testing.T) {
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search).
Return([]*database.LLMConfig{
{
ID: 1,
Expand Down Expand Up @@ -463,7 +469,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) {
Return(*user, nil)
mockCache.EXPECT().Exists(mock.Anything, modelCacheKey).
Return(0, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil).Once()
now := time.Now()
deploys := []database.Deploy{
Expand Down Expand Up @@ -513,7 +519,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) {
Return("", redis.Nil).Once()
// Cache miss: GetModelByID falls through to GetAvailableModels, which calls getCSGHubModels and getExternalModels
mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, int64(1)).Return([]database.Deploy{}, nil).Once()
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, mock.Anything).
Return([]*database.LLMConfig{}, 0, nil).Once()
model, err := comp.GetModelByID(context.Background(), "testuser", "nonexistent:svc")
assert.NoError(t, err)
Expand All @@ -540,7 +546,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) {
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search).
Return([]*database.LLMConfig{
{
ID: 1,
Expand Down Expand Up @@ -639,7 +645,7 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) {
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, search).
mockLLMConfigStore.EXPECT().IndexWithRepo(mock.Anything, 50, 1, search).
Return([]*database.LLMConfig{
{
ID: 1,
Expand Down Expand Up @@ -685,10 +691,10 @@ func TestOpenAIComponent_saveModelsToCache(t *testing.T) {
},
Endpoint: "http://test-endpoint",
ExternalModelInfo: types.ExternalModelInfo{
Provider: "openai",
AuthHead: "Bearer test-token",
NeedSensitiveCheck: true,
},
Provider: "openai",
AuthHead: "Bearer test-token",
NeedSensitiveCheck: true,
},
},
}

Expand Down Expand Up @@ -794,7 +800,7 @@ func TestOpenAIComponent_ExtGetAvailableModels_Error(t *testing.T) {
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search).
mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search).
Return(nil, 0, errors.New("test error")).Once()
user := &database.User{
ID: 1,
Expand Down Expand Up @@ -823,13 +829,23 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) {
extllmStore: mockLLMConfigStore,
modelListCache: mockCache,
}
originalMetadata := map[string]any{
types.MetaKeyTasks: []any{"text-generation", "text-to-image"},
}
mockModels := []*database.LLMConfig{
{
ID: 1,
ModelName: "test-model-1",
Type: 16,
Enabled: true,
Provider: "OpenAI",
Metadata: originalMetadata,
RepoID: 100,
Repo: &database.Repository{
ID: 100,
Path: "test-ns/test-model-1",
GitPath: "models/test-ns/test-model-1",
},
},
}
user := &database.User{
Expand All @@ -848,7 +864,7 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) {
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search).Return(mockModels, 1, nil)
mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search).Return(mockModels, 1, nil)
mockCache.EXPECT().HSet(mock.Anything, modelCacheKey, "test-model-1", mock.Anything).
Return(nil).Once()
var wg sync.WaitGroup
Expand All @@ -863,5 +879,44 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) {
require.Nil(t, err)
require.Len(t, models, 1)
require.Equal(t, "test-model-1", models[0].ID)
require.Equal(t, "text-generation,text-to-image", models[0].Task)
require.Equal(t, "test-ns/test-model-1", models[0].Metadata[types.MetaKeyRepoPath])
require.Equal(t, types.ProviderTypeExternalLLM, models[0].Metadata[types.MetaKeyLLMType])
require.NotContains(t, originalMetadata, types.MetaKeyRepoPath)
require.NotContains(t, originalMetadata, types.MetaKeyLLMType)
wg.Wait()
}

func TestOpenAIComponent_GetExternalModelsWithoutRepoOmitsRepoPath(t *testing.T) {
ctx := context.Background()
mockLLMConfigStore := mockdb.NewMockLLMConfigStore(t)
component := &openaiComponentImpl{
extllmStore: mockLLMConfigStore,
}
searchType := 16
enabled := true
search := &commontypes.SearchLLMConfig{
Type: &searchType,
Enabled: &enabled,
SortBy: "model_size_b",
SortOrder: "desc",
}
mockLLMConfigStore.EXPECT().IndexWithRepo(ctx, 50, 1, search).Return([]*database.LLMConfig{
{
ID: 1,
ModelName: "model-without-repo",
Type: 16,
Enabled: true,
Provider: "OpenAI",
Metadata: map[string]any{"existing": "value"},
},
}, 1, nil).Once()

models := component.getExternalModels(ctx)

require.Len(t, models, 1)
require.Equal(t, "model-without-repo", models[0].ID)
require.Equal(t, "value", models[0].Metadata["existing"])
require.Equal(t, types.ProviderTypeExternalLLM, models[0].Metadata[types.MetaKeyLLMType])
require.NotContains(t, models[0].Metadata, types.MetaKeyRepoPath)
}
7 changes: 7 additions & 0 deletions aigateway/component/openai_model_filter_ce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !ee && !saas

package component

func modelListDefaultFilters() []modelFilter {
return nil
}
Loading
Loading