diff --git a/bitnet_tools/planner.py b/bitnet_tools/planner.py index 85fe941..0d1fc0f 100644 --- a/bitnet_tools/planner.py +++ b/bitnet_tools/planner.py @@ -19,6 +19,8 @@ class AnalysisIntent: region_column: str | None = None compare_periods: bool = False metric_column: str | None = None + intent_schema: dict[str, Any] = field(default_factory=dict) + routing_source: str = "fallback" @dataclass @@ -112,20 +114,56 @@ def parse_question_to_intent(question: str, schema: dict[str, Any]) -> AnalysisI intent.region = str(rg) break - if intent.region is None and intent.region_column: + should_infer_region = any(token in text for token in ["지역", "region", "도시", "city", "area"]) + if intent.region is None and intent.region_column and should_infer_region: tokens = [t for t in re.split(r"\s+", text) if t] for tok in tokens: if re.fullmatch(r"[가-힣A-Za-z][가-힣A-Za-z0-9_-]+", tok): - if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플"}: + if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플", "지역", "도시", "region", "city", "area"}: if tok in columns: continue intent.region = tok break intent.metric_column = _first_numeric_column(schema) + intent.intent_schema = _build_intent_parsing_schema(intent) return intent +def _build_intent_parsing_schema(intent: AnalysisIntent) -> dict[str, Any]: + return { + "topN": intent.top_n, + "sampleN": intent.sample_n, + "threshold": { + "value": intent.threshold, + "column": intent.threshold_column, + }, + "filter": { + "region": intent.region, + "region_column": intent.region_column, + }, + "compare": intent.compare_periods, + "include_code": _should_include_code(intent.question), + } + + +def _should_include_code(question: str) -> bool: + lowered = (question or "").lower() + return any(token in lowered for token in ["include_code", "include code", "코드 포함", "코드도"]) + + +def _template_query_type(intent: AnalysisIntent) -> str | None: + if intent.compare_periods and intent.threshold is not None: + return "compare_threshold" + if intent.top_n is not None and (intent.region is not None or intent.threshold is not None): + return "top_filtered" + if intent.top_n is not None: + return "top_only" + if intent.sample_n is not None: + return "sample_only" + return None + + def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> AnalysisPlan: warnings: list[str] = [] group_col = intent.region_column or _first_text_column(schema_profile) @@ -134,6 +172,72 @@ def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> Analys if metric_col is None: warnings.append("numeric metric column not found") + template_type = _template_query_type(intent) + if template_type: + intent.routing_source = f"template:{template_type}" + return _build_template_plan(intent, group_col, metric_col, warnings) + + intent.routing_source = "fallback" + warnings.append("template not matched: fallback planner used") + + return _build_fallback_plan(intent, group_col, metric_col, warnings) + + +def _build_template_plan( + intent: AnalysisIntent, + group_col: str | None, + metric_col: str | None, + warnings: list[str], +) -> AnalysisPlan: + template_nodes = [ + { + "op": "filter", + "enabled": bool(intent.region or intent.threshold is not None), + "region_column": intent.region_column, + "region": intent.region, + "threshold_column": intent.threshold_column, + "threshold": intent.threshold, + }, + { + "op": "groupby", + "enabled": bool(group_col), + "columns": [group_col] if group_col else [], + }, + { + "op": "agg", + "enabled": bool(metric_col), + "metric": metric_col, + "fn": "sum", + }, + { + "op": "rank", + "enabled": bool(intent.top_n), + "top_n": intent.top_n, + "order": "desc", + }, + { + "op": "sample", + "enabled": bool(intent.sample_n), + "sample_n": intent.sample_n, + "seed": 42, + }, + { + "op": "export", + "enabled": True, + "include_meta": True, + }, + ] + + return AnalysisPlan(intent=intent, nodes=template_nodes, fallback=False, warnings=warnings) + + +def _build_fallback_plan( + intent: AnalysisIntent, + group_col: str | None, + metric_col: str | None, + warnings: list[str], +) -> AnalysisPlan: + nodes = [ { "op": "filter", diff --git a/tests/test_planner.py b/tests/test_planner.py index c01013d..7e40b81 100644 --- a/tests/test_planner.py +++ b/tests/test_planner.py @@ -18,14 +18,37 @@ def test_parse_question_to_intent_extracts_controls(): assert intent.threshold_column == "sales" assert intent.region == "서울" assert intent.compare_periods is True + assert intent.intent_schema == { + "topN": 3, + "sampleN": 2, + "threshold": {"value": 100, "column": "sales"}, + "filter": {"region": "서울", "region_column": "region"}, + "compare": True, + "include_code": False, + } + + +def test_parse_question_to_intent_extracts_include_code_flag(): + intent = parse_question_to_intent("서울 매출 top 5 코드 포함", _schema()) + assert intent.intent_schema["include_code"] is True -def test_build_plan_contains_execution_graph_nodes(): + +def test_build_plan_routes_template_queries_first(): intent = parse_question_to_intent("상위 5 샘플 2", _schema()) plan = build_plan(intent, _schema()) assert [n["op"] for n in plan.nodes] == ["filter", "groupby", "agg", "rank", "sample", "export"] assert any(node["op"] == "rank" and node["enabled"] for node in plan.nodes) + assert plan.intent.routing_source == "template:top_only" + + +def test_build_plan_routes_only_unmatched_queries_to_fallback(): + intent = parse_question_to_intent("기본 분석만 해줘", _schema()) + plan = build_plan(intent, _schema()) + + assert plan.intent.routing_source == "fallback" + assert any("fallback" in warning for warning in plan.warnings) def test_execute_plan_fallback_on_invalid_node(): @@ -39,3 +62,47 @@ def test_execute_plan_fallback_on_invalid_node(): assert result["meta"]["fallback"] is True assert "unsupported op" in result["meta"]["error"] + + +def test_user_examples_are_kept_as_regression_cases(): + cases = [ + ( + "서울 지역 top 3, sample 2, sales 임계값 100 전/후 비교", + { + "top_n": 3, + "sample_n": 2, + "threshold": 100, + "region": "서울", + "compare_periods": True, + }, + ), + ( + "상위 5 샘플 2", + { + "top_n": 5, + "sample_n": 2, + "threshold": None, + "region": None, + "compare_periods": False, + }, + ), + ( + "기본 분석만 해줘", + { + "top_n": None, + "sample_n": None, + "threshold": None, + "region": None, + "compare_periods": False, + }, + ), + ] + + for question, expected in cases: + intent = parse_question_to_intent(question, _schema()) + + assert intent.top_n == expected["top_n"] + assert intent.sample_n == expected["sample_n"] + assert intent.threshold == expected["threshold"] + assert intent.region == expected["region"] + assert intent.compare_periods == expected["compare_periods"]