Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 60 additions & 54 deletions pkg/rag/strategy/bm25.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"math"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -99,6 +100,10 @@ type BM25Strategy struct {
b float64 // length normalization parameter (typically 0.75)
avgDocLength float64 // average document length
docCount int // total number of documents

// Tokenization helpers (built once per strategy instance)
replacer *strings.Replacer
stopwords map[string]bool
}

// newBM25Strategy creates a new BM25-based retrieval strategy
Expand All @@ -120,6 +125,18 @@ func newBM25Strategy(name string, db *bm25DB, events chan<- types.Event, k1, b f
shouldIgnore: shouldIgnore,
k1: k1,
b: b,
replacer: strings.NewReplacer(
".", " ", ",", " ", "!", " ", "?", " ",
";", " ", ":", " ", "(", " ", ")", " ",
"[", " ", "]", " ", "{", " ", "}", " ",
"\"", " ", "'", " ", "\n", " ", "\t", " ",
),
stopwords: map[string]bool{
"the": true, "a": true, "an": true, "and": true, "or": true,
"but": true, "in": true, "on": true, "at": true, "to": true,
"for": true, "of": true, "as": true, "by": true, "is": true,
"was": true, "are": true, "were": true, "be": true, "been": true,
},
}
}

Expand Down Expand Up @@ -247,11 +264,7 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
return nil, errors.New("query contains no valid terms")
}

// For BM25, we need to retrieve all documents and score them
// In a production system, you'd use an inverted index for efficiency
// For now, this is a simplified implementation

// Get all documents (in production, use inverted index to get only relevant docs)
// Get all documents
allDocs, err := s.getAllDocuments(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve documents: %w", err)
Expand All @@ -261,10 +274,33 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
return []database.SearchResult{}, nil
}

// Score each document using BM25
// Pre-tokenize all documents once: build term frequency maps and lengths.
docTermFreqs := make([]map[string]int, len(allDocs))
docLengths := make([]float64, len(allDocs))
for i, doc := range allDocs {
tokens := s.tokenize(doc.Content)
tf := make(map[string]int, len(tokens))
for _, term := range tokens {
tf[term]++
}
docTermFreqs[i] = tf
docLengths[i] = float64(len(tokens))
}

// Pre-compute document frequency for each query term.
df := make(map[string]int, len(queryTerms))
for _, term := range queryTerms {
for _, tf := range docTermFreqs {
if tf[term] > 0 {
df[term]++
}
}
}

// Score each document.
scores := make([]database.SearchResult, 0, len(allDocs))
for _, doc := range allDocs {
score := s.calculateBM25Score(queryTerms, doc, allDocs)
for i, doc := range allDocs {
score := s.calculateBM25Score(queryTerms, docTermFreqs[i], docLengths[i], df)
if score >= threshold {
scores = append(scores, database.SearchResult{
Document: doc,
Expand All @@ -273,14 +309,10 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
}
}

// Sort by score descending
for i := 0; i < len(scores); i++ {
for j := i + 1; j < len(scores); j++ {
if scores[j].Similarity > scores[i].Similarity {
scores[i], scores[j] = scores[j], scores[i]
}
}
}
// Sort by score descending.
slices.SortFunc(scores, func(a, b database.SearchResult) int {
return cmp.Compare(b.Similarity, a.Similarity)
})

// Return top N results
if len(scores) > numResults {
Expand Down Expand Up @@ -384,77 +416,51 @@ func (s *BM25Strategy) Close() error {
// Helper methods

func (s *BM25Strategy) tokenize(text string) []string {
// Simple tokenization: lowercase and split on whitespace/punctuation
text = strings.ToLower(text)
// Replace common punctuation with spaces
replacer := strings.NewReplacer(
".", " ", ",", " ", "!", " ", "?", " ",
";", " ", ":", " ", "(", " ", ")", " ",
"[", " ", "]", " ", "{", " ", "}", " ",
"\"", " ", "'", " ", "\n", " ", "\t", " ",
)
text = replacer.Replace(text)
text = s.replacer.Replace(text)

tokens := strings.Fields(text)

// Remove stopwords (simplified list)
stopwords := map[string]bool{
"the": true, "a": true, "an": true, "and": true, "or": true,
"but": true, "in": true, "on": true, "at": true, "to": true,
"for": true, "of": true, "as": true, "by": true, "is": true,
"was": true, "are": true, "were": true, "be": true, "been": true,
}

filtered := make([]string, 0, len(tokens))
for _, token := range tokens {
if len(token) > 2 && !stopwords[token] {
if len(token) > 2 && !s.stopwords[token] {
filtered = append(filtered, token)
}
}

return filtered
}

func (s *BM25Strategy) calculateBM25Score(queryTerms []string, doc database.Document, allDocs []database.Document) float64 {
docLength := float64(len(s.tokenize(doc.Content)))
func (s *BM25Strategy) calculateBM25Score(queryTerms []string, docTermFreq map[string]int, docLength float64, df map[string]int) float64 {
score := 0.0

docTerms := s.tokenize(doc.Content)
docTermFreq := make(map[string]int)
for _, term := range docTerms {
docTermFreq[term]++
}

for _, queryTerm := range queryTerms {
// Term frequency in document
tf := float64(docTermFreq[queryTerm])
if tf == 0 {
continue
}

// Document frequency (number of documents containing the term)
df := 0
for _, d := range allDocs {
if strings.Contains(strings.ToLower(d.Content), queryTerm) {
df++
}
}

if df == 0 {
// Document frequency (pre-computed)
termDF := df[queryTerm]
if termDF == 0 {
continue
}

// IDF calculation
idf := math.Log((float64(s.docCount)-float64(df)+0.5)/(float64(df)+0.5) + 1.0)
idf := math.Log((float64(s.docCount)-float64(termDF)+0.5)/(float64(termDF)+0.5) + 1.0)

// BM25 formula
numerator := tf * (s.k1 + 1.0)
denominator := tf + s.k1*(1.0-s.b+s.b*(docLength/s.avgDocLength))
lengthRatio := 1.0
if s.avgDocLength > 0 {
lengthRatio = docLength / s.avgDocLength
}
denominator := tf + s.k1*(1.0-s.b+s.b*lengthRatio)
score += idf * (numerator / denominator)
}

// Normalize score to 0-1 range for consistency with vector similarity
// This is a simple normalization; in production, you might use a different approach
return math.Min(score/float64(len(queryTerms)), 1.0)
}

Expand Down