3232import io .github .jbellis .jvector .util .Bits ;
3333import io .github .jbellis .jvector .util .BoundedLongHeap ;
3434import io .github .jbellis .jvector .util .GrowableLongHeap ;
35- import io .github .jbellis .jvector .util .SparseBits ;
3635import io .github .jbellis .jvector .vector .VectorSimilarityFunction ;
3736import io .github .jbellis .jvector .vector .types .VectorFloat ;
3837import org .agrona .collections .Int2ObjectHashMap ;
4746 * search algorithm, see {@link GraphIndex}.
4847 */
4948public class GraphSearcher implements Closeable {
50- private boolean pruneSearch ;
51-
5249 private GraphIndex .View view ;
5350
5451 // Scratch data structures that are used in each {@link #searchInternal} call. These can be expensive
@@ -64,6 +61,9 @@ public class GraphSearcher implements Closeable {
6461 private SearchScoreProvider scoreProvider ;
6562 private CachingReranker cachingReranker ;
6663
64+ private boolean pruneSearch ;
65+ private final ScoreTracker .ScoreTrackerFactory scoreTrackerFactory ;
66+
6767 private int visitedCount ;
6868 private int expandedCount ;
6969 private int expandedCountBaseLayer ;
@@ -75,14 +75,31 @@ public GraphSearcher(GraphIndex graph) {
7575 this (graph .getView ());
7676 }
7777
78- private GraphSearcher (GraphIndex .View view ) {
78+ /**
79+ * Creates a new graph searcher from the given GraphIndex.View
80+ */
81+ protected GraphSearcher (GraphIndex .View view ) {
7982 this .view = view ;
8083 this .candidates = new NodeQueue (new GrowableLongHeap (100 ), NodeQueue .Order .MAX_HEAP );
8184 this .evictedResults = new NodesUnsorted (100 );
8285 this .approximateResults = new NodeQueue (new BoundedLongHeap (100 ), NodeQueue .Order .MIN_HEAP );
8386 this .rerankedResults = new NodeQueue (new BoundedLongHeap (100 ), NodeQueue .Order .MIN_HEAP );
8487 this .visited = new IntHashSet ();
88+
8589 this .pruneSearch = true ;
90+ this .scoreTrackerFactory = new ScoreTracker .ScoreTrackerFactory ();
91+ }
92+
93+ protected int getVisitedCount () {
94+ return visitedCount ;
95+ }
96+
97+ protected int getExpandedCount () {
98+ return expandedCount ;
99+ }
100+
101+ protected int getExpandedCountBaseLayer () {
102+ return expandedCountBaseLayer ;
86103 }
87104
88105 private void initializeScoreProvider (SearchScoreProvider scoreProvider ) {
@@ -208,6 +225,35 @@ public SearchResult search(SearchScoreProvider scoreProvider,
208225 return new SearchResult (new SearchResult .NodeScore [0 ], 0 , 0 , 0 , 0 , Float .POSITIVE_INFINITY );
209226 }
210227
228+ internalSearch (scoreProvider , entry , topK , rerankK , threshold , acceptOrds );
229+ return reranking (topK , rerankK , rerankFloor );
230+ }
231+
232+ /**
233+ * Performs a search, leaving the results in the internal member variable approximateResults.
234+ * It does not perform reranking.
235+ *
236+ * @param scoreProvider provides functions to return the similarity of a given node to the query vector
237+ * @param entry the entry point to the graph. Assumed to be not null.
238+ * @param topK the number of results to look for. With threshold=0, the search will continue until at least
239+ * `topK` results have been found, or until the entire graph has been searched.
240+ * @param rerankK the number of (approximately-scored) results to rerank before returning the best `topK`.
241+ * @param threshold the minimum similarity (0..1) to accept; 0 will accept everything. May be used
242+ * with a large topK to find (approximately) all nodes above the given threshold.
243+ * If threshold > 0 then the search will stop when it is probabilistically unlikely
244+ * to find more nodes above the threshold, even if `topK` results have not yet been found.
245+ * @param acceptOrds a Bits instance indicating which nodes are acceptable results.
246+ * If {@link Bits#ALL}, all nodes are acceptable.
247+ * It is caller's responsibility to ensure that there are enough acceptable nodes
248+ * that we don't search the entire graph trying to satisfy topK.
249+ */
250+ protected void internalSearch (SearchScoreProvider scoreProvider ,
251+ NodeAtLevel entry ,
252+ int topK ,
253+ int rerankK ,
254+ float threshold ,
255+ Bits acceptOrds )
256+ {
211257 initializeInternal (scoreProvider , entry , acceptOrds );
212258
213259 // Move downward from entry.level to 1
@@ -219,7 +265,7 @@ public SearchResult search(SearchScoreProvider scoreProvider,
219265 }
220266
221267 // Now do the main search at layer 0
222- return resume (topK , rerankK , threshold , rerankFloor ) ;
268+ searchLayer0 (topK , rerankK , threshold ); ;
223269 }
224270
225271 /**
@@ -276,6 +322,7 @@ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bi
276322 this .acceptOrds = Bits .intersectionOf (rawAcceptOrds , view .liveNodes ());
277323
278324 // reset the scratch data structures
325+ approximateResults .clear ();
279326 evictedResults .clear ();
280327 candidates .clear ();
281328 visited .clear ();
@@ -290,6 +337,19 @@ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bi
290337 expandedCountBaseLayer = 0 ;
291338 }
292339
340+ private boolean stopSearch (NodeQueue localCandidates , ScoreTracker scoreTracker , int rerankK , float threshold ) {
341+ float topCandidateScore = localCandidates .topScore ();
342+ // we're done when we have K results and the best candidate is worse than the worst result so far
343+ if (approximateResults .size () >= rerankK && topCandidateScore < approximateResults .topScore ()) {
344+ return true ;
345+ }
346+ // when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results
347+ if (threshold > 0 && scoreTracker .shouldStop ()) {
348+ return true ;
349+ }
350+ return false ;
351+ }
352+
293353 /**
294354 * Performs a single-layer ANN search, expanding from the given candidates queue.
295355 *
@@ -335,24 +395,17 @@ void searchOneLayer(SearchScoreProvider scoreProvider,
335395 approximateResults .setMaxSize (rerankK );
336396
337397 // track scores to predict when we are done with threshold queries
338- var scoreTracker = threshold > 0
339- ? new ScoreTracker .TwoPhaseTracker (threshold )
340- : pruneSearch ? new ScoreTracker .RelaxedMonotonicityTracker (rerankK ) : new ScoreTracker .NoOpTracker ();
398+ var scoreTracker = scoreTrackerFactory .getScoreTracker (pruneSearch , rerankK , threshold );
341399 VectorFloat <?> similarities = null ;
342400
343401 // the main search loop
344402 while (candidates .size () > 0 ) {
345- // we're done when we have K results and the best candidate is worse than the worst result so far
346- float topCandidateScore = candidates .topScore ();
347- if (approximateResults .size () >= rerankK && topCandidateScore < approximateResults .topScore ()) {
348- break ;
349- }
350- // when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results
351- if (threshold > 0 && scoreTracker .shouldStop ()) {
403+ if (stopSearch (candidates , scoreTracker , rerankK , threshold )) {
352404 break ;
353405 }
354406
355407 // process the top candidate
408+ float topCandidateScore = candidates .topScore ();
356409 int topCandidateNode = candidates .pop ();
357410 if (acceptOrdsThisLayer .get (topCandidateNode ) && topCandidateScore >= threshold ) {
358411 addTopCandidate (topCandidateNode , topCandidateScore , rerankK );
@@ -397,7 +450,7 @@ void searchOneLayer(SearchScoreProvider scoreProvider,
397450 }
398451 }
399452
400- SearchResult resume (int topK , int rerankK , float threshold , float rerankFloor ) {
453+ private void searchLayer0 (int topK , int rerankK , float threshold ) {
401454 // rR is persistent to save on allocations
402455 rerankedResults .clear ();
403456 rerankedResults .setMaxSize (topK );
@@ -407,7 +460,9 @@ SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
407460 evictedResults .clear ();
408461
409462 searchOneLayer (scoreProvider , rerankK , threshold , 0 , acceptOrds );
463+ }
410464
465+ SearchResult reranking (int topK , int rerankK , float rerankFloor ) {
411466 // rerank results
412467 assert approximateResults .size () <= rerankK ;
413468 NodeQueue popFromQueue ;
@@ -445,6 +500,11 @@ SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
445500 return new SearchResult (nodes , visitedCount , expandedCount , expandedCountBaseLayer , reranked , worstApproximateInTopK );
446501 }
447502
503+ SearchResult resume (int topK , int rerankK , float threshold , float rerankFloor ) {
504+ searchLayer0 (topK , rerankK , threshold );
505+ return reranking (topK , rerankK , rerankFloor );
506+ }
507+
448508 @ SuppressWarnings ("StatementWithEmptyBody" )
449509 private void addTopCandidate (int topCandidateNode , float topCandidateScore , int rerankK ) {
450510 // add the new node to the results queue, and any evicted node to evictedResults in case we resume later
0 commit comments