Skip to content

Commit 6639992

Browse files
Reducing the number of allocations in GraphSearcher (#501)
* Reduces the number of allocations in GraphSearcher * Minor improvements to code readability
1 parent 51d4f0b commit 6639992

File tree

2 files changed

+163
-27
lines changed

2 files changed

+163
-27
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import io.github.jbellis.jvector.util.Bits;
3333
import io.github.jbellis.jvector.util.BoundedLongHeap;
3434
import io.github.jbellis.jvector.util.GrowableLongHeap;
35-
import io.github.jbellis.jvector.util.SparseBits;
3635
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
3736
import io.github.jbellis.jvector.vector.types.VectorFloat;
3837
import org.agrona.collections.Int2ObjectHashMap;
@@ -47,8 +46,6 @@
4746
* search algorithm, see {@link GraphIndex}.
4847
*/
4948
public class GraphSearcher implements Closeable {
50-
private boolean pruneSearch;
51-
5249
private GraphIndex.View view;
5350

5451
// Scratch data structures that are used in each {@link #searchInternal} call. These can be expensive
@@ -64,6 +61,9 @@ public class GraphSearcher implements Closeable {
6461
private SearchScoreProvider scoreProvider;
6562
private CachingReranker cachingReranker;
6663

64+
private boolean pruneSearch;
65+
private final ScoreTracker.ScoreTrackerFactory scoreTrackerFactory;
66+
6767
private int visitedCount;
6868
private int expandedCount;
6969
private int expandedCountBaseLayer;
@@ -75,14 +75,31 @@ public GraphSearcher(GraphIndex graph) {
7575
this(graph.getView());
7676
}
7777

78-
private GraphSearcher(GraphIndex.View view) {
78+
/**
79+
* Creates a new graph searcher from the given GraphIndex.View
80+
*/
81+
protected GraphSearcher(GraphIndex.View view) {
7982
this.view = view;
8083
this.candidates = new NodeQueue(new GrowableLongHeap(100), NodeQueue.Order.MAX_HEAP);
8184
this.evictedResults = new NodesUnsorted(100);
8285
this.approximateResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP);
8386
this.rerankedResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP);
8487
this.visited = new IntHashSet();
88+
8589
this.pruneSearch = true;
90+
this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory();
91+
}
92+
93+
protected int getVisitedCount() {
94+
return visitedCount;
95+
}
96+
97+
protected int getExpandedCount() {
98+
return expandedCount;
99+
}
100+
101+
protected int getExpandedCountBaseLayer() {
102+
return expandedCountBaseLayer;
86103
}
87104

88105
private void initializeScoreProvider(SearchScoreProvider scoreProvider) {
@@ -208,6 +225,35 @@ public SearchResult search(SearchScoreProvider scoreProvider,
208225
return new SearchResult(new SearchResult.NodeScore[0], 0, 0, 0, 0, Float.POSITIVE_INFINITY);
209226
}
210227

228+
internalSearch(scoreProvider, entry, topK, rerankK, threshold, acceptOrds);
229+
return reranking(topK, rerankK, rerankFloor);
230+
}
231+
232+
/**
233+
* Performs a search, leaving the results in the internal member variable approximateResults.
234+
* It does not perform reranking.
235+
*
236+
* @param scoreProvider provides functions to return the similarity of a given node to the query vector
237+
* @param entry the entry point to the graph. Assumed to be not null.
238+
* @param topK the number of results to look for. With threshold=0, the search will continue until at least
239+
* `topK` results have been found, or until the entire graph has been searched.
240+
* @param rerankK the number of (approximately-scored) results to rerank before returning the best `topK`.
241+
* @param threshold the minimum similarity (0..1) to accept; 0 will accept everything. May be used
242+
* with a large topK to find (approximately) all nodes above the given threshold.
243+
* If threshold > 0 then the search will stop when it is probabilistically unlikely
244+
* to find more nodes above the threshold, even if `topK` results have not yet been found.
245+
* @param acceptOrds a Bits instance indicating which nodes are acceptable results.
246+
* If {@link Bits#ALL}, all nodes are acceptable.
247+
* It is caller's responsibility to ensure that there are enough acceptable nodes
248+
* that we don't search the entire graph trying to satisfy topK.
249+
*/
250+
protected void internalSearch(SearchScoreProvider scoreProvider,
251+
NodeAtLevel entry,
252+
int topK,
253+
int rerankK,
254+
float threshold,
255+
Bits acceptOrds)
256+
{
211257
initializeInternal(scoreProvider, entry, acceptOrds);
212258

213259
// Move downward from entry.level to 1
@@ -219,7 +265,7 @@ public SearchResult search(SearchScoreProvider scoreProvider,
219265
}
220266

221267
// Now do the main search at layer 0
222-
return resume(topK, rerankK, threshold, rerankFloor);
268+
searchLayer0(topK, rerankK, threshold);;
223269
}
224270

225271
/**
@@ -276,6 +322,7 @@ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bi
276322
this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes());
277323

278324
// reset the scratch data structures
325+
approximateResults.clear();
279326
evictedResults.clear();
280327
candidates.clear();
281328
visited.clear();
@@ -290,6 +337,19 @@ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bi
290337
expandedCountBaseLayer = 0;
291338
}
292339

340+
private boolean stopSearch(NodeQueue localCandidates, ScoreTracker scoreTracker, int rerankK, float threshold) {
341+
float topCandidateScore = localCandidates.topScore();
342+
// we're done when we have K results and the best candidate is worse than the worst result so far
343+
if (approximateResults.size() >= rerankK && topCandidateScore < approximateResults.topScore()) {
344+
return true;
345+
}
346+
// when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results
347+
if (threshold > 0 && scoreTracker.shouldStop()) {
348+
return true;
349+
}
350+
return false;
351+
}
352+
293353
/**
294354
* Performs a single-layer ANN search, expanding from the given candidates queue.
295355
*
@@ -335,24 +395,17 @@ void searchOneLayer(SearchScoreProvider scoreProvider,
335395
approximateResults.setMaxSize(rerankK);
336396

337397
// track scores to predict when we are done with threshold queries
338-
var scoreTracker = threshold > 0
339-
? new ScoreTracker.TwoPhaseTracker(threshold)
340-
: pruneSearch ? new ScoreTracker.RelaxedMonotonicityTracker(rerankK) : new ScoreTracker.NoOpTracker();
398+
var scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold);
341399
VectorFloat<?> similarities = null;
342400

343401
// the main search loop
344402
while (candidates.size() > 0) {
345-
// we're done when we have K results and the best candidate is worse than the worst result so far
346-
float topCandidateScore = candidates.topScore();
347-
if (approximateResults.size() >= rerankK && topCandidateScore < approximateResults.topScore()) {
348-
break;
349-
}
350-
// when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results
351-
if (threshold > 0 && scoreTracker.shouldStop()) {
403+
if (stopSearch(candidates, scoreTracker, rerankK, threshold)) {
352404
break;
353405
}
354406

355407
// process the top candidate
408+
float topCandidateScore = candidates.topScore();
356409
int topCandidateNode = candidates.pop();
357410
if (acceptOrdsThisLayer.get(topCandidateNode) && topCandidateScore >= threshold) {
358411
addTopCandidate(topCandidateNode, topCandidateScore, rerankK);
@@ -397,7 +450,7 @@ void searchOneLayer(SearchScoreProvider scoreProvider,
397450
}
398451
}
399452

400-
SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
453+
private void searchLayer0(int topK, int rerankK, float threshold) {
401454
// rR is persistent to save on allocations
402455
rerankedResults.clear();
403456
rerankedResults.setMaxSize(topK);
@@ -407,7 +460,9 @@ SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
407460
evictedResults.clear();
408461

409462
searchOneLayer(scoreProvider, rerankK, threshold, 0, acceptOrds);
463+
}
410464

465+
SearchResult reranking(int topK, int rerankK, float rerankFloor) {
411466
// rerank results
412467
assert approximateResults.size() <= rerankK;
413468
NodeQueue popFromQueue;
@@ -445,6 +500,11 @@ SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
445500
return new SearchResult(nodes, visitedCount, expandedCount, expandedCountBaseLayer, reranked, worstApproximateInTopK);
446501
}
447502

503+
SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
504+
searchLayer0(topK, rerankK, threshold);
505+
return reranking(topK, rerankK, rerankFloor);
506+
}
507+
448508
@SuppressWarnings("StatementWithEmptyBody")
449509
private void addTopCandidate(int topCandidateNode, float topCandidateScore, int rerankK) {
450510
// add the new node to the results queue, and any evicted node to evictedResults in case we resume later

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,54 @@
1616

1717
package io.github.jbellis.jvector.graph;
1818

19+
import io.github.jbellis.jvector.util.ArrayUtil;
1920
import io.github.jbellis.jvector.util.BoundedLongHeap;
2021
import org.apache.commons.math3.stat.StatUtils;
2122

2223
import static io.github.jbellis.jvector.util.NumericUtils.floatToSortableInt;
2324
import static io.github.jbellis.jvector.util.NumericUtils.sortableIntToFloat;
2425

25-
interface ScoreTracker {
26+
public interface ScoreTracker {
27+
class ScoreTrackerFactory {
28+
private TwoPhaseTracker twoPhaseTracker;
29+
private RelaxedMonotonicityTracker relaxedMonotonicityTracker;
30+
private NoOpTracker noOpTracker;
31+
32+
ScoreTrackerFactory() {
33+
twoPhaseTracker = null;
34+
relaxedMonotonicityTracker = null;
35+
noOpTracker = null;
36+
}
37+
38+
public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float threshold) {
39+
// track scores to predict when we are done with threshold queries
40+
final ScoreTracker scoreTracker;
41+
42+
if (threshold > 0) {
43+
if (twoPhaseTracker == null) {
44+
twoPhaseTracker = new ScoreTracker.TwoPhaseTracker();
45+
} else {
46+
twoPhaseTracker.reset(threshold);
47+
}
48+
scoreTracker = twoPhaseTracker;
49+
} else {
50+
if (pruneSearch) {
51+
if (relaxedMonotonicityTracker == null) {
52+
relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker();
53+
} else {
54+
relaxedMonotonicityTracker.reset(rerankK);
55+
}
56+
scoreTracker = relaxedMonotonicityTracker;
57+
} else {
58+
if (noOpTracker == null) {
59+
noOpTracker = new ScoreTracker.NoOpTracker();
60+
}
61+
scoreTracker = noOpTracker;
62+
}
63+
}
64+
return scoreTracker;
65+
}
66+
}
2667

2768
ScoreTracker NO_OP = new NoOpTracker();
2869

@@ -59,11 +100,22 @@ class TwoPhaseTracker implements ScoreTracker {
59100
// observation count
60101
private int observationCount;
61102

62-
private final double threshold;
103+
private double threshold;
63104

64105
TwoPhaseTracker(double threshold) {
65106
this.recentScores = new double[RECENT_SCORES_TRACKED];
66107
this.bestScores = new BoundedLongHeap(BEST_SCORES_TRACKED);
108+
this.observationCount = 0;
109+
this.threshold = threshold;
110+
}
111+
112+
TwoPhaseTracker() {
113+
this(0);
114+
}
115+
116+
void reset(double threshold) {
117+
this.bestScores.clear();
118+
this.observationCount = 0;
67119
this.threshold = threshold;
68120
}
69121

@@ -108,10 +160,13 @@ public boolean shouldStop() {
108160
* (approximately the 96th percentile of the Normal distribution).
109161
*/
110162
class RelaxedMonotonicityTracker implements ScoreTracker {
111-
static final double SIGMA_FACTOR = 1.75;
163+
private static final double SIGMA_FACTOR = 1.75;
164+
165+
private static final int BASE_RECENT_SCORES_SIZE = 200;
112166

113167
// a sliding window of recent scores
114-
private final double[] recentScores;
168+
private double[] recentScores;
169+
private int recentScoresSize;
115170
private int recentEntryIndex;
116171

117172
// Heap of the best scores seen so far
@@ -132,11 +187,32 @@ class RelaxedMonotonicityTracker implements ScoreTracker {
132187
* the results anymore. An empirical rule of thumb is bestScoresTracked=rerankK.
133188
*/
134189
RelaxedMonotonicityTracker(int bestScoresTracked) {
190+
this.recentScoresSize = getRecentScoresSize(bestScoresTracked);
191+
this.recentScores = new double[this.recentScoresSize];
192+
this.bestScores = new BoundedLongHeap(bestScoresTracked);
193+
this.observationCount = 0;
194+
this.mean = 0;
195+
this.dSquared = 0;
196+
}
197+
198+
RelaxedMonotonicityTracker() {
199+
this(100);
200+
}
201+
202+
private static int getRecentScoresSize(int bestScoresTracked) {
135203
// A quick empirical study yields that the number of recent scores
136204
// that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2)
137205
int factor = (int) Math.round(Math.sqrt(bestScoresTracked / 2.0));
138-
this.recentScores = new double[200 * factor];
139-
this.bestScores = new BoundedLongHeap(bestScoresTracked);
206+
return BASE_RECENT_SCORES_SIZE * factor;
207+
}
208+
209+
void reset(int bestScoresTracked) {
210+
this.recentScoresSize = getRecentScoresSize(bestScoresTracked);
211+
if (this.recentScoresSize > recentScores.length) {
212+
recentScores = ArrayUtil.grow(recentScores, this.recentScoresSize);
213+
}
214+
this.bestScores.clear();
215+
this.observationCount = 0;
140216
this.mean = 0;
141217
this.dSquared = 0;
142218
}
@@ -150,7 +226,7 @@ public void track(float score) {
150226
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online
151227
// and
152228
// https://nestedsoftware.com/2019/09/26/incremental-average-and-standard-deviation-with-sliding-window-470k.176143.html
153-
if (observationCount <= this.recentScores.length) {
229+
if (observationCount <= this.recentScoresSize) {
154230
// if the buffer is not full yet, use standard Welford method
155231
var meanDelta = (score - this.mean) / observationCount;
156232
var newMean = this.mean + meanDelta;
@@ -163,7 +239,7 @@ public void track(float score) {
163239
} else {
164240
// once the buffer is full, adjust Welford method for window size
165241
var oldScore = recentScores[recentEntryIndex];
166-
var meanDelta = (score - oldScore) / this.recentScores.length;
242+
var meanDelta = (score - oldScore) / this.recentScoresSize;
167243
var newMean = this.mean + meanDelta;
168244

169245
var dSquaredDelta = ((score - oldScore) * (score - newMean + oldScore - this.mean));
@@ -173,21 +249,21 @@ public void track(float score) {
173249
this.dSquared = newDSquared;
174250
}
175251
recentScores[recentEntryIndex] = score;
176-
recentEntryIndex = (recentEntryIndex + 1) % this.recentScores.length;
252+
recentEntryIndex = (recentEntryIndex + 1) % this.recentScoresSize;
177253
}
178254

179255
@Override
180256
public boolean shouldStop() {
181257
// don't stop if we don't have enough data points
182-
if (observationCount < this.recentScores.length) {
258+
if (observationCount < this.recentScoresSize) {
183259
return false;
184260
}
185261

186262
// We're in phase 2 if the q-th percentile of the recent scores evaluated,
187263
// mean + SIGMA_FACTOR * sqrt(variance),
188264
// is lower than the worst of the best scores seen.
189265
// (paper suggests using the median of recent scores, but experimentally that is too prone to false positives)
190-
double std = Math.sqrt(this.dSquared / (this.recentScores.length - 1));
266+
double std = Math.sqrt(this.dSquared / (this.recentScoresSize - 1));
191267
double windowPercentile = this.mean + SIGMA_FACTOR * std;
192268
double worstBestScore = sortableIntToFloat((int) bestScores.top());
193269
return windowPercentile < worstBestScore;

0 commit comments

Comments
 (0)