Skip to content

Commit e08e641

Browse files
committed
test(checkmodel): 添加跨类型模型检查测试
新增 TestCrossTypeChecks 测试函数,用于验证不同模型类型(embedding 和 rerank)的检查逻辑
1 parent f30b4cd commit e08e641

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

test/checkmodel_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"os"
7+
"regexp"
78
"strings"
89
"testing"
910
"time"
@@ -207,3 +208,114 @@ func TestCheckModelCombinations(t *testing.T) {
207208
}
208209
}
209210
}
211+
212+
func TestCrossTypeChecks(t *testing.T) {
213+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
214+
defer cancel()
215+
mk := usecase.NewModelKit(nil)
216+
217+
apiKeyBZ := strings.TrimSpace(os.Getenv("baizhiapikey"))
218+
if apiKeyBZ == "" {
219+
t.Fatalf("missing baizhiapikey")
220+
}
221+
222+
listBaseURL := "https://model-square.app.baizhi.cloud/v1"
223+
224+
// 1. 获取模型列表
225+
ml, err := mk.ModelList(ctx, &domain.ModelListReq{
226+
Provider: string(consts.ModelProviderBaiZhiCloud),
227+
BaseURL: listBaseURL,
228+
APIKey: apiKeyBZ,
229+
Type: "",
230+
})
231+
if err != nil {
232+
t.Fatalf("list model error: %v", err)
233+
}
234+
if ml == nil || len(ml.Models) == 0 {
235+
if ml != nil && ml.Error != "" {
236+
t.Fatalf("list model error: %s", ml.Error)
237+
}
238+
t.Fatalf("no models returned")
239+
}
240+
241+
// 2. 测试所有模型
242+
// Helper functions to guess model type
243+
getLowerBaseModelName := func(id string) string {
244+
parts := strings.Split(id, "/")
245+
return strings.ToLower(parts[len(parts)-1])
246+
}
247+
248+
isRerank := func(modelID string) bool {
249+
mid := getLowerBaseModelName(modelID)
250+
re := regexp.MustCompile(`(?i)(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)`)
251+
return re.MatchString(mid)
252+
}
253+
254+
isEmbedding := func(modelID string) bool {
255+
if isRerank(modelID) {
256+
return false
257+
}
258+
mid := getLowerBaseModelName(modelID)
259+
re := regexp.MustCompile(`(?i)(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings|voyage-)`)
260+
return re.MatchString(mid)
261+
}
262+
263+
checkTypes := []string{"embedding", "rerank"}
264+
for _, m := range ml.Models {
265+
for _, ct := range checkTypes {
266+
testName := fmt.Sprintf(
267+
"provider=%s base=%s model=%s type=%s apiKey=present",
268+
string(consts.ModelProviderBaiZhiCloud),
269+
listBaseURL,
270+
m.Model,
271+
ct,
272+
)
273+
274+
// Determine expected outcome
275+
expectSuccess := false
276+
if ct == "embedding" && isEmbedding(m.Model) {
277+
expectSuccess = true
278+
} else if ct == "rerank" && isRerank(m.Model) {
279+
expectSuccess = true
280+
}
281+
282+
t.Run(testName, func(t *testing.T) {
283+
// t.Parallel() // Optional: User didn't ask for parallel, and it might rate limit. Safer to run sequential.
284+
resp, err := mk.CheckModel(ctx, &domain.CheckModelReq{
285+
Provider: string(consts.ModelProviderBaiZhiCloud),
286+
Model: m.Model,
287+
BaseURL: listBaseURL,
288+
APIKey: apiKeyBZ,
289+
Type: ct,
290+
})
291+
292+
respError := ""
293+
respContent := ""
294+
if resp != nil {
295+
respError = resp.Error
296+
respContent = resp.Content
297+
}
298+
299+
logMsg := fmt.Sprintf("RespError: %q, RespContent: %q", respError, respContent)
300+
301+
if expectSuccess {
302+
// Expect Success
303+
if err != nil {
304+
t.Errorf("FAIL (Expected Success): %s; error: %v; %s", testName, err, logMsg)
305+
} else if respError != "" {
306+
t.Errorf("FAIL (Expected Success): %s; %s", testName, logMsg)
307+
} else {
308+
t.Logf("PASS: %s; %s", testName, logMsg)
309+
}
310+
} else {
311+
// Expect Failure
312+
if err == nil && respError == "" {
313+
t.Errorf("FAIL (Expected Failure for mismatched type): %s; got success but expected error; %s", testName, logMsg)
314+
} else {
315+
fmt.Printf("PASS: %s; correctly failed as expected for mismatched type. Error: %v; %s\n", testName, err, logMsg)
316+
}
317+
}
318+
})
319+
}
320+
}
321+
}

0 commit comments

Comments
 (0)