From 690d138a133c4733f5a7795fbd481d0f497eee34 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 10 Mar 2026 16:57:49 +0100 Subject: [PATCH] perf: optimize BM25 scoring strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move replacer and stopwords into BM25Strategy struct fields, built once in the constructor instead of on every tokenize() call - Pre-tokenize all documents once per query and pre-compute document frequencies, reducing complexity from O(T×N²×L) to O(N×L + T×N) - Replace O(n²) selection sort with slices.SortFunc (O(n log n)) - Guard against division by zero when avgDocLength is 0 - Fix DF computation to use exact token matching (consistent with TF) Assisted-By: docker-agent --- pkg/rag/strategy/bm25.go | 114 ++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 54 deletions(-) diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index 1535575fd..72cb0fe56 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -9,6 +9,7 @@ import ( "math" "os" "path/filepath" + "slices" "strings" "sync" "time" @@ -99,6 +100,10 @@ type BM25Strategy struct { b float64 // length normalization parameter (typically 0.75) avgDocLength float64 // average document length docCount int // total number of documents + + // Tokenization helpers (built once per strategy instance) + replacer *strings.Replacer + stopwords map[string]bool } // newBM25Strategy creates a new BM25-based retrieval strategy @@ -120,6 +125,18 @@ func newBM25Strategy(name string, db *bm25DB, events chan<- types.Event, k1, b f shouldIgnore: shouldIgnore, k1: k1, b: b, + replacer: strings.NewReplacer( + ".", " ", ",", " ", "!", " ", "?", " ", + ";", " ", ":", " ", "(", " ", ")", " ", + "[", " ", "]", " ", "{", " ", "}", " ", + "\"", " ", "'", " ", "\n", " ", "\t", " ", + ), + stopwords: map[string]bool{ + "the": true, "a": true, "an": true, "and": true, "or": true, + "but": true, "in": true, "on": true, "at": true, "to": true, + "for": true, "of": true, "as": true, "by": true, "is": true, + "was": true, "are": true, "were": true, "be": true, "been": true, + }, } } @@ -247,11 +264,7 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int, return nil, errors.New("query contains no valid terms") } - // For BM25, we need to retrieve all documents and score them - // In a production system, you'd use an inverted index for efficiency - // For now, this is a simplified implementation - - // Get all documents (in production, use inverted index to get only relevant docs) + // Get all documents allDocs, err := s.getAllDocuments(ctx) if err != nil { return nil, fmt.Errorf("failed to retrieve documents: %w", err) @@ -261,10 +274,33 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int, return []database.SearchResult{}, nil } - // Score each document using BM25 + // Pre-tokenize all documents once: build term frequency maps and lengths. + docTermFreqs := make([]map[string]int, len(allDocs)) + docLengths := make([]float64, len(allDocs)) + for i, doc := range allDocs { + tokens := s.tokenize(doc.Content) + tf := make(map[string]int, len(tokens)) + for _, term := range tokens { + tf[term]++ + } + docTermFreqs[i] = tf + docLengths[i] = float64(len(tokens)) + } + + // Pre-compute document frequency for each query term. + df := make(map[string]int, len(queryTerms)) + for _, term := range queryTerms { + for _, tf := range docTermFreqs { + if tf[term] > 0 { + df[term]++ + } + } + } + + // Score each document. scores := make([]database.SearchResult, 0, len(allDocs)) - for _, doc := range allDocs { - score := s.calculateBM25Score(queryTerms, doc, allDocs) + for i, doc := range allDocs { + score := s.calculateBM25Score(queryTerms, docTermFreqs[i], docLengths[i], df) if score >= threshold { scores = append(scores, database.SearchResult{ Document: doc, @@ -273,14 +309,10 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int, } } - // Sort by score descending - for i := 0; i < len(scores); i++ { - for j := i + 1; j < len(scores); j++ { - if scores[j].Similarity > scores[i].Similarity { - scores[i], scores[j] = scores[j], scores[i] - } - } - } + // Sort by score descending. + slices.SortFunc(scores, func(a, b database.SearchResult) int { + return cmp.Compare(b.Similarity, a.Similarity) + }) // Return top N results if len(scores) > numResults { @@ -384,30 +416,14 @@ func (s *BM25Strategy) Close() error { // Helper methods func (s *BM25Strategy) tokenize(text string) []string { - // Simple tokenization: lowercase and split on whitespace/punctuation text = strings.ToLower(text) - // Replace common punctuation with spaces - replacer := strings.NewReplacer( - ".", " ", ",", " ", "!", " ", "?", " ", - ";", " ", ":", " ", "(", " ", ")", " ", - "[", " ", "]", " ", "{", " ", "}", " ", - "\"", " ", "'", " ", "\n", " ", "\t", " ", - ) - text = replacer.Replace(text) + text = s.replacer.Replace(text) tokens := strings.Fields(text) - // Remove stopwords (simplified list) - stopwords := map[string]bool{ - "the": true, "a": true, "an": true, "and": true, "or": true, - "but": true, "in": true, "on": true, "at": true, "to": true, - "for": true, "of": true, "as": true, "by": true, "is": true, - "was": true, "are": true, "were": true, "be": true, "been": true, - } - filtered := make([]string, 0, len(tokens)) for _, token := range tokens { - if len(token) > 2 && !stopwords[token] { + if len(token) > 2 && !s.stopwords[token] { filtered = append(filtered, token) } } @@ -415,16 +431,9 @@ func (s *BM25Strategy) tokenize(text string) []string { return filtered } -func (s *BM25Strategy) calculateBM25Score(queryTerms []string, doc database.Document, allDocs []database.Document) float64 { - docLength := float64(len(s.tokenize(doc.Content))) +func (s *BM25Strategy) calculateBM25Score(queryTerms []string, docTermFreq map[string]int, docLength float64, df map[string]int) float64 { score := 0.0 - docTerms := s.tokenize(doc.Content) - docTermFreq := make(map[string]int) - for _, term := range docTerms { - docTermFreq[term]++ - } - for _, queryTerm := range queryTerms { // Term frequency in document tf := float64(docTermFreq[queryTerm]) @@ -432,29 +441,26 @@ func (s *BM25Strategy) calculateBM25Score(queryTerms []string, doc database.Docu continue } - // Document frequency (number of documents containing the term) - df := 0 - for _, d := range allDocs { - if strings.Contains(strings.ToLower(d.Content), queryTerm) { - df++ - } - } - - if df == 0 { + // Document frequency (pre-computed) + termDF := df[queryTerm] + if termDF == 0 { continue } // IDF calculation - idf := math.Log((float64(s.docCount)-float64(df)+0.5)/(float64(df)+0.5) + 1.0) + idf := math.Log((float64(s.docCount)-float64(termDF)+0.5)/(float64(termDF)+0.5) + 1.0) // BM25 formula numerator := tf * (s.k1 + 1.0) - denominator := tf + s.k1*(1.0-s.b+s.b*(docLength/s.avgDocLength)) + lengthRatio := 1.0 + if s.avgDocLength > 0 { + lengthRatio = docLength / s.avgDocLength + } + denominator := tf + s.k1*(1.0-s.b+s.b*lengthRatio) score += idf * (numerator / denominator) } // Normalize score to 0-1 range for consistency with vector similarity - // This is a simple normalization; in production, you might use a different approach return math.Min(score/float64(len(queryTerms)), 1.0) }