Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions bitnet_tools/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class AnalysisIntent:
region_column: str | None = None
compare_periods: bool = False
metric_column: str | None = None
intent_schema: dict[str, Any] = field(default_factory=dict)
routing_source: str = "fallback"


@dataclass
Expand Down Expand Up @@ -112,20 +114,56 @@ def parse_question_to_intent(question: str, schema: dict[str, Any]) -> AnalysisI
intent.region = str(rg)
break

if intent.region is None and intent.region_column:
should_infer_region = any(token in text for token in ["지역", "region", "도시", "city", "area"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize region-keyword checks before gating inference

The new should_infer_region gate checks lowercase English keywords against the original text, so mixed/uppercase inputs (for example, "Region seoul top 3") no longer enter the region-inference block when region_values is absent. In that case intent.region stays None, the filter node is disabled, and the planner can route as unfiltered (top_only) even though the query includes a region qualifier, which changes analysis results based only on casing.

Useful? React with 👍 / 👎.

if intent.region is None and intent.region_column and should_infer_region:
tokens = [t for t in re.split(r"\s+", text) if t]
for tok in tokens:
if re.fullmatch(r"[가-힣A-Za-z][가-힣A-Za-z0-9_-]+", tok):
if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플"}:
if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플", "지역", "도시", "region", "city", "area"}:
if tok in columns:
continue
intent.region = tok
break

intent.metric_column = _first_numeric_column(schema)
intent.intent_schema = _build_intent_parsing_schema(intent)
return intent


def _build_intent_parsing_schema(intent: AnalysisIntent) -> dict[str, Any]:
return {
"topN": intent.top_n,
"sampleN": intent.sample_n,
"threshold": {
"value": intent.threshold,
"column": intent.threshold_column,
},
"filter": {
"region": intent.region,
"region_column": intent.region_column,
},
"compare": intent.compare_periods,
"include_code": _should_include_code(intent.question),
}


def _should_include_code(question: str) -> bool:
lowered = (question or "").lower()
return any(token in lowered for token in ["include_code", "include code", "코드 포함", "코드도"])


def _template_query_type(intent: AnalysisIntent) -> str | None:
if intent.compare_periods and intent.threshold is not None:
return "compare_threshold"
if intent.top_n is not None and (intent.region is not None or intent.threshold is not None):
return "top_filtered"
if intent.top_n is not None:
return "top_only"
if intent.sample_n is not None:
return "sample_only"
return None


def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> AnalysisPlan:
warnings: list[str] = []
group_col = intent.region_column or _first_text_column(schema_profile)
Expand All @@ -134,6 +172,72 @@ def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> Analys
if metric_col is None:
warnings.append("numeric metric column not found")

template_type = _template_query_type(intent)
if template_type:
intent.routing_source = f"template:{template_type}"
return _build_template_plan(intent, group_col, metric_col, warnings)

intent.routing_source = "fallback"
warnings.append("template not matched: fallback planner used")

return _build_fallback_plan(intent, group_col, metric_col, warnings)


def _build_template_plan(
intent: AnalysisIntent,
group_col: str | None,
metric_col: str | None,
warnings: list[str],
) -> AnalysisPlan:
template_nodes = [
{
"op": "filter",
"enabled": bool(intent.region or intent.threshold is not None),
"region_column": intent.region_column,
"region": intent.region,
"threshold_column": intent.threshold_column,
"threshold": intent.threshold,
},
{
"op": "groupby",
"enabled": bool(group_col),
"columns": [group_col] if group_col else [],
},
{
"op": "agg",
"enabled": bool(metric_col),
"metric": metric_col,
"fn": "sum",
},
{
"op": "rank",
"enabled": bool(intent.top_n),
"top_n": intent.top_n,
"order": "desc",
},
{
"op": "sample",
"enabled": bool(intent.sample_n),
"sample_n": intent.sample_n,
"seed": 42,
},
{
"op": "export",
"enabled": True,
"include_meta": True,
},
]

return AnalysisPlan(intent=intent, nodes=template_nodes, fallback=False, warnings=warnings)


def _build_fallback_plan(
intent: AnalysisIntent,
group_col: str | None,
metric_col: str | None,
warnings: list[str],
) -> AnalysisPlan:

nodes = [
{
"op": "filter",
Expand Down
69 changes: 68 additions & 1 deletion tests/test_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,37 @@ def test_parse_question_to_intent_extracts_controls():
assert intent.threshold_column == "sales"
assert intent.region == "서울"
assert intent.compare_periods is True
assert intent.intent_schema == {
"topN": 3,
"sampleN": 2,
"threshold": {"value": 100, "column": "sales"},
"filter": {"region": "서울", "region_column": "region"},
"compare": True,
"include_code": False,
}


def test_parse_question_to_intent_extracts_include_code_flag():
intent = parse_question_to_intent("서울 매출 top 5 코드 포함", _schema())

assert intent.intent_schema["include_code"] is True

def test_build_plan_contains_execution_graph_nodes():

def test_build_plan_routes_template_queries_first():
intent = parse_question_to_intent("상위 5 샘플 2", _schema())
plan = build_plan(intent, _schema())

assert [n["op"] for n in plan.nodes] == ["filter", "groupby", "agg", "rank", "sample", "export"]
assert any(node["op"] == "rank" and node["enabled"] for node in plan.nodes)
assert plan.intent.routing_source == "template:top_only"


def test_build_plan_routes_only_unmatched_queries_to_fallback():
intent = parse_question_to_intent("기본 분석만 해줘", _schema())
plan = build_plan(intent, _schema())

assert plan.intent.routing_source == "fallback"
assert any("fallback" in warning for warning in plan.warnings)


def test_execute_plan_fallback_on_invalid_node():
Expand All @@ -39,3 +62,47 @@ def test_execute_plan_fallback_on_invalid_node():

assert result["meta"]["fallback"] is True
assert "unsupported op" in result["meta"]["error"]


def test_user_examples_are_kept_as_regression_cases():
cases = [
(
"서울 지역 top 3, sample 2, sales 임계값 100 전/후 비교",
{
"top_n": 3,
"sample_n": 2,
"threshold": 100,
"region": "서울",
"compare_periods": True,
},
),
(
"상위 5 샘플 2",
{
"top_n": 5,
"sample_n": 2,
"threshold": None,
"region": None,
"compare_periods": False,
},
),
(
"기본 분석만 해줘",
{
"top_n": None,
"sample_n": None,
"threshold": None,
"region": None,
"compare_periods": False,
},
),
]

for question, expected in cases:
intent = parse_question_to_intent(question, _schema())

assert intent.top_n == expected["top_n"]
assert intent.sample_n == expected["sample_n"]
assert intent.threshold == expected["threshold"]
assert intent.region == expected["region"]
assert intent.compare_periods == expected["compare_periods"]