Skip to content

Commit d54d451

Browse files
authored
Merge pull request #209 from CausalInferenceLab/feat/keyword-retriever
Feat/keyword retriever
2 parents 26507fd + b38b1dd commit d54d451

File tree

6 files changed

+461
-0
lines changed

6 files changed

+461
-0
lines changed

src/lang2sql/components/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .keyword import KeywordRetriever
2+
3+
__all__ = ["KeywordRetriever"]
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Internal BM25 index — stdlib only (math, collections).
3+
4+
BM25 parameters:
5+
k1 = 1.5 (term frequency saturation)
6+
b = 0.75 (document length normalization)
7+
8+
Tokenization: text.lower().split() (whitespace, no external deps)
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import math
14+
from collections import Counter
15+
from typing import Any
16+
17+
_K1 = 1.5
18+
_B = 0.75
19+
20+
21+
def _tokenize(text: str) -> list[str]:
22+
return text.lower().split()
23+
24+
25+
def _extract_text(value: Any) -> list[str]:
26+
"""Recursively extract text tokens from any value (str, list, dict, other)."""
27+
if isinstance(value, str):
28+
return [value]
29+
if isinstance(value, list):
30+
result: list[str] = []
31+
for item in value:
32+
result.extend(_extract_text(item))
33+
return result
34+
if isinstance(value, dict):
35+
result = []
36+
for k, v in value.items():
37+
result.append(str(k))
38+
result.extend(_extract_text(v))
39+
return result
40+
return [str(value)]
41+
42+
43+
def _entry_to_text(entry: dict[str, Any], index_fields: list[str]) -> str:
44+
"""
45+
Convert a catalog dict entry into a single text string for indexing.
46+
47+
Handles:
48+
- str fields → joined as-is
49+
- dict fields → "key value key value ..." (for columns: {col_name: col_desc})
50+
- list fields → each element extracted recursively
51+
- other types → str(value)
52+
"""
53+
parts: list[str] = []
54+
for field in index_fields:
55+
value = entry.get(field)
56+
if value is None:
57+
continue
58+
parts.extend(_extract_text(value))
59+
return " ".join(parts)
60+
61+
62+
class _BM25Index:
63+
"""
64+
In-memory BM25 index over a list[dict] catalog.
65+
66+
Usage:
67+
index = _BM25Index(catalog, index_fields=["name", "description", "columns"])
68+
scores = index.score("주문 테이블") # list[float], one per catalog entry
69+
"""
70+
71+
def __init__(
72+
self,
73+
catalog: list[dict[str, Any]],
74+
index_fields: list[str],
75+
) -> None:
76+
self._catalog = catalog
77+
self._n = len(catalog)
78+
79+
# Tokenize each document
80+
self._docs: list[list[str]] = [
81+
_tokenize(_entry_to_text(entry, index_fields)) for entry in catalog
82+
]
83+
84+
# Term frequencies per document
85+
self._tfs: list[Counter[str]] = [Counter(doc) for doc in self._docs]
86+
87+
# Document lengths
88+
doc_lengths = [len(doc) for doc in self._docs]
89+
self._avgdl: float = sum(doc_lengths) / self._n if self._n > 0 else 0.0
90+
91+
# Inverted index: term → set of doc indices that contain it
92+
self._df: Counter[str] = Counter()
93+
for tf in self._tfs:
94+
for term in tf:
95+
self._df[term] += 1
96+
97+
def score(self, query: str) -> list[float]:
98+
"""
99+
Return a BM25 score for each catalog entry.
100+
101+
Args:
102+
query: Natural language query string.
103+
104+
Returns:
105+
List of float scores, one per catalog entry, in original order.
106+
"""
107+
if self._n == 0:
108+
return []
109+
110+
query_terms = _tokenize(query)
111+
scores = [0.0] * self._n
112+
113+
for term in query_terms:
114+
df_t = self._df.get(term, 0)
115+
if df_t == 0:
116+
continue
117+
118+
# IDF — smoothed to avoid log(0)
119+
idf = math.log((self._n - df_t + 0.5) / (df_t + 0.5) + 1)
120+
121+
for i, tf in enumerate(self._tfs):
122+
tf_t = tf.get(term, 0)
123+
if tf_t == 0:
124+
continue
125+
126+
dl = len(self._docs[i])
127+
denom = tf_t + _K1 * (1 - _B + _B * dl / self._avgdl)
128+
scores[i] += idf * (tf_t * (_K1 + 1)) / denom
129+
130+
return scores
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Optional
4+
5+
from ...core.base import BaseComponent
6+
from ...core.context import RunContext
7+
from ...core.hooks import TraceHook
8+
from ._bm25 import _BM25Index
9+
10+
_DEFAULT_INDEX_FIELDS = ["name", "description", "columns"]
11+
12+
13+
class KeywordRetriever(BaseComponent):
14+
"""
15+
BM25-based keyword retriever over a table catalog.
16+
17+
Indexes catalog entries at init time (in-memory).
18+
On each call, reads ``run.query`` and writes top-N matches
19+
into ``run.schema_selected``.
20+
21+
Args:
22+
catalog: List of table dicts. Each dict should have at minimum
23+
``name`` (str) and ``description`` (str).
24+
Optional keys: ``columns`` (dict[str, str]), ``meta`` (dict).
25+
top_n: Maximum number of results to return. Defaults to 5.
26+
index_fields: Fields to index. Defaults to ["name", "description", "columns"].
27+
Pass a custom list to replace the default (complete override).
28+
name: Component name for tracing. Defaults to "KeywordRetriever".
29+
hook: Optional TraceHook for observability.
30+
31+
Example::
32+
33+
retriever = KeywordRetriever(catalog=[
34+
{"name": "orders", "description": "주문 정보 테이블"},
35+
])
36+
run = retriever(RunContext(query="주문 조회"))
37+
print(run.schema_selected) # [{"name": "orders", ...}]
38+
"""
39+
40+
def __init__(
41+
self,
42+
*,
43+
catalog: list[dict[str, Any]],
44+
top_n: int = 5,
45+
index_fields: Optional[list[str]] = None,
46+
name: Optional[str] = None,
47+
hook: Optional[TraceHook] = None,
48+
) -> None:
49+
super().__init__(name=name or "KeywordRetriever", hook=hook)
50+
self._catalog = catalog
51+
self._top_n = top_n
52+
self._index_fields = (
53+
index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS
54+
)
55+
self._index = _BM25Index(catalog, self._index_fields)
56+
57+
def run(self, run: RunContext) -> RunContext:
58+
"""
59+
Search the catalog with BM25 and store results in ``run.schema_selected``.
60+
61+
Args:
62+
run: Current RunContext. Reads ``run.query``.
63+
64+
Returns:
65+
The same RunContext with ``run.schema_selected`` set to a
66+
ranked list[dict] (BM25 score descending). Empty list if no match.
67+
"""
68+
if not self._catalog:
69+
run.schema_selected = []
70+
return run
71+
72+
scores = self._index.score(run.query)
73+
74+
# Pair each catalog entry with its score, sort descending
75+
ranked = sorted(
76+
zip(scores, self._catalog),
77+
key=lambda x: x[0],
78+
reverse=True,
79+
)
80+
81+
# Return up to top_n entries that have a positive score
82+
results = [entry for score, entry in ranked[: self._top_n] if score > 0.0]
83+
84+
run.schema_selected = results
85+
return run

src/lang2sql/core/ports.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import Protocol
4+
5+
6+
class EmbeddingPort(Protocol):
7+
"""
8+
Placeholder — will be implemented in OQ-2 (VectorRetriever).
9+
10+
Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.)
11+
so VectorRetriever is not coupled to any specific provider.
12+
"""
13+
14+
def embed_query(self, text: str) -> list[float]: ...
15+
16+
def embed_texts(self, texts: list[str]) -> list[list[float]]: ...

0 commit comments

Comments
 (0)