From 0b9c33488fc7c0baab44ab40ac941df7db29d4bb Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 22 Feb 2026 23:52:54 +0900 Subject: [PATCH 1/5] feat(core): add EmbeddingPort Protocol as placeholder for vector retrieval --- src/lang2sql/core/ports.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/lang2sql/core/ports.py diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py new file mode 100644 index 0000000..7452d4c --- /dev/null +++ b/src/lang2sql/core/ports.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Protocol + + +class EmbeddingPort(Protocol): + """ + Placeholder — will be implemented in OQ-2 (VectorRetriever). + + Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.) + so VectorRetriever is not coupled to any specific provider. + """ + + def embed_query(self, text: str) -> list[float]: ... + + def embed_texts(self, texts: list[str]) -> list[list[float]]: ... From b1ff1187ffd652303b78b29430288f3274e89f28 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 22 Feb 2026 23:54:03 +0900 Subject: [PATCH 2/5] feat(components/retrieval): implement stdlib-only BM25 index with recursive field extraction --- src/lang2sql/components/__init__.py | 0 src/lang2sql/components/retrieval/__init__.py | 3 + src/lang2sql/components/retrieval/_bm25.py | 129 ++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 src/lang2sql/components/__init__.py create mode 100644 src/lang2sql/components/retrieval/__init__.py create mode 100644 src/lang2sql/components/retrieval/_bm25.py diff --git a/src/lang2sql/components/__init__.py b/src/lang2sql/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lang2sql/components/retrieval/__init__.py b/src/lang2sql/components/retrieval/__init__.py new file mode 100644 index 0000000..3d9f791 --- /dev/null +++ b/src/lang2sql/components/retrieval/__init__.py @@ -0,0 +1,3 @@ +from .keyword import KeywordRetriever + +__all__ = ["KeywordRetriever"] diff --git a/src/lang2sql/components/retrieval/_bm25.py b/src/lang2sql/components/retrieval/_bm25.py new file mode 100644 index 0000000..df82187 --- /dev/null +++ b/src/lang2sql/components/retrieval/_bm25.py @@ -0,0 +1,129 @@ +""" +Internal BM25 index — stdlib only (math, collections). + +BM25 parameters: + k1 = 1.5 (term frequency saturation) + b = 0.75 (document length normalization) + +Tokenization: text.lower().split() (whitespace, no external deps) +""" +from __future__ import annotations + +import math +from collections import Counter +from typing import Any + +_K1 = 1.5 +_B = 0.75 + + +def _tokenize(text: str) -> list[str]: + return text.lower().split() + + +def _extract_text(value: Any) -> list[str]: + """Recursively extract text tokens from any value (str, list, dict, other).""" + if isinstance(value, str): + return [value] + if isinstance(value, list): + result: list[str] = [] + for item in value: + result.extend(_extract_text(item)) + return result + if isinstance(value, dict): + result = [] + for k, v in value.items(): + result.append(str(k)) + result.extend(_extract_text(v)) + return result + return [str(value)] + + +def _entry_to_text(entry: dict[str, Any], index_fields: list[str]) -> str: + """ + Convert a catalog dict entry into a single text string for indexing. + + Handles: + - str fields → joined as-is + - dict fields → "key value key value ..." (for columns: {col_name: col_desc}) + - list fields → each element extracted recursively + - other types → str(value) + """ + parts: list[str] = [] + for field in index_fields: + value = entry.get(field) + if value is None: + continue + parts.extend(_extract_text(value)) + return " ".join(parts) + + +class _BM25Index: + """ + In-memory BM25 index over a list[dict] catalog. + + Usage: + index = _BM25Index(catalog, index_fields=["name", "description", "columns"]) + scores = index.score("주문 테이블") # list[float], one per catalog entry + """ + + def __init__( + self, + catalog: list[dict[str, Any]], + index_fields: list[str], + ) -> None: + self._catalog = catalog + self._n = len(catalog) + + # Tokenize each document + self._docs: list[list[str]] = [ + _tokenize(_entry_to_text(entry, index_fields)) for entry in catalog + ] + + # Term frequencies per document + self._tfs: list[Counter[str]] = [Counter(doc) for doc in self._docs] + + # Document lengths + doc_lengths = [len(doc) for doc in self._docs] + self._avgdl: float = sum(doc_lengths) / self._n if self._n > 0 else 0.0 + + # Inverted index: term → set of doc indices that contain it + self._df: Counter[str] = Counter() + for tf in self._tfs: + for term in tf: + self._df[term] += 1 + + def score(self, query: str) -> list[float]: + """ + Return a BM25 score for each catalog entry. + + Args: + query: Natural language query string. + + Returns: + List of float scores, one per catalog entry, in original order. + """ + if self._n == 0: + return [] + + query_terms = _tokenize(query) + scores = [0.0] * self._n + + for term in query_terms: + df_t = self._df.get(term, 0) + if df_t == 0: + continue + + # IDF — smoothed to avoid log(0) + idf = math.log((self._n - df_t + 0.5) / (df_t + 0.5) + 1) + + for i, tf in enumerate(self._tfs): + tf_t = tf.get(term, 0) + if tf_t == 0: + continue + + dl = len(self._docs[i]) + denom = tf_t + _K1 * (1 - _B + _B * dl / self._avgdl) + scores[i] += idf * (tf_t * (_K1 + 1)) / denom + + return scores From 547b2cc499a41d1d6b14e4885e33c6959dd780f2 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 22 Feb 2026 23:54:36 +0900 Subject: [PATCH 3/5] feat(components/retrieval): add KeywordRetriever with BM25-based catalog search --- src/lang2sql/components/retrieval/keyword.py | 87 ++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 src/lang2sql/components/retrieval/keyword.py diff --git a/src/lang2sql/components/retrieval/keyword.py b/src/lang2sql/components/retrieval/keyword.py new file mode 100644 index 0000000..74fd17d --- /dev/null +++ b/src/lang2sql/components/retrieval/keyword.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any, Optional + +from ...core.base import BaseComponent +from ...core.context import RunContext +from ...core.hooks import TraceHook +from ._bm25 import _BM25Index + +_DEFAULT_INDEX_FIELDS = ["name", "description", "columns"] + + +class KeywordRetriever(BaseComponent): + """ + BM25-based keyword retriever over a table catalog. + + Indexes catalog entries at init time (in-memory). + On each call, reads ``run.query`` and writes top-N matches + into ``run.schema_selected``. + + Args: + catalog: List of table dicts. Each dict should have at minimum + ``name`` (str) and ``description`` (str). + Optional keys: ``columns`` (dict[str, str]), ``meta`` (dict). + top_n: Maximum number of results to return. Defaults to 5. + index_fields: Fields to index. Defaults to ["name", "description", "columns"]. + Pass a custom list to replace the default (complete override). + name: Component name for tracing. Defaults to "KeywordRetriever". + hook: Optional TraceHook for observability. + + Example:: + + retriever = KeywordRetriever(catalog=[ + {"name": "orders", "description": "주문 정보 테이블"}, + ]) + run = retriever(RunContext(query="주문 조회")) + print(run.schema_selected) # [{"name": "orders", ...}] + """ + + def __init__( + self, + *, + catalog: list[dict[str, Any]], + top_n: int = 5, + index_fields: Optional[list[str]] = None, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "KeywordRetriever", hook=hook) + self._catalog = catalog + self._top_n = top_n + self._index_fields = index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS + self._index = _BM25Index(catalog, self._index_fields) + + def run(self, run: RunContext) -> RunContext: + """ + Search the catalog with BM25 and store results in ``run.schema_selected``. + + Args: + run: Current RunContext. Reads ``run.query``. + + Returns: + The same RunContext with ``run.schema_selected`` set to a + ranked list[dict] (BM25 score descending). Empty list if no match. + """ + if not self._catalog: + run.schema_selected = [] + return run + + scores = self._index.score(run.query) + + # Pair each catalog entry with its score, sort descending + ranked = sorted( + zip(scores, self._catalog), + key=lambda x: x[0], + reverse=True, + ) + + # Return up to top_n entries that have a positive score + results = [ + entry + for score, entry in ranked[: self._top_n] + if score > 0.0 + ] + + run.schema_selected = results + return run From 5aad33d9d73962d6f14b3eb46d83ed8a4cbd880c Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 22 Feb 2026 23:55:02 +0900 Subject: [PATCH 4/5] test(components/retrieval): add 14 unit tests for KeywordRetriever --- tests/test_components_keyword_retriever.py | 222 +++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 tests/test_components_keyword_retriever.py diff --git a/tests/test_components_keyword_retriever.py b/tests/test_components_keyword_retriever.py new file mode 100644 index 0000000..772608f --- /dev/null +++ b/tests/test_components_keyword_retriever.py @@ -0,0 +1,222 @@ +""" +Tests for KeywordRetriever — 14 cases. + +Pattern follows test_core_base.py: +- pytest, inline fixtures, MemoryHook +""" +import pytest + +from lang2sql.components.retrieval import KeywordRetriever +from lang2sql.core.context import RunContext +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import SequentialFlow + + +# ------------------------- +# Shared test catalog +# ------------------------- + +ORDER_TABLE = { + "name": "order_table", + "description": "고객 주문 정보를 저장하는 테이블", + "columns": {"order_id": "주문 고유 ID", "amount": "주문 금액"}, + "meta": {"primary_key": "order_id", "tags": ["finance", "core"]}, +} + +USER_TABLE = { + "name": "user_table", + "description": "사용자 계정 정보 테이블", + "columns": {"user_id": "사용자 고유 ID", "email": "이메일"}, + "meta": {"primary_key": "user_id"}, +} + +PRODUCT_TABLE = { + "name": "product_table", + "description": "상품 목록 및 재고 테이블", + "columns": {"product_id": "상품 ID", "stock": "재고 수량"}, +} + +CATALOG = [ORDER_TABLE, USER_TABLE, PRODUCT_TABLE] + + +# ------------------------- +# Tests +# ------------------------- + + +def test_basic_search_returns_relevant_table(): + """'주문' 질문 → order_table이 top 위치.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문 정보 조회")) + + assert run.schema_selected + assert run.schema_selected[0]["name"] == "order_table" + + +def test_top_n_limits_results(): + """top_n=2 → 최대 2개 반환.""" + retriever = KeywordRetriever(catalog=CATALOG, top_n=2) + run = retriever(RunContext(query="테이블")) + + assert len(run.schema_selected) <= 2 + + +def test_top_n_larger_than_catalog(): + """top_n=10, catalog 3개 → 최대 3개 반환.""" + retriever = KeywordRetriever(catalog=CATALOG, top_n=10) + run = retriever(RunContext(query="테이블")) + + assert len(run.schema_selected) <= len(CATALOG) + + +def test_zero_results_returns_empty_list(): + """완전히 무관한 query → schema_selected == [].""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="xyzzy_no_match_token_12345")) + + assert run.schema_selected == [] + + +def test_schema_selected_is_list_of_dict(): + """결과가 list[dict]인지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문")) + + assert isinstance(run.schema_selected, list) + assert len(run.schema_selected) > 0 + assert isinstance(run.schema_selected[0], dict) + + +def test_returns_runcontext(): + """run 메서드가 RunContext를 반환하는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + result = retriever(RunContext(query="주문")) + + assert isinstance(result, RunContext) + + +def test_hook_start_end_events(): + """MemoryHook으로 start/end 이벤트 확인.""" + hook = MemoryHook() + retriever = KeywordRetriever(catalog=CATALOG, hook=hook) + retriever(RunContext(query="주문")) + + assert len(hook.events) == 2 + assert hook.events[0].name == "component.run" + assert hook.events[0].phase == "start" + assert hook.events[1].name == "component.run" + assert hook.events[1].phase == "end" + assert hook.events[1].duration_ms is not None + assert hook.events[1].duration_ms >= 0.0 + + +def test_empty_catalog(): + """catalog=[] → schema_selected == [].""" + retriever = KeywordRetriever(catalog=[]) + run = retriever(RunContext(query="주문")) + + assert run.schema_selected == [] + + +def test_meta_preserved_in_results(): + """meta 필드가 결과 dict에 그대로 포함되는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문")) + + result = run.schema_selected[0] + assert "meta" in result + assert result["meta"]["primary_key"] == "order_id" + + +def test_index_fields_meta(): + """index_fields=["description","meta"] → meta 텍스트도 검색에 반영.""" + # finance라는 단어는 meta.tags에만 존재 (name/description/columns에는 없음) + catalog = [ + { + "name": "alpha", + "description": "일반 데이터 저장소", + "meta": {"tags": ["finance", "core"]}, + }, + { + "name": "beta", + "description": "기타 로그 테이블", + "meta": {"tags": ["logging"]}, + }, + ] + + retriever = KeywordRetriever( + catalog=catalog, + index_fields=["description", "meta"], + ) + run = retriever(RunContext(query="finance")) + + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "alpha" + + +def test_result_order_by_relevance(): + """관련도 높은 테이블이 앞에 위치하는지 확인.""" + catalog = [ + { + "name": "order_summary", + "description": "주문 요약 주문 집계 주문 통계", # '주문' 3회 + }, + { + "name": "user_table", + "description": "사용자 주문 기록", # '주문' 1회 + }, + ] + + retriever = KeywordRetriever(catalog=catalog) + run = retriever(RunContext(query="주문")) + + assert len(run.schema_selected) >= 2 + assert run.schema_selected[0]["name"] == "order_summary" + + +def test_columns_text_indexed(): + """컬럼명/컬럼설명으로 검색 가능한지 확인.""" + catalog = [ + { + "name": "sales", + "description": "판매 데이터", + "columns": {"revenue": "매출액", "region": "지역"}, + }, + { + "name": "logs", + "description": "시스템 로그", + "columns": {"event_type": "이벤트 유형"}, + }, + ] + + retriever = KeywordRetriever(catalog=catalog) + run = retriever(RunContext(query="매출액")) + + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "sales" + + +def test_missing_optional_fields_no_error(): + """columns/meta 없는 entry가 있어도 crash 없음.""" + catalog = [ + {"name": "minimal", "description": "최소 필드만 있는 테이블"}, + {"name": "full", "description": "전체 필드", "columns": {"id": "ID"}, "meta": {}}, + ] + + retriever = KeywordRetriever(catalog=catalog) + # 예외가 발생하지 않으면 테스트 통과 + run = retriever(RunContext(query="테이블")) + assert isinstance(run.schema_selected, list) + + +def test_end_to_end_in_sequential_flow(): + """SequentialFlow(steps=[retriever]).run_query('...') 가 동작하는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + flow = SequentialFlow(steps=[retriever]) + + run = flow.run_query("주문 내역 확인") + + assert isinstance(run, RunContext) + assert isinstance(run.schema_selected, list) + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "order_table" From b38b1dd43d19684737584378c51ee52f114e3926 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 23 Feb 2026 00:01:16 +0900 Subject: [PATCH 5/5] refactor : black check --- src/lang2sql/components/retrieval/_bm25.py | 1 + src/lang2sql/components/retrieval/keyword.py | 10 ++++------ tests/test_components_keyword_retriever.py | 9 +++++++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/lang2sql/components/retrieval/_bm25.py b/src/lang2sql/components/retrieval/_bm25.py index df82187..d3e87c0 100644 --- a/src/lang2sql/components/retrieval/_bm25.py +++ b/src/lang2sql/components/retrieval/_bm25.py @@ -7,6 +7,7 @@ Tokenization: text.lower().split() (whitespace, no external deps) """ + from __future__ import annotations import math diff --git a/src/lang2sql/components/retrieval/keyword.py b/src/lang2sql/components/retrieval/keyword.py index 74fd17d..1444086 100644 --- a/src/lang2sql/components/retrieval/keyword.py +++ b/src/lang2sql/components/retrieval/keyword.py @@ -49,7 +49,9 @@ def __init__( super().__init__(name=name or "KeywordRetriever", hook=hook) self._catalog = catalog self._top_n = top_n - self._index_fields = index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS + self._index_fields = ( + index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS + ) self._index = _BM25Index(catalog, self._index_fields) def run(self, run: RunContext) -> RunContext: @@ -77,11 +79,7 @@ def run(self, run: RunContext) -> RunContext: ) # Return up to top_n entries that have a positive score - results = [ - entry - for score, entry in ranked[: self._top_n] - if score > 0.0 - ] + results = [entry for score, entry in ranked[: self._top_n] if score > 0.0] run.schema_selected = results return run diff --git a/tests/test_components_keyword_retriever.py b/tests/test_components_keyword_retriever.py index 772608f..27b9d5c 100644 --- a/tests/test_components_keyword_retriever.py +++ b/tests/test_components_keyword_retriever.py @@ -4,6 +4,7 @@ Pattern follows test_core_base.py: - pytest, inline fixtures, MemoryHook """ + import pytest from lang2sql.components.retrieval import KeywordRetriever @@ -11,7 +12,6 @@ from lang2sql.core.hooks import MemoryHook from lang2sql.flows.baseline import SequentialFlow - # ------------------------- # Shared test catalog # ------------------------- @@ -200,7 +200,12 @@ def test_missing_optional_fields_no_error(): """columns/meta 없는 entry가 있어도 crash 없음.""" catalog = [ {"name": "minimal", "description": "최소 필드만 있는 테이블"}, - {"name": "full", "description": "전체 필드", "columns": {"id": "ID"}, "meta": {}}, + { + "name": "full", + "description": "전체 필드", + "columns": {"id": "ID"}, + "meta": {}, + }, ] retriever = KeywordRetriever(catalog=catalog)