diff --git a/src/ltm.ts b/src/ltm.ts index df76ac4..f3a7a57 100644 --- a/src/ltm.ts +++ b/src/ltm.ts @@ -417,6 +417,56 @@ export function search(input: { } } +export type ScoredKnowledgeEntry = KnowledgeEntry & { rank: number }; + +/** + * Search with BM25 scores included. Returns results with raw FTS5 rank values + * for use in cross-source score fusion (RRF). + */ +export function searchScored(input: { + query: string; + projectPath?: string; + limit?: number; +}): ScoredKnowledgeEntry[] { + const limit = input.limit ?? 20; + const q = ftsQuery(input.query); + if (q === EMPTY_QUERY) return []; + + const pid = input.projectPath ? ensureProject(input.projectPath) : null; + const { title, content, category } = FTS_WEIGHTS; + + const ftsSQL = pid + ? `SELECT k.*, bm25(knowledge_fts, ?, ?, ?) as rank FROM knowledge k + JOIN knowledge_fts f ON k.rowid = f.rowid + WHERE knowledge_fts MATCH ? + AND (k.project_id = ? OR k.project_id IS NULL OR k.cross_project = 1) + AND k.confidence > 0.2 + ORDER BY rank LIMIT ?` + : `SELECT k.*, bm25(knowledge_fts, ?, ?, ?) as rank FROM knowledge k + JOIN knowledge_fts f ON k.rowid = f.rowid + WHERE knowledge_fts MATCH ? + AND k.confidence > 0.2 + ORDER BY rank LIMIT ?`; + + const ftsParams = pid + ? [title, content, category, q, pid, limit] + : [title, content, category, q, limit]; + + try { + const results = db().query(ftsSQL).all(...ftsParams) as ScoredKnowledgeEntry[]; + if (results.length) return results; + + const qOr = ftsQueryOr(input.query); + if (qOr === EMPTY_QUERY) return []; + const ftsParamsOr = pid + ? [title, content, category, qOr, pid, limit] + : [title, content, category, qOr, limit]; + return db().query(ftsSQL).all(...ftsParamsOr) as ScoredKnowledgeEntry[]; + } catch { + return []; + } +} + export function get(id: string): KnowledgeEntry | null { return db() .query("SELECT * FROM knowledge WHERE id = ?") diff --git a/src/reflect.ts b/src/reflect.ts index 792c33e..0a266c6 100644 --- a/src/reflect.ts +++ b/src/reflect.ts @@ -3,7 +3,7 @@ import * as temporal from "./temporal"; import * as ltm from "./ltm"; import * as log from "./log"; import { db, ensureProject } from "./db"; -import { ftsQuery, ftsQueryOr, EMPTY_QUERY } from "./search"; +import { ftsQuery, ftsQueryOr, EMPTY_QUERY, reciprocalRankFusion } from "./search"; import { serialize, inline, h, p, ul, lip, liph, t, root } from "./markdown"; type Distillation = { @@ -41,25 +41,27 @@ function searchDistillationsLike(input: { .all(...allParams) as Distillation[]; } -function searchDistillations(input: { +type ScoredDistillation = Distillation & { rank: number }; + +function searchDistillationsScored(input: { projectPath: string; query: string; sessionID?: string; limit?: number; -}): Distillation[] { +}): ScoredDistillation[] { const pid = ensureProject(input.projectPath); const limit = input.limit ?? 10; const q = ftsQuery(input.query); if (q === EMPTY_QUERY) return []; const ftsSQL = input.sessionID - ? `SELECT d.id, d.observations, d.generation, d.created_at, d.session_id + ? `SELECT d.id, d.observations, d.generation, d.created_at, d.session_id, rank FROM distillations d JOIN distillation_fts f ON d.rowid = f.rowid WHERE distillation_fts MATCH ? AND d.project_id = ? AND d.session_id = ? ORDER BY rank LIMIT ?` - : `SELECT d.id, d.observations, d.generation, d.created_at, d.session_id + : `SELECT d.id, d.observations, d.generation, d.created_at, d.session_id, rank FROM distillations d JOIN distillation_fts f ON d.rowid = f.rowid WHERE distillation_fts MATCH ? @@ -70,7 +72,7 @@ function searchDistillations(input: { : [q, pid, limit]; try { - const results = db().query(ftsSQL).all(...params) as Distillation[]; + const results = db().query(ftsSQL).all(...params) as ScoredDistillation[]; if (results.length) return results; // AND returned nothing — try OR fallback @@ -79,15 +81,15 @@ function searchDistillations(input: { const paramsOr = input.sessionID ? [qOr, pid, input.sessionID, limit] : [qOr, pid, limit]; - return db().query(ftsSQL).all(...paramsOr) as Distillation[]; + return db().query(ftsSQL).all(...paramsOr) as ScoredDistillation[]; } catch { - // FTS5 failed — fall back to LIKE search + // FTS5 failed — fall back to LIKE search with synthetic rank return searchDistillationsLike({ pid, query: input.query, sessionID: input.sessionID, limit, - }); + }).map((d, i) => ({ ...d, rank: -(10 - i) })); } } @@ -137,6 +139,53 @@ function formatResults(input: { return serialize(root(...children)); } +type TaggedResult = + | { source: "knowledge"; item: ltm.ScoredKnowledgeEntry } + | { source: "distillation"; item: ScoredDistillation } + | { source: "temporal"; item: temporal.ScoredTemporalMessage }; + +function formatFusedResults( + results: Array<{ item: TaggedResult; score: number }>, + maxResults: number, +): string { + if (!results.length) return "No results found for this query."; + + const items = results.slice(0, maxResults).map(({ item: tagged }) => { + switch (tagged.source) { + case "knowledge": { + const k = tagged.item; + return liph( + t( + `**[knowledge/${k.category}]** ${inline(k.title)}: ${inline(k.content)}`, + ), + ); + } + case "distillation": { + const d = tagged.item; + const preview = + d.observations.length > 500 + ? d.observations.slice(0, 500) + "..." + : d.observations; + return lip( + `**[distilled]** ${inline(preview)}`, + ); + } + case "temporal": { + const m = tagged.item; + const preview = + m.content.length > 500 + ? m.content.slice(0, 500) + "..." + : m.content; + return lip( + `**[temporal/${m.role}]** (session: ${m.session_id.slice(0, 8)}...) ${inline(preview)}`, + ); + } + } + }); + + return serialize(root(h(2, "Recall Results"), ul(items))); +} + export function createRecallTool(projectPath: string, knowledgeEnabled = true): ReturnType { return tool({ description: @@ -163,52 +212,80 @@ export function createRecallTool(projectPath: string, knowledgeEnabled = true): return "Query too vague — try using specific keywords, file names, or technical terms."; } - let temporalResults: temporal.TemporalMessage[] = []; - if (scope !== "knowledge") { + // Run scored searches across all sources + const knowledgeResults: ltm.ScoredKnowledgeEntry[] = []; + if (knowledgeEnabled && scope !== "session") { try { - temporalResults = temporal.search({ - projectPath, - query: args.query, - sessionID: scope === "session" ? sid : undefined, - limit: 10, - }); + knowledgeResults.push( + ...ltm.searchScored({ + query: args.query, + projectPath, + limit: 10, + }), + ); } catch (err) { - log.error("recall: temporal search failed:", err); + log.error("recall: knowledge search failed:", err); } } - let distillationResults: Distillation[] = []; + const distillationResults: ScoredDistillation[] = []; if (scope !== "knowledge") { try { - distillationResults = searchDistillations({ - projectPath, - query: args.query, - sessionID: scope === "session" ? sid : undefined, - limit: 5, - }); + distillationResults.push( + ...searchDistillationsScored({ + projectPath, + query: args.query, + sessionID: scope === "session" ? sid : undefined, + limit: 10, + }), + ); } catch (err) { log.error("recall: distillation search failed:", err); } } - let knowledgeResults: ltm.KnowledgeEntry[] = []; - if (knowledgeEnabled && scope !== "session") { + const temporalResults: temporal.ScoredTemporalMessage[] = []; + if (scope !== "knowledge") { try { - knowledgeResults = ltm.search({ - query: args.query, - projectPath, - limit: 10, - }); + temporalResults.push( + ...temporal.searchScored({ + projectPath, + query: args.query, + sessionID: scope === "session" ? sid : undefined, + limit: 10, + }), + ); } catch (err) { - log.error("recall: knowledge search failed:", err); + log.error("recall: temporal search failed:", err); } } - return formatResults({ - temporalResults, - distillationResults, - knowledgeResults, - }); + // Fuse results using Reciprocal Rank Fusion + const fused = reciprocalRankFusion([ + { + items: knowledgeResults.map((item) => ({ + source: "knowledge" as const, + item, + })), + key: (r) => `k:${r.item.id}`, + }, + { + items: distillationResults.map((item) => ({ + source: "distillation" as const, + item, + })), + key: (r) => `d:${r.item.id}`, + }, + { + items: temporalResults.map((item) => ({ + source: "temporal" as const, + item, + })), + key: (r) => `t:${r.item.id}`, + }, + ]); + + return formatFusedResults(fused, 20); }, }); } diff --git a/src/search.ts b/src/search.ts index 548e426..1d879ca 100644 --- a/src/search.ts +++ b/src/search.ts @@ -172,3 +172,65 @@ export function ftsQueryOr(raw: string): string { if (!terms.length) return EMPTY_QUERY; return terms.map((w) => `${w}*`).join(" OR "); } + +// --------------------------------------------------------------------------- +// Score normalization & fusion (Phase 2) +// --------------------------------------------------------------------------- + +/** + * Normalize a raw FTS5 BM25 rank to a 0–1 range using min-max normalization. + * + * FTS5 rank/bm25() values are negative (more negative = better match). + * This converts them to 0–1 where 1 = best match in the result set. + * + * Used for display scores only — RRF fusion uses rank positions, not scores. + */ +export function normalizeRank( + rank: number, + minRank: number, + maxRank: number, +): number { + // All same rank → everything is equally relevant + if (minRank === maxRank) return 1; + // minRank is most negative (best), maxRank is least negative (worst) + // Invert: best match → 1.0, worst → 0.0 + return (maxRank - rank) / (maxRank - minRank); +} + +/** + * Reciprocal Rank Fusion: merge multiple ranked lists into a single ranked list. + * + * RRF score = Σ(1 / (k + rank_i)) for each list where the item appears. + * k = 60 is standard (from Cormack et al., 2009; also used by QMD). + * + * RRF is rank-based, not score-based — raw score magnitude differences across + * different FTS5 tables don't matter. Only relative ordering within each list. + * + * @param lists Each list provides items (in ranked order) and a key function + * for deduplication. Items at the front of the array are rank 0. + * @param k Smoothing constant. Default 60. + * @returns Fused list sorted by RRF score descending. When items appear + * in multiple lists, the first occurrence's item is kept. + */ +export function reciprocalRankFusion( + lists: Array<{ items: T[]; key: (item: T) => string }>, + k = 60, +): Array<{ item: T; score: number }> { + const scores = new Map(); + + for (const list of lists) { + for (let rank = 0; rank < list.items.length; rank++) { + const item = list.items[rank]; + const id = list.key(item); + const rrfScore = 1 / (k + rank); + const existing = scores.get(id); + if (existing) { + existing.score += rrfScore; + } else { + scores.set(id, { item, score: rrfScore }); + } + } + } + + return [...scores.values()].sort((a, b) => b.score - a.score); +} diff --git a/src/temporal.ts b/src/temporal.ts index 5e018ea..46f432a 100644 --- a/src/temporal.ts +++ b/src/temporal.ts @@ -201,6 +201,51 @@ export function search(input: { } } +export type ScoredTemporalMessage = TemporalMessage & { rank: number }; + +/** + * Search with BM25 scores included. Returns results with raw FTS5 rank values + * for use in cross-source score fusion (RRF). + */ +export function searchScored(input: { + projectPath: string; + query: string; + sessionID?: string; + limit?: number; +}): ScoredTemporalMessage[] { + const pid = ensureProject(input.projectPath); + const limit = input.limit ?? 20; + const q = ftsQuery(input.query); + if (q === EMPTY_QUERY) return []; + + const ftsSQL = input.sessionID + ? `SELECT m.*, rank FROM temporal_messages m + JOIN temporal_fts f ON m.rowid = f.rowid + WHERE f.content MATCH ? AND m.project_id = ? AND m.session_id = ? + ORDER BY rank LIMIT ?` + : `SELECT m.*, rank FROM temporal_messages m + JOIN temporal_fts f ON m.rowid = f.rowid + WHERE f.content MATCH ? AND m.project_id = ? + ORDER BY rank LIMIT ?`; + const params = input.sessionID + ? [q, pid, input.sessionID, limit] + : [q, pid, limit]; + + try { + const results = db().query(ftsSQL).all(...params) as ScoredTemporalMessage[]; + if (results.length) return results; + + const qOr = ftsQueryOr(input.query); + if (qOr === EMPTY_QUERY) return []; + const paramsOr = input.sessionID + ? [qOr, pid, input.sessionID, limit] + : [qOr, pid, limit]; + return db().query(ftsSQL).all(...paramsOr) as ScoredTemporalMessage[]; + } catch { + return []; + } +} + export function count(projectPath: string, sessionID?: string): number { const pid = ensureProject(projectPath); const query = sessionID diff --git a/test/search.test.ts b/test/search.test.ts index f8466db..6f9c856 100644 --- a/test/search.test.ts +++ b/test/search.test.ts @@ -1,5 +1,12 @@ import { describe, test, expect } from "bun:test"; -import { ftsQuery, ftsQueryOr, STOPWORDS, EMPTY_QUERY } from "../src/search"; +import { + ftsQuery, + ftsQueryOr, + STOPWORDS, + EMPTY_QUERY, + normalizeRank, + reciprocalRankFusion, +} from "../src/search"; describe("search", () => { describe("ftsQuery (AND semantics)", () => { @@ -131,4 +138,124 @@ describe("search", () => { expect(EMPTY_QUERY).toBe('""'); }); }); + + describe("normalizeRank", () => { + test("best rank (most negative) normalizes to 1.0", () => { + // minRank=-10 is best, maxRank=-1 is worst + expect(normalizeRank(-10, -10, -1)).toBe(1); + }); + + test("worst rank normalizes to 0.0", () => { + expect(normalizeRank(-1, -10, -1)).toBe(0); + }); + + test("mid-range rank normalizes proportionally", () => { + const score = normalizeRank(-5.5, -10, -1); + expect(score).toBeCloseTo(0.5, 1); + }); + + test("all same rank returns 1.0", () => { + expect(normalizeRank(-5, -5, -5)).toBe(1); + }); + + test("single result returns 1.0", () => { + expect(normalizeRank(-3, -3, -3)).toBe(1); + }); + }); + + describe("reciprocalRankFusion", () => { + test("merges two lists by RRF score", () => { + const fused = reciprocalRankFusion([ + { + items: [{ id: "a" }, { id: "b" }, { id: "c" }], + key: (x) => x.id, + }, + { + items: [{ id: "b" }, { id: "a" }, { id: "d" }], + key: (x) => x.id, + }, + ]); + + const ids = fused.map((r) => r.item.id); + // "a" appears at rank 0 in list 1 and rank 1 in list 2 → highest combined RRF + // "b" appears at rank 1 in list 1 and rank 0 in list 2 → same as "a" + expect(ids.slice(0, 2).sort()).toEqual(["a", "b"]); + // "c" and "d" only appear in one list each + expect(ids).toContain("c"); + expect(ids).toContain("d"); + expect(ids.length).toBe(4); + }); + + test("items in multiple lists score higher than single-list items", () => { + const fused = reciprocalRankFusion([ + { + items: [{ id: "shared" }, { id: "only-in-1" }], + key: (x) => x.id, + }, + { + items: [{ id: "shared" }, { id: "only-in-2" }], + key: (x) => x.id, + }, + ]); + + // "shared" appears in both lists → highest score + expect(fused[0].item.id).toBe("shared"); + // Its score should be roughly 2 * 1/(60+0) ≈ 0.0333 + expect(fused[0].score).toBeCloseTo(2 / 60, 4); + }); + + test("preserves first occurrence when item appears in multiple lists", () => { + const fused = reciprocalRankFusion([ + { + items: [{ id: "x", source: "list1" }], + key: (x) => x.id, + }, + { + items: [{ id: "x", source: "list2" }], + key: (x) => x.id, + }, + ]); + + // First occurrence (list1) should be kept + expect((fused[0].item as { source: string }).source).toBe("list1"); + }); + + test("empty lists produce empty result", () => { + const fused = reciprocalRankFusion<{ id: string }>([ + { items: [], key: (x) => x.id }, + { items: [], key: (x) => x.id }, + ]); + expect(fused.length).toBe(0); + }); + + test("single list returns items in order", () => { + const fused = reciprocalRankFusion([ + { + items: [{ id: "first" }, { id: "second" }, { id: "third" }], + key: (x) => x.id, + }, + ]); + + expect(fused.map((r) => r.item.id)).toEqual([ + "first", + "second", + "third", + ]); + }); + + test("custom k parameter changes scores", () => { + const fused = reciprocalRankFusion( + [ + { + items: [{ id: "a" }], + key: (x) => x.id, + }, + ], + 10, // smaller k → higher scores + ); + + // With k=10, rank 0 → 1/(10+0) = 0.1 + expect(fused[0].score).toBeCloseTo(0.1, 4); + }); + }); });