Skip to content

Commit 19c0bb6

Browse files
feat(backend/kernel): wire TSparkParameter through to kernel bind_param
Lifts the NotSupportedError that execute_command currently raises for parametrized queries. The kernel-side PyO3 binding for Statement.bind_param landed in databricks-sql-kernel#18; this commit wires the connector's TSparkParameter shape through to it. Implementation: - New `bind_tspark_params(kernel_stmt, parameters)` in type_mapping.py forwards each TSparkParameter to the kernel as `(ordinal, value.stringValue, type)`. ordinal is the 1-based position in the parameters list; the connector's `ordinal: bool` flag is checked only to reject named bindings (kernel v0 doesn't accept them on the wire). - execute_command no longer raises on `parameters=[...]`. The query_tags branch stays — that's a separate gap. Tests: - 6 new unit tests in tests/unit/test_kernel_type_mapping.py for the mapper: - positional forwarding preserves ordering and (ordinal, value, type) - None value forwards as SQL NULL - VOID passes through verbatim (kernel parser ignores value for VOID) - named bindings raise NotSupportedError with a pointed message - missing TSparkParameter.type defaults to STRING (defensive) - empty parameters list is a no-op - `test_execute_command_rejects_parameters` (which previously asserted the NotSupportedError) replaced with `test_execute_command_forwards_parameters_to_bind_param` — stubs the kernel statement and verifies bind_param is called once per TSparkParameter in order with 1-based ordinals, and execute fires after binding. - 3 new e2e tests in tests/e2e/test_kernel_backend.py against dogfood: - mixed-type round-trip (INT, STRING, BOOLEAN) via the connector's native IntegerParameter/StringParameter/BooleanParameter - None parameter (VoidParameter → SQL NULL) - DECIMAL parameter with precision/scale carried in the SQL type string (auto-inferred — explicit-arg path has a pre-existing bug in native.py where format-args are swapped) 106/106 kernel unit tests pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 9164ba0 commit 19c0bb6

5 files changed

Lines changed: 278 additions & 28 deletions

File tree

