diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda756..cf5a5675e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -24,6 +24,8 @@ import io.github.jbellis.jvector.util.FixedBitSet; import io.github.jbellis.jvector.util.IntMap; +import java.util.function.IntPredicate; + import static java.lang.Math.min; /** @@ -56,6 +58,7 @@ public ConcurrentNeighborMap(IntMap neighbors, DiversityProvider public void insertEdge(int fromId, int toId, float score, float overflow) { while (true) { var old = neighbors.get(fromId); + if (old == null) return; // node was concurrently removed via removeNode() — skip backlink var next = old.insert(toId, score, overflow, this); if (next == null || neighbors.compareAndPut(fromId, old, next)) { break; @@ -101,6 +104,21 @@ public void replaceDeletedNeighbors(int nodeId, BitSet toDelete, NodeArray candi } } + /** + * Algorithm 6 (IP-DiskANN): pure dangling-edge filter. + * Removes every out-neighbor of {@code nodeId} for which {@code isDead} returns true. + * No replacement candidates, no diversity pruning on survivors — just a structural sweep. + * Safe to call concurrently with inserts and other deletions via CAS. + */ + public void removeDeadEdges(int nodeId, IntPredicate isDead) { + while (true) { + var old = neighbors.get(nodeId); + if (old == null) return; // node itself was concurrently removed + var next = old.removeDeadEdges(isDead); + if (next == old || neighbors.compareAndPut(nodeId, old, next)) break; + } + } + public Neighbors insertDiverse(int nodeId, NodeArray candidates) { while (true) { var old = neighbors.get(nodeId); @@ -238,6 +256,26 @@ private Neighbors replaceDeletedNeighbors(Bits deletedNodes, NodeArray candidate return new Neighbors(nodeId, merged); } + /** + * Algorithm 6 (IP-DiskANN): pure structural filter with no diversity pruning. + * Returns {@code this} unchanged if no neighbors are dead (avoids allocation). + */ + private Neighbors removeDeadEdges(IntPredicate isDead) { + // Fast path: check if any neighbor is dead before allocating + boolean anyDead = false; + for (int i = 0; i < size(); i++) { + if (isDead.test(getNode(i))) { anyDead = true; break; } + } + if (!anyDead) return this; + + var live = new NodeArray(size()); + for (int i = 0; i < size(); i++) { + int n = getNode(i); + if (!isDead.test(n)) live.addInOrder(n, getScore(i)); + } + return new Neighbors(nodeId, live); + } + /** * For each candidate (going from best to worst), select it only if it is closer to target than it * is to any of the already-selected candidates. This is maintained whether those other neighbors diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9e366676c..d623da7a0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -25,10 +25,7 @@ import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.util.ExceptionUtils; -import io.github.jbellis.jvector.util.ExplicitThreadLocal; -import io.github.jbellis.jvector.util.PhysicalCoreExecutor; +import io.github.jbellis.jvector.util.*; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.slf4j.Logger; @@ -84,6 +81,28 @@ public class GraphIndexBuilder implements Closeable { private final Random rng; + /** + * Fraction of total live nodes that must be deleted since the last consolidation + * before {@link #consolidateDanglingEdges()} is auto-triggered inside + * {@link #markNodeDeleted(int)}. Matches the ablation sweet-spot from Table 4 of the + * IP-DiskANN paper (arXiv:2502.13826). Configurable via + * {@link #setConsolidationThreshold(double)}. + */ + private volatile double consolidationThreshold = 0.20; + + /** Running count of markNodeDeleted calls since construction. */ + private final AtomicInteger totalDeletions = new AtomicInteger(0); + + /** Value of totalDeletions at the time of the last consolidateDanglingEdges() run. */ + private final AtomicInteger lastConsolidationAt = new AtomicInteger(0); + + /** + * Beam width for the GreedySearch used during in-place deletion repair (l_d in Algorithm 5, + * IP-DiskANN paper). Controls how deeply we search for approximate in-neighbors of the deleted + * node. Higher values improve in-neighbor recall at the cost of deletion latency. + */ + private static final int DELETION_LD = 128; + /** * Reads all the vectors from vector values, builds a graph connecting them by their dense * ordinals, using the given hyperparameter settings, and returns the resulting graph. @@ -678,8 +697,198 @@ public void setEntryPoint(int level, int node) { graph.updateEntryNode(new NodeAtLevel(level, node)); } + /** + * Sets the fraction of live nodes that must have been deleted since the last + * Algorithm 6 consolidation before the next one auto-triggers inside + * {@link #markNodeDeleted(int)}. Default is 0.20 (20%), matching the IP-DiskANN + * paper ablation (arXiv:2502.13826, Table 4). + * + * @param threshold a value in (0, 1]; use 1.0 to effectively disable auto-triggering + */ + public void setConsolidationThreshold(double threshold) { + if (threshold <= 0 || threshold > 1.0) { + throw new IllegalArgumentException("consolidationThreshold must be in (0, 1]"); + } + this.consolidationThreshold = threshold; + } + + /** + * Algorithm 6 (IP-DiskANN, arXiv:2502.13826): sweeps every live node at every level + * and removes out-edges that point to structurally absent nodes. + *

+ * This is the complement of {@link #markNodeDeleted}: Algorithm 5 repairs the + * immediate neighborhood of a deleted node at deletion time, but cannot guarantee + * that all reverse pointers (in-neighbors elsewhere in the graph) have been cleaned + * up. Over time, these dangling edges accumulate and degrade search recall by routing + * traversal into dead-end nodes. This method eliminates them in a single parallel sweep. + *

+ * Unlike {@link #removeDeletedNodes()}, this method: + *

+ *

+ * Complexity: O(N × M) where N = live node count and M = average out-degree. + * No distance calculations are performed. + */ + public void consolidateDanglingEdges() { + parallelExecutor.submit(() -> { + IntStream.range(0, graph.getMaxLevel() + 1).forEach(level -> { + graph.nodeStream(level).parallel().forEach(node -> { + graph.removeDeadEdges(level, node); + }); + }); + }).join(); + // reset the consolidation counter so the next window is measured from now + lastConsolidationAt.set(totalDeletions.get()); + logger.debug("consolidateDanglingEdges complete — totalDeletions={}", totalDeletions.get()); + } + public void markNodeDeleted(int node) { graph.markDeleted(node); + updateEntryPointIfNeeded(node); + repairDeletionViaSearch(node); + + // Auto-trigger Algorithm 6 when the number of deletions since the last + // consolidation exceeds consolidationThreshold * current graph size. + // Exactly one thread wins the CAS and runs consolidation; all others skip. + int total = totalDeletions.incrementAndGet(); + int lastAt = lastConsolidationAt.get(); + int deltaNeeded = (int) Math.max(1, consolidationThreshold * graph.size(0)); + if ((total - lastAt) >= deltaNeeded + && lastConsolidationAt.compareAndSet(lastAt, total)) { + logger.debug("auto-triggering consolidateDanglingEdges at totalDeletions={}", total); + consolidateDanglingEdges(); + } + } + + private void updateEntryPointIfNeeded(int deletedNode) { + var currentEntry = graph.entryNode(); + if (currentEntry == null || currentEntry.node != deletedNode) { + return; + } + + var deletedNodes = graph.getDeletedNodes(); + int newLevel = graph.getMaxLevel(); + int newEntry = -1; + + outer: + while (newLevel >= 0) { + for (var it = graph.getNodes(newLevel); it.hasNext(); ) { + int candidate = it.nextInt(); + if (!deletedNodes.get(candidate)) { + newEntry = candidate; + break outer; + } + } + newLevel--; + } + + graph.updateEntryNode(newEntry >= 0 ? new NodeAtLevel(newLevel, newEntry) : null); + } + + /** + * In-place deletion repair using a GreedySearch toward the deleted node's vector + * (Algorithm 5, IP-DiskANN). + * with an O(DELETION_LD) visited-set check: + *

    + *
  1. Run GreedySearch(G, x_node, DELETION_LD) — navigates to x_node's region.
  2. + *
  3. In-neighbors ≈ {z ∈ Visited : node ∈ N_out(z)} — stale edges pointing to the deleted node.
  4. + *
  5. Replacement candidates = approximateResults (top-DELETION_LD nearest to x_node) — fresh + * high-quality pool that does not degrade as deletions accumulate.
  6. + *
  7. For level > 0 (addHierarchy=true), falls back to the full nodeStream scan since the + * level-0 search does not visit upper-layer nodes. Upper layers are logarithmically sparse + * so this is cheap.
  8. + *
+ */ + private void repairDeletionViaSearch(int node) { + var deletedNodes = graph.getDeletedNodes(); + var entry = graph.entryNode(); + if (entry == null) { + graph.removeNode(node); + return; + } + + var ssp = scoreProvider.searchProviderFor(node); + + try (var gs = searchers.get()) { + var view = graph.getView(); + gs.setView(view); + + // Navigate from the top level down to level 1 to find a good entry point for level 0. + // Then run the full beam search at level 0 with beam width DELETION_LD. + gs.initializeInternal(ssp, entry, new ExcludingBits(node)); + for (int lvl = entry.level; lvl > 0; lvl--) { + gs.searchOneLayer(ssp, 1, 0.0f, lvl, view.liveNodes()); + gs.setEntryPointsFromPreviousLayer(); + } + gs.searchOneLayer(ssp, DELETION_LD, 0.0f, 0, view.liveNodes()); + + // Collect live replacement candidate IDs from the top-DELETION_LD nearest to x_node. + // Scores are re-computed per in-neighbor below (each in-neighbor needs its own scoring). + var candidateIds = new ArrayList(); + gs.approximateResults.foreach((k, score) -> { + if (!deletedNodes.get(k)) candidateIds.add(k); + }); + + // Read-only view of visited set — valid until next initializeInternal call. + var visitedSet = gs.visitedNodes(); + + for (int level = 0; level <= graph.getMaxLevel(); level++) { + if (!graph.getNeighborsIterator(level, node).hasNext()) continue; + final int lvl = level; + + if (level == 0) { + // Fast path: only check nodes that were visited during the search. + // The greedy search navigates toward x_node, so actual in-neighbors + // are very likely to appear in the visited set. + for (int z : visitedSet) { + if (z == node || deletedNodes.get(z)) continue; + for (var it = graph.getNeighborsIterator(lvl, z); it.hasNext(); ) { + if (it.nextInt() == node) { + repairInNeighbor(lvl, z, node, candidateIds); + break; + } + } + } + } else { + // Slow path for higher levels (addHierarchy=true). + // Upper layers are logarithmically sparse so this scan is cheap. + graph.nodeStream(level).forEach(i -> { + if (i == node || deletedNodes.get(i)) return; + for (var it = graph.getNeighborsIterator(lvl, i); it.hasNext(); ) { + if (it.nextInt() == node) { + repairInNeighbor(lvl, i, node, candidateIds); + return; + } + } + }); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + graph.removeNode(node); + } + + /** + * Repairs a single in-neighbor {@code inNeighbor} whose edge to {@code deletedNode} is now stale. + * Scores each candidate in {@code candidateIds} from {@code inNeighbor}'s perspective and + * calls {@link MutableGraphIndex#replaceDeletedNeighbors} to rewire the edge with diversity pruning. + */ + private void repairInNeighbor(int level, int inNeighbor, int deletedNode, List candidateIds) { + var iSf = scoreProvider.searchProviderFor(inNeighbor).scoreFunction(); + var iCandidates = new NodeArray(graph.getDegree(level)); + for (int k : candidateIds) { + if (k == inNeighbor) continue; + iCandidates.insertSorted(k, iSf.similarityTo(k)); + } + var bs = new FixedBitSet(Math.max(graph.getIdUpperBound(), deletedNode + 1)); + bs.set(deletedNode); + graph.replaceDeletedNeighbors(level, inNeighbor, bs, iCandidates); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 73cc5fbd5..67b6f61f9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -39,6 +39,8 @@ import java.io.Closeable; import java.io.IOException; +import java.util.Collections; +import java.util.Set; /** @@ -102,6 +104,18 @@ protected int getExpandedCountBaseLayer() { return expandedCountBaseLayer; } + /** + * Returns a read-only view of the nodes visited (scored) during the last search. + * Valid only until the next call to {@link #initializeInternal}. + *

+ * Package-private: used by {@link GraphIndexBuilder} to find approximate in-neighbors + * of a deleted node during in-place repair (Algorithm 5, IP-DiskANN), avoiding an + * O(N) full-graph scan. + */ + Set visitedNodes() { + return Collections.unmodifiableSet(visited); + } + private void initializeScoreProvider(SearchScoreProvider scoreProvider) { this.scoreProvider = scoreProvider; if (scoreProvider.reranker() == null) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java index 36ec49a16..cd5beda83 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java @@ -162,6 +162,13 @@ interface MutableGraphIndex extends ImmutableGraphIndex { */ void replaceDeletedNeighbors(int level, int node, BitSet toDelete, NodeArray candidates); + /** + * Algorithm 6 (IP-DiskANN): removes out-edges from {@code node} at {@code level} that + * point to nodes no longer structurally present in the graph at that level. + * Pure filter — no replacement candidates, no diversity pruning on survivors. + */ + void removeDeadEdges(int level, int node); + /** * Signals that all mutations have been completed and the graph will not be mutated any further. * Should be called by the builder after all mutations are completed (during cleanup). diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 9ed1a92dd..48b4a6f53 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -287,6 +287,12 @@ public void replaceDeletedNeighbors(int level, int node, BitSet toDelete, NodeAr layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates); } + @Override + public void removeDeadEdges(int level, int node) { + var layer = layers.get(level); + layer.removeDeadEdges(node, neighbor -> !layer.contains(neighbor)); + } + @Override public String toString() { return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", size(0), entryPoint.get()); diff --git a/jvector-tests/pom.xml b/jvector-tests/pom.xml index 7f3145211..541dbebe3 100644 --- a/jvector-tests/pom.xml +++ b/jvector-tests/pom.xml @@ -100,6 +100,14 @@ + + org.apache.maven.plugins + maven-compiler-plugin + + 16 + 16 + + diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java index 12c263a6a..159c11659 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java @@ -21,7 +21,7 @@ import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.junit.Test; @@ -36,21 +36,20 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collectors; import java.util.stream.IntStream; /** - * Runs "nVectors" operations, where each operation is either: - * - an insertion - * - a mock deletion, instantiated through the use of a BitSet for skipping these nodes during search - * - a search - * With probability 0.01, we run cleanup to commit the deletions to the index. The cleanup process and the insertions - * cannot be concurrently executed (we use a lock to control their execution). + * Runs "nVectors" operations concurrently, where each operation is either: + * - an insertion via {@link GraphIndexBuilder#addGraphNode} + * - a deletion via {@link GraphIndexBuilder#markNodeDeleted} (Algorithm 5 — immediate, in-place repair) + * - a search via {@link GraphSearcher#search} + * + * Deletions are now immediate: markNodeDeleted physically removes the node and repairs + * affected edges inline using Algorithm 5 (IP-DiskANN). No deferred cleanup step is needed. + * Concurrent inserts, deletes, and searches are all safe without external locking. */ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestConcurrentReadWriteDeletes extends RandomizedTest { @@ -58,10 +57,8 @@ public class TestConcurrentReadWriteDeletes extends RandomizedTest { private static final int nVectors = 20_000; private static final int dimension = 16; - private static final double cleanupProbability = 0.01; private KeySet keysInserted = new KeySet(); - private List keysRemoved = new CopyOnWriteArrayList(); private List> vectors = createRandomVectors(nVectors, dimension); private RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); @@ -71,34 +68,22 @@ public class TestConcurrentReadWriteDeletes extends RandomizedTest { private BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, similarityFunction); private GraphIndexBuilder builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f, true); - private FixedBitSet liveNodes = new FixedBitSet(nVectors); - - private final Lock writeLock = new ReentrantLock(); - @Test public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException { var vv = ravv.threadLocalSupplier(); testConcurrentOps(i -> { var R = getRandom(); - if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) - { - // In the future, we could improve this test by acquiring the lock earlier and executing other - writeLock.lock(); - try { - builder.addGraphNode(i, vv.get().getVector(i)); - liveNodes.set(i); - keysInserted.add(i); - } finally { - writeLock.unlock(); - } + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) { + // insert + builder.addGraphNode(i, vv.get().getVector(i)); + keysInserted.add(i); } else if (R.nextDouble() < 0.1) { + // delete immediately via Algorithm 5 — no deferred cleanup needed var key = keysInserted.getRandom(); - if (!keysRemoved.contains(key)) { - liveNodes.flip(key); - keysRemoved.add(key); - } + builder.markNodeDeleted(key); } else { + // search var queryVector = randomVector(getRandom(), dimension); SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv); @@ -106,107 +91,79 @@ public void testConcurrentReadsWritesDeletes() throws ExecutionException, Interr int rerankK = Math.min(50, keysInserted.size()); GraphSearcher searcher = new GraphSearcher(builder.getGraph()); - searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes); + searcher.search(ssp, topK, rerankK, 0.f, 0.f, Bits.MatchAllBits.ALL); } }); } @FunctionalInterface - private interface Op - { + private interface Op { void run(int i) throws Throwable; } - private void testConcurrentOps(Op op) throws ExecutionException, InterruptedException { + private void testConcurrentOps(Op op) throws InterruptedException { AtomicInteger counter = new AtomicInteger(); long start = System.currentTimeMillis(); - - // Use a simpler approach that doesn't rely on parallel streams + var keys = IntStream.range(0, nVectors).boxed().collect(Collectors.toList()); Collections.shuffle(keys, getRandom()); - - // Use a thread-safe approach without relying on RandomizedContext - int threadCount = Math.min(Runtime.getRuntime().availableProcessors(), 8); // Limit thread count + + int threadCount = Math.min(Runtime.getRuntime().availableProcessors(), 8); List threads = new ArrayList<>(); int keysPerThread = nVectors / threadCount; - - // Create a thread-safe random seed for each thread - final long randomSeed = getRandom().nextLong(); - + for (int t = 0; t < threadCount; t++) { final int threadIndex = t; final int startIdx = threadIndex * keysPerThread; final int endIdx = (threadIndex == threadCount - 1) ? keys.size() : (threadIndex + 1) * keysPerThread; - + Thread thread = new Thread(() -> { for (int i = startIdx; i < endIdx; i++) { int key = keys.get(i); wrappedOp(op, key); - + if (counter.incrementAndGet() % 1_000 == 0) { var elapsed = System.currentTimeMillis() - start; logger.info(String.format("%d ops in %dms = %f ops/s", - counter.get(), elapsed, counter.get() * 1000.0 / elapsed)); - } - - if (getRandom().nextDouble() < cleanupProbability) { - writeLock.lock(); - try { - for (Integer keyToRemove : keysRemoved) { - builder.markNodeDeleted(keyToRemove); - } - keysRemoved.clear(); - builder.cleanup(); - } finally { - writeLock.unlock(); - } + counter.get(), elapsed, counter.get() * 1000.0 / elapsed)); } } }); - + threads.add(thread); thread.start(); } - - // Wait for all threads to complete + for (Thread thread : threads) { thread.join(); } } private static void wrappedOp(Op op, Integer i) { - try - { + try { op.run(i); - } - catch (Throwable e) - { + } catch (Throwable e) { throw new RuntimeException(e); } } - private static class KeySet - { + private static class KeySet { private final Map keys = new ConcurrentHashMap<>(); private final AtomicInteger ordinal = new AtomicInteger(); - public void add(Integer key) - { + public void add(Integer key) { var i = ordinal.getAndIncrement(); keys.put(i, key); } - public int getRandom() - { + public int getRandom() { if (isEmpty()) throw new IllegalStateException(); var i = TestConcurrentReadWriteDeletes.getRandom().nextInt(ordinal.get()); - // in case there is race with add(key), retry another random return keys.containsKey(i) ? keys.get(i) : getRandom(); } - public boolean isEmpty() - { + public boolean isEmpty() { return keys.isEmpty(); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java index da052a617..ac952a588 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.nio.file.Files; +import java.util.ArrayList; import java.util.stream.IntStream; import static io.github.jbellis.jvector.TestUtil.assertGraphEquals; @@ -86,44 +87,30 @@ public void testCleanup(boolean addHierarchy) throws IOException { // delete all nodes that connect to a random node int nodeToIsolate = getRandom().nextInt(ravv.size()); - int nDeleted = 0; + var toDelete = new ArrayList(); try (var view = graph.getView()) { - for (var i = 0; i < graph.size(0); i++) { - for (var it = view.getNeighborsIterator(0, i); it.hasNext(); ) { // TODO hardcoded level + for (var i = 0; i < ravv.size(); i++) { + for (var it = view.getNeighborsIterator(0, i); it.hasNext(); ) { if (nodeToIsolate == it.nextInt()) { - builder.markNodeDeleted(i); - nDeleted++; + toDelete.add(i); break; } } } } - assertNotEquals(0, nDeleted); + assertNotEquals(0, toDelete.size()); + + for (int node : toDelete) { + builder.markNodeDeleted(node); + } + int nDeleted = toDelete.size(); - // cleanup removes the deleted nodes - builder.cleanup(); assertEquals(ravv.size() - nDeleted, graph.size(0)); - // cleanup should have added new connections to the node that would otherwise have been disconnected + // Algorithm 5 should have added new connections to the node that would otherwise have been disconnected var v = ravv.getVector(nodeToIsolate).copy(); var results = GraphSearcher.search(v, 10, ravv, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(nodeToIsolate, results.getNodes()[0].node); - - var ohgi = (OnHeapGraphIndex) graph; - - // check that we can save and load the graph with "holes" from the deletion - var testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); - var outputPath = testDirectory.resolve("on_heap_graph"); - try (var out = TestUtil.openDataOutputStream(outputPath)) { - ohgi.save(out); - } - - var b2 = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy); - try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath)) { - b2.load(readerSupplier.get()); - } - var reloadedGraph = b2.getGraph(); - assertGraphEquals(graph, reloadedGraph); } @Test @@ -140,7 +127,7 @@ public void testMarkingAllNodesAsDeleted(boolean addHierarchy) { var graph = TestUtil.buildSequentially(builder, ravv); // mark all deleted - for (var i = 0; i < graph.size(0); i++) { + for (var i = 0; i < ravv.size(); i++) { builder.markNodeDeleted(i); } @@ -180,7 +167,6 @@ public void testNoPathToLiveNodesWhenRemovingDeletedNodes2(boolean addHierarchy) builder.markNodeDeleted(i); } - builder.cleanup(); assert builder.graph.getView().entryNode() != null; } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestInplaceDeletion.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestInplaceDeletion.java new file mode 100644 index 000000000..17e5d0b48 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestInplaceDeletion.java @@ -0,0 +1,337 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.LuceneTestCase; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.junit.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; +import static org.junit.Assert.*; + +/* + * These 3 tests are the acceptance criteria for POC: + * 1. Single-threaded sequential deletes: recall@10 must not degrade > 3% after 1k deletes on a 10K-vector index + * 2. Entry point deletion: after deleting the current entry point, the graph must update to a live node and search must still work. + * 3. Algorithm 6 correctness: after consolidateDanglingEdges(), no live node holds an out-edge to a structurally absent node. + * + * All tests run with addHierarchy = false (flat Vamana) and addHierarchy = true (hierarchical) to ensure correctness across both graph modes. + * Graph parameters match the paper's high-recall regime: dimension = 128, cosine, m = 16, efConstruction = 200*/ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestInplaceDeletion extends LuceneTestCase { + private static final int DIMENSION = 128; + private static final VectorSimilarityFunction SIMILARITY = VectorSimilarityFunction.COSINE; + + // high-recall regime + // alpha = 1.2f (vamana diversity rule), neighborOverflow = 1.5f + private static final int M = 16; + private static final int EF_CONSTRUCTION = 200; + private static final int EF_SEARCH = 100; // beam width at query time; topK=10 alone gives ~56% on 1M + private static final float ALPHA = 1.2f; + private static final float NEIGHBOR_OVERFLOW = 1.5f; + + @Test + public void testRecallDegradation() { + testRecallDegradation(false); + testRecallDegradation(true); + } + + private void testRecallDegradation(boolean addHierarchy) { + int indexSize = 10_000; + int deleteCount = 1_000; // 10% of 10K + int batchSize = 100; // rolling recall checkpoint every 100 deletes + int queryCount = 100; + int topK = 10; + + System.out.println("\n=== TEST 1: testRecallDegradation addHierarchy=" + addHierarchy + " ==="); + + var baseVectors = createRandomFloatVectors(indexSize, DIMENSION, getRandom()); + var queryVectors = Arrays.asList(createRandomFloatVectors(queryCount, DIMENSION, getRandom())); + + System.out.println("[setup] indexSize=" + indexSize + " deleteCount=" + deleteCount + + " batchSize=" + batchSize + " queryCount=" + queryCount + " topK=" + topK + + " M=" + M + " efConstruction=" + EF_CONSTRUCTION + " alpha=" + ALPHA); + + var ravv = MockVectorValues.fromValues(baseVectors); + var builder = new GraphIndexBuilder(ravv, SIMILARITY, M, EF_CONSTRUCTION, ALPHA, NEIGHBOR_OVERFLOW, addHierarchy); + + long buildStart = System.currentTimeMillis(); + var graph = builder.build(ravv); + long buildMs = System.currentTimeMillis() - buildStart; + System.out.println("[build] parallel build done in " + buildMs + "ms — graph.size(0)=" + graph.size(0)); + + double baselineRecall = measureRecallBrute(queryVectors, graph, ravv, Collections.emptySet(), topK); + System.out.println("[baseline] recall@" + topK + "=" + String.format("%.4f", baselineRecall)); + + var allOrdinals = IntStream.range(0, indexSize) + .boxed() + .collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(allOrdinals, getRandom()); + + var deletedNodes = new HashSet(); + long totalDeletionMs = 0; + int numBatches = deleteCount / batchSize; + + for (int batch = 0; batch < numBatches; batch++) { + int from = batch * batchSize; + long batchT0 = System.currentTimeMillis(); + for (int i = from; i < from + batchSize; i++) { + builder.markNodeDeleted(allOrdinals.get(i)); + deletedNodes.add(allOrdinals.get(i)); + } + long batchMs = System.currentTimeMillis() - batchT0; + totalDeletionMs += batchMs; + + double rollingRecall = measureRecallBrute(queryVectors, graph, ravv, deletedNodes, topK); + System.out.printf("[batch %2d/%d] deleted=%6d batchTime=%5dms avgPerDelete=%.2fms recall@%d=%.4f degradation=%.2f%%%n", + batch + 1, numBatches, deletedNodes.size(), + batchMs, (double) batchMs / batchSize, + topK, rollingRecall, (baselineRecall - rollingRecall) * 100); + } + + System.out.printf("[deletion summary] totalDeleted=%d totalTime=%dms avgPerDelete=%.2fms%n", + deletedNodes.size(), totalDeletionMs, (double) totalDeletionMs / deleteCount); + + double postRecall = measureRecallBruteVerbose(queryVectors, graph, ravv, deletedNodes, topK); + double degradation = baselineRecall - postRecall; + System.out.println("[result] baseline=" + String.format("%.4f", baselineRecall) + + " post=" + String.format("%.4f", postRecall) + + " degradation=" + String.format("%.2f%%", degradation * 100) + + " threshold=3.00% PASS=" + (degradation <= 0.03)); + + assertTrue( + String.format( + "Recall degraded by %.1f%% (baseline=%.3f, post=%.3f) — exceeds 3%% threshold. " + + "addHierarchy=%b.", + degradation * 100, baselineRecall, postRecall, addHierarchy), + degradation <= 0.03); + } + + @Test + public void testEntryPointDeletion() { + testEntryPointDeletion(false); + testEntryPointDeletion(true); + } + + private void testEntryPointDeletion(boolean addHierarchy) { + int indexSize = 100; + + System.out.println("\n=== TEST 2: testEntryPointDeletion addHierarchy=" + addHierarchy + " ==="); + + var baseVectors = createRandomFloatVectors(indexSize, DIMENSION, getRandom()); + var ravv = MockVectorValues.fromValues(baseVectors); + var builder = new GraphIndexBuilder(ravv, SIMILARITY, M, EF_CONSTRUCTION, ALPHA, NEIGHBOR_OVERFLOW, addHierarchy); + + long buildStart = System.currentTimeMillis(); + var graph = builder.build(ravv); + long buildMs = System.currentTimeMillis() - buildStart; + System.out.println("[build] parallel build done in " + buildMs + "ms — graph.size(0)=" + graph.size(0)); + + var originalEntry = graph.getView().entryNode(); + assertNotNull("Graph must have a valid entry point before any deletion", originalEntry); + int originalEntryNode = originalEntry.node; + System.out.println("[entry-point] before deletion: node=" + originalEntryNode + " level=" + originalEntry.level); + + long deleteStart = System.currentTimeMillis(); + builder.markNodeDeleted(originalEntryNode); + long deleteMs = System.currentTimeMillis() - deleteStart; + System.out.println("[delete] entry point deletion took " + deleteMs + "ms"); + + var newEntry = graph.getView().entryNode(); + System.out.println("[entry-point] after deletion: " + + (newEntry == null ? "null" : "node=" + newEntry.node + " level=" + newEntry.level) + + " — changed=" + (newEntry == null || newEntry.node != originalEntryNode)); + + assertNotNull("Entry point must not be null after deleting old entry point", newEntry); + assertNotEquals( + "Entry point must change after deleting node " + originalEntryNode, + originalEntryNode, newEntry.node); + + var queryVectors = Arrays.asList(createRandomFloatVectors(20, DIMENSION, getRandom())); + for (int i = 0; i < queryVectors.size(); i++) { + var queryVec = queryVectors.get(i); + var results = GraphSearcher.search(queryVec, 5, ravv, SIMILARITY, graph, Bits.ALL); + + assertNotNull("Search returned null after entry point deletion (query " + i + ")", results); + + var resultNodes = new StringBuilder(); + for (var ns : results.getNodes()) { + resultNodes.append(ns.node).append("(").append(String.format("%.4f", ns.score)).append(") "); + assertNotEquals( + "Deleted entry point node " + originalEntryNode + " must not appear in results", + originalEntryNode, ns.node); + } + if (i == 0) { + System.out.println("[search query 0] results=[" + resultNodes.toString().trim() + + "] — deleted node=" + originalEntryNode + " absent=PASS"); + } + } + System.out.println("[search] all 20 queries passed — deleted entry point never returned"); + } + + /** + * Algorithm 6 correctness: after calling consolidateDanglingEdges(), no live node + * at any level may hold an out-edge pointing to a node that is structurally absent + * from that level. + *

+ * We disable auto-trigger (threshold=1.0) so that the sweep only runs when we + * explicitly invoke it, giving us full control over the before/after observation. + */ + @Test + public void testConsolidateDanglingEdges() { + testConsolidateDanglingEdges(false); + testConsolidateDanglingEdges(true); + } + + private void testConsolidateDanglingEdges(boolean addHierarchy) { + int indexSize = 500; + int deleteCount = 100; // 20% deletions — well above default 20% threshold + + System.out.println("\n=== testConsolidateDanglingEdges addHierarchy=" + addHierarchy + " ==="); + + var vectors = createRandomFloatVectors(indexSize, DIMENSION, getRandom()); + var ravv = MockVectorValues.fromValues(vectors); + var builder = new GraphIndexBuilder(ravv, SIMILARITY, M, EF_CONSTRUCTION, ALPHA, NEIGHBOR_OVERFLOW, addHierarchy); + var graph = builder.build(ravv); + + // Disable auto-trigger so we control exactly when Algorithm 6 fires. + builder.setConsolidationThreshold(1.0); + + // Delete nodes sequentially. Algorithm 5 runs for each one but Algorithm 6 does not. + var allOrdinals = IntStream.range(0, indexSize) + .boxed() + .collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(allOrdinals, getRandom()); + + for (int i = 0; i < deleteCount; i++) { + builder.markNodeDeleted(allOrdinals.get(i)); + } + + // Scan for dangling edges BEFORE consolidation — informational only. + long danglingBefore = countDanglingEdges((OnHeapGraphIndex) graph); + System.out.println("[before consolidation] danglingEdges=" + danglingBefore); + + // Run Algorithm 6. + builder.consolidateDanglingEdges(); + + // Every out-edge of every live node must now point to a structurally present node. + long danglingAfter = countDanglingEdges((OnHeapGraphIndex) graph); + System.out.println("[after consolidation] danglingEdges=" + danglingAfter); + + assertEquals( + "consolidateDanglingEdges() must leave zero dangling edges at all levels. addHierarchy=" + addHierarchy, + 0L, danglingAfter); + } + + /** + * Counts out-edges across all levels that point to a structurally absent neighbor node. + */ + private long countDanglingEdges(OnHeapGraphIndex graph) { + long dangling = 0; + var view = graph.getView(); + int maxLevel = graph.getMaxLevel(); + for (int level = 0; level <= maxLevel; level++) { + var nodeIt = graph.nodeStream(level).iterator(); + while (nodeIt.hasNext()) { + int node = nodeIt.nextInt(); + var it = view.getNeighborsIterator(level, node); + while (it.hasNext()) { + int neighbor = it.nextInt(); + if (!view.contains(level, neighbor)) dangling++; + } + } + } + return dangling; + } + + /** + * Measures recall using brute-force exact search as ground truth. + * Deleted ordinals are excluded from both the ground truth and search scoring. + */ + private double measureRecallBrute(List> queries, + ImmutableGraphIndex graph, + RandomAccessVectorValues ravv, + Set deletedNodes, + int topK) { + double totalRecall = 0.0; + for (VectorFloat query : queries) { + Set gtSet = bruteForceTopK(query, ravv, deletedNodes, topK); + var results = GraphSearcher.search(query, topK, EF_SEARCH, ravv, SIMILARITY, graph, Bits.ALL); + int hits = 0; + for (var ns : results.getNodes()) { + if (gtSet.contains(ns.node)) hits++; + } + totalRecall += (double) hits / Math.max(1, gtSet.size()); + } + return totalRecall / queries.size(); + } + + /** Same as measureRecallBrute but prints query 0 hit/miss detail. */ + private double measureRecallBruteVerbose(List> queries, + ImmutableGraphIndex graph, + RandomAccessVectorValues ravv, + Set deletedNodes, + int topK) { + double totalRecall = 0.0; + for (int q = 0; q < queries.size(); q++) { + VectorFloat query = queries.get(q); + Set gtSet = bruteForceTopK(query, ravv, deletedNodes, topK); + var results = GraphSearcher.search(query, topK, EF_SEARCH, ravv, SIMILARITY, graph, Bits.ALL); + int hits = 0; + var graphNodes = new StringBuilder(); + for (var ns : results.getNodes()) { + boolean isHit = gtSet.contains(ns.node); + if (isHit) hits++; + graphNodes.append(ns.node).append(isHit ? "[HIT]" : "[miss]") + .append("(").append(String.format("%.4f", ns.score)).append(") "); + } + double queryRecall = (double) hits / Math.max(1, gtSet.size()); + totalRecall += queryRecall; + if (q == 0) { + System.out.println(" [query 0] hits=" + hits + "/" + gtSet.size() + + " gt=" + gtSet + + " graphResults=[" + graphNodes.toString().trim() + "]"); + } + } + return totalRecall / queries.size(); + } + + /** + * Returns the topK nearest neighbour ordinals for the given query via linear scan, + * excluding all ordinals in {@code excluded}. + */ + private Set bruteForceTopK(VectorFloat query, + RandomAccessVectorValues ravv, + Set excluded, + int topK) { + return IntStream.range(0, ravv.size()) + .filter(i -> !excluded.contains(i)) + .boxed() + .sorted((a, b) -> Float.compare( + SIMILARITY.compare(query, ravv.getVector(b)), + SIMILARITY.compare(query, ravv.getVector(a)))) + .limit(topK) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } +} \ No newline at end of file diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index 29a8dca29..e72df8581 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -105,8 +105,6 @@ public void testRenumberingOnDelete(boolean addHierarchy) throws IOException { // delete the first node builder.markNodeDeleted(0); - builder.cleanup(); - builder.setEntryPoint(0, builder.getGraph().getIdUpperBound() - 1); // TODO // check assertEquals(2, original.size(0));