From 8415d26dc2fed6c8fc8aa656f30b034f51a24884 Mon Sep 17 00:00:00 2001 From: Ted Willke Date: Thu, 19 Mar 2026 01:33:32 +0000 Subject: [PATCH 1/3] Perf optimized version of AccuracyMetrics and removal of dead and noisy code. --- .../jvector/example/util/AccuracyMetrics.java | 100 +++++++++--------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java index ba537ed06..6a0ba0f52 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java @@ -17,11 +17,9 @@ package io.github.jbellis.jvector.example.util; import io.github.jbellis.jvector.graph.SearchResult; - -import java.util.Arrays; +import java.util.HashSet; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import java.util.Set; /** * Computes accuracy metrics, such as recall and mean average precision. @@ -41,43 +39,44 @@ public static double recallFromSearchResults(List> gt, L if (gt.size() != retrieved.size()) { throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements"); } - Long correctCount = IntStream.range(0, gt.size()) - .mapToObj(i -> topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved)) - .reduce(0L, Long::sum); + + long correctCount = 0; + for (int i = 0; i < gt.size(); i++) { + correctCount += topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved); + } + return (double) correctCount / (kGT * gt.size()); } - private static long topKCorrect(List gt, List retrieved, int kGT, int kRetrieved) { + private static long topKCorrect(List gt, SearchResult retrieved, int kGT, int kRetrieved) { + // Exception validation + var nodes = retrieved.getNodes(); if (kGT > kRetrieved) { throw new IllegalArgumentException("kGT: " + kGT + " > kRetrieved: " + kRetrieved); } if (kGT > gt.size()) { throw new IllegalArgumentException("kGT: " + kGT + " > Gt size: " + gt.size()); } - if (kRetrieved > retrieved.size()) { - throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved.size()); + if (kRetrieved > nodes.length) { + throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes.length); } - var gtView = crop(gt, kGT); - var retrievedView = crop(retrieved, kRetrieved); - - if (gtView.size() > retrieved.size()) { - return gtView.stream().filter(retrievedView::contains).count(); - } else { - return retrievedView.stream().filter(gtView::contains).count(); + // Build HashSet with explicit capacity to avoid rehashing. + // Load factor is 0.75, so sized to kGT / 0.75. + Set gtSet = new HashSet<>((int) (kGT / 0.75f) + 1); + for (int i = 0; i < kGT; i++) { + gtSet.add(gt.get(i)); } - } - private static long topKCorrect(List gt, SearchResult retrieved, int kGT, int kRetrieved) { - var temp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node) - .boxed() - .collect(Collectors.toList()); - return topKCorrect(gt, temp, kGT, kRetrieved); - } + // Manual primitive loop for speed (no Stream setup). + int hits = 0; + for (int i = 0; i < kRetrieved; i++) { + if (gtSet.contains(nodes[i].node)) { + hits++; + } + } - private static List crop(List list, int k) { - int count = Math.min(list.size(), k); - return list.subList(0, count); + return hits; } /** @@ -89,33 +88,34 @@ private static List crop(List list, int k) { * @return the average precision */ public static double averagePrecisionAtK(List gt, SearchResult retrieved, int k) { - var retrievedTemp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node) - .boxed() - .collect(Collectors.toList()); - + var nodes = retrieved.getNodes(); if (k > gt.size()) { throw new IllegalArgumentException("k: " + k + " > Gt size: " + gt.size()); } - if (k > retrievedTemp.size()) { - throw new IllegalArgumentException("k: " + k + " > retrieved size: " + retrievedTemp.size()); + if (k > nodes.length) { + throw new IllegalArgumentException("k: " + k + " > retrieved size: " + nodes.length); } - var gtView = crop(gt, k); - var retrievedView = crop(retrievedTemp, k); + // Sized hashset used for performance. + Set gtSet = new HashSet<>((int) (k / 0.75f) + 1); + for (int i = 0; i < k; i++) { + gtSet.add(gt.get(i)); + } - double score = 0.; - int num_hits = 0; - int i = 0; + // Handles potential duplicates in O(1). + Set seen = new HashSet<>((int) (k / 0.75f) + 1); - for (var p : retrievedView) { - if (gtView.contains(p) && !retrievedView.subList(0, i).contains(p)) { - num_hits += 1; - score += num_hits / (i + 1.0); + double score = 0.; + int hits = 0; + for (int i = 0; i < k; i++) { + int p = nodes[i].node; + if (gtSet.contains(p) && seen.add(p)) { + hits++; + score += (double) hits / (i + 1); } - i++; } - return score / gtView.size(); + return score / k; } /** @@ -130,10 +130,12 @@ public static double meanAveragePrecisionAtK(List> gt, L if (gt.size() != retrieved.size()) { throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements"); } - Double apk = IntStream.range(0, gt.size()) - .mapToObj(i -> averagePrecisionAtK(gt.get(i), retrieved.get(i), k)) - .reduce(0., Double::sum); - return apk / gt.size(); - } + double totalAp = 0; + for (int i = 0; i < gt.size(); i++) { + totalAp += averagePrecisionAtK(gt.get(i), retrieved.get(i), k); + } + + return totalAp / gt.size(); + } } From 5ce81aca28047061f6c090aa57ea42bbd5030eab Mon Sep 17 00:00:00 2001 From: Ted Willke Date: Tue, 31 Mar 2026 00:29:13 +0000 Subject: [PATCH 2/3] AccuracyMetrics now treats duplicate ground truth and retrieved results as an error condition. --- .../jvector/example/util/AccuracyMetrics.java | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java index 6a0ba0f52..b26d8ac11 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java @@ -65,13 +65,20 @@ private static long topKCorrect(List gt, SearchResult retrieved, int kG // Load factor is 0.75, so sized to kGT / 0.75. Set gtSet = new HashSet<>((int) (kGT / 0.75f) + 1); for (int i = 0; i < kGT; i++) { - gtSet.add(gt.get(i)); + int ord = gt.get(i); + if (!gtSet.add(ord)) { + throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + kGT + ": " + ord); + } } - // Manual primitive loop for speed (no Stream setup). + Set seenRetrieved = new HashSet<>((int) (kRetrieved / 0.75f) + 1); int hits = 0; for (int i = 0; i < kRetrieved; i++) { - if (gtSet.contains(nodes[i].node)) { + int p = nodes[i].node; + if (!seenRetrieved.add(p)) { + throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + kRetrieved + ": " + p); + } + if (gtSet.contains(p)) { hits++; } } @@ -99,17 +106,21 @@ public static double averagePrecisionAtK(List gt, SearchResult retrieve // Sized hashset used for performance. Set gtSet = new HashSet<>((int) (k / 0.75f) + 1); for (int i = 0; i < k; i++) { - gtSet.add(gt.get(i)); + int ord = gt.get(i); + if (!gtSet.add(ord)) { + throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + k + ": " + ord); + } } - // Handles potential duplicates in O(1). - Set seen = new HashSet<>((int) (k / 0.75f) + 1); - + Set seenRetrieved = new HashSet<>((int) (k / 0.75f) + 1); double score = 0.; int hits = 0; for (int i = 0; i < k; i++) { int p = nodes[i].node; - if (gtSet.contains(p) && seen.add(p)) { + if (!seenRetrieved.add(p)) { + throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + k + ": " + p); + } + if (gtSet.contains(p)) { hits++; score += (double) hits / (i + 1); } From 6c4902a6f7cb3778b17230d5ceb23b7bd46ffa70 Mon Sep 17 00:00:00 2001 From: Ted Willke Date: Tue, 31 Mar 2026 04:35:30 +0000 Subject: [PATCH 3/3] AccuracyMetrics tests, including duplicate Exceptions. --- .../example/util/AccuracyMetricsTest.java | 357 ++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/AccuracyMetricsTest.java diff --git a/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/AccuracyMetricsTest.java b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/AccuracyMetricsTest.java new file mode 100644 index 000000000..8e2dcefde --- /dev/null +++ b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/AccuracyMetricsTest.java @@ -0,0 +1,357 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.example.util; + +import io.github.jbellis.jvector.graph.SearchResult; + +import java.util.List; + +/** + * Simple test class to verify the AccuracyMetrics functionality. + */ +public class AccuracyMetricsTest { + public static void main(String[] args) { + System.out.println("Running AccuracyMetrics tests..."); + + testRecallPerfect(); + testRecallPartial(); + testRecallAveragesAcrossQueries(); + + testAveragePrecisionPerfect(); + testAveragePrecisionPartial(); + testAveragePrecisionNoHits(); + testAveragePrecisionRespectsRankingOrder(); + + testMeanAveragePrecision(); + + testRecallThrowsOnMismatchedQueryCounts(); + testMeanAveragePrecisionThrowsOnMismatchedQueryCounts(); + + testRecallThrowsWhenKgtExceedsKRetrieved(); + testRecallThrowsWhenKgtExceedsGtSize(); + testRecallThrowsWhenKRetrievedExceedsRetrievedSize(); + + testAveragePrecisionThrowsWhenKExceedsGtSize(); + testAveragePrecisionThrowsWhenKExceedsRetrievedSize(); + + testRecallThrowsOnDuplicateGroundTruthOrdinal(); + testRecallThrowsOnDuplicateRetrievedOrdinal(); + testAveragePrecisionThrowsOnDuplicateGroundTruthOrdinal(); + testAveragePrecisionThrowsOnDuplicateRetrievedOrdinal(); + + System.out.println("All AccuracyMetrics tests completed successfully!"); + } + + private static void testRecallPerfect() { + System.out.println("\nTest: Recall perfect"); + + double recall = AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 2, 3)), + 3, + 3 + ); + + assertEquals("Recall", 1.0, recall, 0.0000001); + } + + private static void testRecallPartial() { + System.out.println("\nTest: Recall partial"); + + double recall = AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 4, 5)), + 3, + 3 + ); + + assertEquals("Recall", 1.0 / 3.0, recall, 0.0000001); + } + + private static void testRecallAveragesAcrossQueries() { + System.out.println("\nTest: Recall averages across queries"); + + double recall = AccuracyMetrics.recallFromSearchResults( + List.of( + List.of(1, 2, 3), + List.of(10, 20, 30) + ), + List.of( + result(1, 2, 3), + result(10, 99, 98) + ), + 3, + 3 + ); + + assertEquals("Recall", 2.0 / 3.0, recall, 0.0000001); + } + + private static void testAveragePrecisionPerfect() { + System.out.println("\nTest: Average precision perfect"); + + double ap = AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(1, 2, 3), + 3 + ); + + assertEquals("Average precision", 1.0, ap, 0.0000001); + } + + private static void testAveragePrecisionPartial() { + System.out.println("\nTest: Average precision partial"); + + double ap = AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(1, 4, 2), + 3 + ); + + // Relevant hits at ranks 1 and 3: + // P@1 = 1/1 + // P@3 = 2/3 + // AP@3 = (1 + 2/3) / 3 = 5/9 + assertEquals("Average precision", 5.0 / 9.0, ap, 0.0000001); + } + + private static void testAveragePrecisionNoHits() { + System.out.println("\nTest: Average precision no hits"); + + double ap = AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(4, 5, 6), + 3 + ); + + assertEquals("Average precision", 0.0, ap, 0.0000001); + } + + private static void testAveragePrecisionRespectsRankingOrder() { + System.out.println("\nTest: Average precision respects ranking order"); + + double better = AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(1, 2, 9), + 3 + ); + + double worse = AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(9, 1, 2), + 3 + ); + + assertEquals("Better-ranked AP", 2.0 / 3.0, better, 0.0000001); + assertEquals("Worse-ranked AP", (1.0 / 2.0 + 2.0 / 3.0) / 3.0, worse, 0.0000001); + } + + private static void testMeanAveragePrecision() { + System.out.println("\nTest: Mean average precision"); + + double map = AccuracyMetrics.meanAveragePrecisionAtK( + List.of( + List.of(1, 2, 3), + List.of(10, 20, 30) + ), + List.of( + result(1, 2, 3), + result(99, 98, 97) + ), + 3 + ); + + assertEquals("Mean average precision", 0.5, map, 0.0000001); + } + + private static void testRecallThrowsOnMismatchedQueryCounts() { + System.out.println("\nTest: Recall throws on mismatched query counts"); + + assertThrows( + "Insufficient ground truth for the number of retrieved elements", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 2, 3), result(4, 5, 6)), + 3, + 3 + ) + ); + } + + private static void testMeanAveragePrecisionThrowsOnMismatchedQueryCounts() { + System.out.println("\nTest: MAP throws on mismatched query counts"); + + assertThrows( + "Insufficient ground truth for the number of retrieved elements", + () -> AccuracyMetrics.meanAveragePrecisionAtK( + List.of(List.of(1, 2, 3)), + List.of(result(1, 2, 3), result(4, 5, 6)), + 3 + ) + ); + } + + private static void testRecallThrowsWhenKgtExceedsKRetrieved() { + System.out.println("\nTest: Recall throws when kGT exceeds kRetrieved"); + + assertThrows( + "kGT: 3 > kRetrieved: 2", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 2)), + 3, + 2 + ) + ); + } + + private static void testRecallThrowsWhenKgtExceedsGtSize() { + System.out.println("\nTest: Recall throws when kGT exceeds GT size"); + + assertThrows( + "kGT: 3 > Gt size: 2", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2)), + List.of(result(1, 2, 3)), + 3, + 3 + ) + ); + } + + private static void testRecallThrowsWhenKRetrievedExceedsRetrievedSize() { + System.out.println("\nTest: Recall throws when kRetrieved exceeds retrieved size"); + + assertThrows( + "kRetrieved: 3 > retrieved size: 2", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 2)), + 2, + 3 + ) + ); + } + + private static void testAveragePrecisionThrowsWhenKExceedsGtSize() { + System.out.println("\nTest: AP throws when k exceeds GT size"); + + assertThrows( + "k: 3 > Gt size: 2", + () -> AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2), + result(1, 2, 3), + 3 + ) + ); + } + + private static void testAveragePrecisionThrowsWhenKExceedsRetrievedSize() { + System.out.println("\nTest: AP throws when k exceeds retrieved size"); + + assertThrows( + "k: 3 > retrieved size: 2", + () -> AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(1, 2), + 3 + ) + ); + } + + private static void testRecallThrowsOnDuplicateGroundTruthOrdinal() { + System.out.println("\nTest: Recall throws on duplicate ground truth ordinal"); + + assertThrows( + "Duplicate ground truth ordinal in top-3: 1", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 1, 2)), + List.of(result(1, 2, 3)), + 3, + 3 + ) + ); + } + + private static void testRecallThrowsOnDuplicateRetrievedOrdinal() { + System.out.println("\nTest: Recall throws on duplicate retrieved ordinal"); + + assertThrows( + "Duplicate retrieved ordinal in top-3: 1", + () -> AccuracyMetrics.recallFromSearchResults( + List.of(List.of(1, 2, 3)), + List.of(result(1, 1, 2)), + 3, + 3 + ) + ); + } + + private static void testAveragePrecisionThrowsOnDuplicateGroundTruthOrdinal() { + System.out.println("\nTest: AP throws on duplicate ground truth ordinal"); + + assertThrows( + "Duplicate ground truth ordinal in top-3: 1", + () -> AccuracyMetrics.averagePrecisionAtK( + List.of(1, 1, 2), + result(1, 2, 3), + 3 + ) + ); + } + + private static void testAveragePrecisionThrowsOnDuplicateRetrievedOrdinal() { + System.out.println("\nTest: AP throws on duplicate retrieved ordinal"); + + assertThrows( + "Duplicate retrieved ordinal in top-3: 1", + () -> AccuracyMetrics.averagePrecisionAtK( + List.of(1, 2, 3), + result(1, 1, 2), + 3 + ) + ); + } + + private static SearchResult result(int... nodes) { + SearchResult.NodeScore[] nodeScores = new SearchResult.NodeScore[nodes.length]; + for (int i = 0; i < nodes.length; i++) { + nodeScores[i] = new SearchResult.NodeScore(nodes[i], 0.0f); + } + return new SearchResult(nodeScores, 0, 0, 0, 0, Float.POSITIVE_INFINITY); + } + + private static void assertEquals(String message, double expected, double actual, double delta) { + if (Math.abs(expected - actual) > delta) { + throw new AssertionError(message + " - Expected: " + expected + ", Actual: " + actual); + } + System.out.println("✓ " + message + " - Value: " + actual); + } + + private static void assertThrows(String expectedMessage, Runnable runnable) { + try { + runnable.run(); + throw new AssertionError("Expected exception with message: " + expectedMessage); + } catch (IllegalArgumentException e) { + if (!expectedMessage.equals(e.getMessage())) { + throw new AssertionError( + "Expected exception message: " + expectedMessage + ", Actual: " + e.getMessage() + ); + } + System.out.println("✓ Threw expected exception - Message: " + e.getMessage()); + } + } +}