src/databricks/sql/backend/kernel/client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
1515
Phase 1 gaps documented in the integration design:
1616
17-
- Parameter binding (``parameters=[TSparkParameter, ...]``) is not
18-
yet supported — the PyO3 ``Statement`` doesn't expose
19-
``bind_param``. ``execute_command(parameters=[...])`` raises
20-
``NotSupportedError``.
2117
- ``query_tags`` on execute is not supported (kernel exposes
2218
``statement_conf`` but PyO3 doesn't surface it).
2319
- ``get_tables`` with a non-empty ``table_types`` filter applies
@@ -231,11 +227,6 @@ def execute_command(
231227
) -> Union["ResultSet", None]:
232228
if self._kernel_session is None:
233229
raise InterfaceError("Cannot execute_command without an open session.")
234-
if parameters:
235-
raise NotSupportedError(
236-
"Parameter binding is not yet supported on the kernel backend "
237-
"(PyO3 Statement.bind_param lands in a follow-up PR)."
238-
)
239230
if query_tags:
240231
raise NotSupportedError(
241232
"Statement-level query_tags are not yet supported on the kernel backend."
@@ -248,6 +239,15 @@ def execute_command(
248239
try:
249240
try:
250241
stmt.set_sql(operation)
242+
if parameters:
243+
# Lazy import — type_mapping touches pyarrow at
244+
# module load; keep ``execute_command`` callable
245+
# from contexts that don't yet need it.
246+
from databricks.sql.backend.kernel.type_mapping import (
247+
bind_tspark_params,
248+
)
249+
250+
bind_tspark_params(stmt, parameters)
251251
if async_op:
252252
async_exec = stmt.submit()
253253
command_id = CommandId.from_sea_statement_id(

src/databricks/sql/backend/kernel/type_mapping.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@
1313
the kernel receives Arrow schemas directly), so the mapping
1414
function stays local but the names are shared.
1515
16-
Parameter binding (``TSparkParameter`` → kernel ``TypedValue``) is
17-
not yet implemented — the PyO3 ``Statement`` doesn't expose a
18-
``bind_param`` method on this branch. It'll land in a follow-up
19-
once that PyO3 surface ships.
16+
Parameter binding (``TSparkParameter`` → kernel
17+
``Statement.bind_param``) is handled by ``bind_tspark_params``
18+
forwards the connector's already-string-encoded form to the kernel
19+
binding without an intermediate Python-typed round-trip.
2020
"""
2121

2222
from __future__ import annotations
2323

24-
from typing import List, Tuple
24+
from typing import Any, List, Tuple
2525

2626
import pyarrow
2727

2828
from databricks.sql.backend.sea.utils.conversion import SqlType
29+
from databricks.sql.thrift_api.TCLIService import ttypes
2930

3031

3132
def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str:
@@ -92,3 +93,55 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]:
9293
)
9394
for field in schema
9495
]
96+
97+
98+
def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any:
99+
"""Extract the string-encoded value from a ``TSparkParameter``,
100+
or ``None`` for SQL NULL.
101+
102+
Native parameters (``IntegerParameter`` etc.) always wrap their
103+
value in ``TSparkParameterValue(stringValue=str(self.value))``;
104+
``VoidParameter`` sets ``stringValue="None"`` but the type is
105+
``"VOID"`` — the kernel-side parser ignores the value when the
106+
type is VOID, so we don't have to special-case here.
107+
"""
108+
if param.value is None:
109+
return None
110+
return param.value.stringValue
111+
112+
113+
def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> None:
114+
"""Bind a list of ``TSparkParameter`` onto a kernel ``Statement``.
115+
116+
The kernel expects positional bindings only (SEA v0 doesn't
117+
accept named bindings on the wire). The connector's
118+
``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means
119+
"treat as positional in source-list order". Native bindings
120+
almost always come through positional today; named-binding
121+
parameters surface as ``NotSupportedError`` so the user gets a
122+
clear message instead of a server-side rejection.
123+
124+
Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) are routed
125+
through the kernel parser which currently rejects them — same
126+
user-visible message ("compound parameter types … are not yet
127+
supported"). Tracked as a follow-up.
128+
"""
129+
for i, param in enumerate(parameters, start=1):
130+
# The connector's `ordinal` field is a bool (True/False) on
131+
# native params and indicates positional vs named. Named
132+
# params can't flow through the kernel today; raise early
133+
# rather than letting the server reject.
134+
if getattr(param, "ordinal", None) is False and getattr(param, "name", None):
135+
from databricks.sql.exc import NotSupportedError
136+
137+
raise NotSupportedError(
138+
f"Named parameter binding (got name={param.name!r}) is not yet "
139+
"supported on the kernel backend; pass parameters positionally."
140+
)
141+
142+
sql_type = param.type or "STRING"
143+
value_str = _tspark_param_value_str(param)
144+
# The kernel takes 1-based ordinals; `i` is already that.
145+
# Errors from the kernel side (bad literal, unsupported type,
146+
# etc.) come up as KernelError and bubble through normally.
147+
kernel_stmt.bind_param(i, value_str, sql_type)

tests/e2e/test_kernel_backend.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,69 @@ def test_bad_sql_surfaces_as_databaseerror(conn):
199199
# Structured fields copied off the kernel exception:
200200
assert getattr(err, "code", None) == "SqlError"
201201
assert getattr(err, "sql_state", None) == "42P01"
202+
203+
204+
# ── Parameter binding ─────────────────────────────────────────────
205+
206+
207+
def test_parameterized_query_round_trips(conn):
208+
"""Positional parameter binding via the kernel backend. The
209+
connector's native parameter classes (IntegerParameter etc.)
210+
serialize to TSparkParameter under the hood; the kernel
211+
backend's mapper forwards them positionally to the kernel.
212+
"""
213+
from databricks.sql.parameters.native import (
214+
IntegerParameter,
215+
StringParameter,
216+
BooleanParameter,
217+
)
218+
219+
with conn.cursor() as cur:
220+
cur.execute(
221+
"SELECT ? AS i, ? AS s, ? AS b",
222+
[
223+
IntegerParameter(42),
224+
StringParameter("alice"),
225+
BooleanParameter(True),
226+
],
227+
)
228+
rows = cur.fetchall()
229+
assert len(rows) == 1
230+
assert rows[0][0] == 42
231+
assert rows[0][1] == "alice"
232+
assert rows[0][2] is True
233+
234+
235+
def test_parameterized_query_with_null(conn):
236+
"""`None` in the parameter list flows through as VoidParameter
237+
→ kernel TypedValue::Null."""
238+
with conn.cursor() as cur:
239+
cur.execute("SELECT ? IS NULL AS is_null", [None])
240+
rows = cur.fetchall()
241+
assert rows[0][0] is True
242+
243+
244+
def test_parameterized_query_decimal(conn):
245+
"""DECIMAL parameters carry precision/scale in the SQL type
246+
string ('DECIMAL(p,s)') — the kernel parser extracts them so
247+
fractional digits survive the wire.
248+
249+
Uses the connector's auto-inference path
250+
(`calculate_decimal_cast_string`) to derive precision/scale
251+
from the value; the explicit-arg path
252+
(`DecimalParameter(v, scale=, precision=)`) has a pre-existing
253+
bug in this branch where the format-args are passed
254+
`(scale, precision)` instead of `(precision, scale)` — out of
255+
scope for this PR.
256+
"""
257+
import decimal
258+
from databricks.sql.parameters.native import DecimalParameter
259+
260+
with conn.cursor() as cur:
261+
cur.execute(
262+
"SELECT ? AS d",
263+
[DecimalParameter(decimal.Decimal("-123.45"))],
264+
)
265+
rows = cur.fetchall()
266+
# Server echoes back as decimal.Decimal.
267+
assert str(rows[0][0]) == "-123.45"

tests/unit/test_kernel_client.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -234,25 +234,57 @@ def test_open_session_rejects_double_open(monkeypatch):
234234
c.open_session(session_configuration=None, catalog=None, schema=None)
235235

236236

237-
def test_execute_command_rejects_parameters():
237+
def test_execute_command_forwards_parameters_to_bind_param():
238+
"""``execute_command(parameters=[...])`` routes each parameter
239+
through ``bind_tspark_params`` onto the kernel statement before
240+
``execute()`` is called. Replaces the prior ``NotSupportedError``
241+
rejection now that the kernel-side ``Statement.bind_param`` is
242+
live (kernel PR #18)."""
243+
from databricks.sql.thrift_api.TCLIService import ttypes
244+
238245
c = _make_client()
239246
c._kernel_session = MagicMock()
240247
cursor = MagicMock()
241248
cursor.arraysize = 100
242249
cursor.buffer_size_bytes = 1024
243-
with pytest.raises(NotSupportedError, match="Parameter binding"):
244-
c.execute_command(
245-
operation="SELECT ?",
246-
session_id=MagicMock(),
247-
max_rows=1,
248-
max_bytes=1,
249-
lz4_compression=False,
250-
cursor=cursor,
251-
use_cloud_fetch=False,
252-
parameters=[object()], # any non-empty list
253-
async_op=False,
254-
enforce_embedded_schema_correctness=False,
255-
)
250+
251+
# Stub the statement chain so we can observe bind_param calls
252+
# without exercising the full ExecutedStatement → arrow_schema()
253+
# path (that's covered elsewhere).
254+
stmt = MagicMock()
255+
stmt.bind_param = MagicMock()
256+
stmt.execute.return_value = MagicMock(
257+
statement_id="stmt-id",
258+
arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])),
259+
)
260+
c._kernel_session.statement.return_value = stmt
261+
262+
p1 = ttypes.TSparkParameter(ordinal=True, name=None, type="INT")
263+
p1.value = ttypes.TSparkParameterValue(stringValue="42")
264+
p2 = ttypes.TSparkParameter(ordinal=True, name=None, type="STRING")
265+
p2.value = ttypes.TSparkParameterValue(stringValue="hello")
266+
267+
c.execute_command(
268+
operation="SELECT ?, ?",
269+
session_id=MagicMock(),
270+
max_rows=1,
271+
max_bytes=1,
272+
lz4_compression=False,
273+
cursor=cursor,
274+
use_cloud_fetch=False,
275+
parameters=[p1, p2],
276+
async_op=False,
277+
enforce_embedded_schema_correctness=False,
278+
)
279+
280+
# bind_param was called once per TSparkParameter, in order, with
281+
# 1-based ordinals.
282+
assert stmt.bind_param.call_args_list == [
283+
((1, "42", "INT"), {}),
284+
((2, "hello", "STRING"), {}),
285+
]
286+
# …and execute fired after binding.
287+
assert stmt.execute.called
256288

257289

258290
def test_execute_command_rejects_query_tags():

tests/unit/test_kernel_type_mapping.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,102 @@ def test_description_from_schema_reports_non_nullable_fields():
8484
desc = description_from_arrow_schema(schema)
8585
assert desc[0][6] is False
8686
assert desc[1][6] is True
87+
88+
89+
# ─── bind_tspark_params ──────────────────────────────────────────────────
90+
91+
92+
def _mk_param(*, type, value, ordinal=True, name=None):
93+
"""Build a minimal TSparkParameter for tests."""
94+
from databricks.sql.thrift_api.TCLIService import ttypes
95+
96+
p = ttypes.TSparkParameter(ordinal=ordinal, name=name, type=type)
97+
p.value = ttypes.TSparkParameterValue(stringValue=value) if value is not None else None
98+
return p
99+
100+
101+
class _RecordingStmt:
102+
"""Stand-in for the kernel `Statement` pyclass — records every
103+
`bind_param` call so tests can assert the (ordinal, value, type)
104+
triples the mapper forwarded."""
105+
106+
def __init__(self):
107+
self.calls = []
108+
109+
def bind_param(self, ordinal, value_str, sql_type):
110+
self.calls.append((ordinal, value_str, sql_type))
111+
112+
113+
def test_bind_tspark_params_forwards_each_param_positionally():
114+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
115+
116+
params = [
117+
_mk_param(type="INT", value="42"),
118+
_mk_param(type="STRING", value="alice"),
119+
_mk_param(type="DATE", value="2026-05-15"),
120+
]
121+
stmt = _RecordingStmt()
122+
bind_tspark_params(stmt, params)
123+
assert stmt.calls == [
124+
(1, "42", "INT"),
125+
(2, "alice", "STRING"),
126+
(3, "2026-05-15", "DATE"),
127+
]
128+
129+
130+
def test_bind_tspark_params_null_value():
131+
"""TSparkParameter with value=None → kernel sees value_str=None,
132+
interpreted as SQL NULL regardless of the SQL type."""
133+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
134+
135+
p = _mk_param(type="STRING", value=None)
136+
stmt = _RecordingStmt()
137+
bind_tspark_params(stmt, [p])
138+
assert stmt.calls == [(1, None, "STRING")]
139+
140+
141+
def test_bind_tspark_params_void_passes_through():
142+
"""VoidParameter sets type='VOID' with stringValue='None'; the
143+
kernel parser ignores the value when type=VOID."""
144+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
145+
146+
p = _mk_param(type="VOID", value="None")
147+
stmt = _RecordingStmt()
148+
bind_tspark_params(stmt, [p])
149+
assert stmt.calls == [(1, "None", "VOID")]
150+
151+
152+
def test_bind_tspark_params_named_param_rejected():
153+
"""The kernel doesn't accept named bindings on the SEA wire;
154+
surface that at the connector layer so the user gets a pointed
155+
error instead of a server-side rejection."""
156+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
157+
from databricks.sql.exc import NotSupportedError
158+
159+
p = _mk_param(type="INT", value="42", ordinal=False, name="my_param")
160+
stmt = _RecordingStmt()
161+
with pytest.raises(NotSupportedError, match="(?i)named"):
162+
bind_tspark_params(stmt, [p])
163+
# Nothing should have been forwarded before the rejection.
164+
assert stmt.calls == []
165+
166+
167+
def test_bind_tspark_params_missing_type_defaults_to_string():
168+
"""Defensive: a TSparkParameter with no `type` shouldn't crash
169+
the mapper — fall back to STRING and let the kernel parse."""
170+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
171+
from databricks.sql.thrift_api.TCLIService import ttypes
172+
173+
p = ttypes.TSparkParameter(ordinal=True, name=None, type=None)
174+
p.value = ttypes.TSparkParameterValue(stringValue="hello")
175+
stmt = _RecordingStmt()
176+
bind_tspark_params(stmt, [p])
177+
assert stmt.calls == [(1, "hello", "STRING")]
178+
179+
180+
def test_bind_tspark_params_empty_list_is_noop():
181+
from databricks.sql.backend.kernel.type_mapping import bind_tspark_params
182+
183+
stmt = _RecordingStmt()
184+
bind_tspark_params(stmt, [])
185+
assert stmt.calls == []

0 commit comments

Comments
 (0)