diff --git a/conftest.py b/conftest.py index c8b350bee..748f73443 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,41 @@ +import importlib.util import os import pytest +def _kernel_wheel_available() -> bool: + """The ``use_sea=True`` code path now routes through the Rust + kernel via PyO3. The ``databricks_sql_kernel`` wheel is not + yet on PyPI (built from a separate repo); CI environments + without it should skip ``use_sea=True`` parametrized cases + rather than fail with a hard ImportError.""" + return importlib.util.find_spec("databricks_sql_kernel") is not None + + +def pytest_collection_modifyitems(config, items): + """Skip parametrized test cases that pass ``use_sea=True`` when + the kernel wheel isn't installed. + + The existing e2e suite uses ``@pytest.mark.parametrize( + "extra_params", [{}, {"use_sea": True}])`` to exercise both + backends. When the kernel wheel is missing those cases die at + ``connect()`` time with our pointed ImportError; mark them + skipped at collection time so CI signal stays accurate. + """ + if _kernel_wheel_available(): + return + skip_marker = pytest.mark.skip( + reason="use_sea=True requires databricks-sql-kernel (not installed)" + ) + for item in items: + params = getattr(item, "callspec", None) + if params is None: + continue + extra_params = params.params.get("extra_params") + if isinstance(extra_params, dict) and extra_params.get("use_sea") is True: + item.add_marker(skip_marker) + + @pytest.fixture(scope="session") def host(): return os.getenv("DATABRICKS_SERVER_HOSTNAME") diff --git a/pyproject.toml b/pyproject.toml index 5e9f7f0ca..6868919d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,20 @@ requests-kerberos = {version = "^0.15.0", optional = true} [tool.poetry.extras] pyarrow = ["pyarrow"] +# `[kernel]` extra is intentionally not declared here yet. +# `databricks-sql-kernel` is built from the databricks-sql-kernel +# repo and not yet published to PyPI; declaring it as a poetry dep +# breaks `poetry lock` for every CI job. Once the wheel is on PyPI +# the extra will be added back here: +# +# databricks-sql-kernel = {version = "^0.1.0", optional = true} +# [tool.poetry.extras] +# kernel = ["databricks-sql-kernel"] +# +# Until then, install the kernel separately: +# pip install databricks-sql-kernel +# or (local dev): +# cd databricks-sql-kernel/pyo3 && maturin develop --release [tool.poetry.group.dev.dependencies] pytest = "^7.1.2" diff --git a/scripts/bench_kernel_vs_thrift.py b/scripts/bench_kernel_vs_thrift.py new file mode 100644 index 000000000..243b73f5a --- /dev/null +++ b/scripts/bench_kernel_vs_thrift.py @@ -0,0 +1,294 @@ +"""Benchmark the kernel-backed connector against the Thrift backend. + +One-shot script, not a CI gate. Runs each (backend × SQL-shape) +combination N+1 times against a live warehouse, drops the first +run (cache warm-up), and reports min / median / max wall-clock for +session-open, time-to-first-row, drain, and RSS delta. + +Usage: + + set -a && source ~/.databricks/pecotesting-creds && set +a + # If DATABRICKS_HOST is set but DATABRICKS_SERVER_HOSTNAME is + # not, normalise it (matches the e2e suite's convention). + export DATABRICKS_SERVER_HOSTNAME=${DATABRICKS_SERVER_HOSTNAME:-${DATABRICKS_HOST#https://}} + .venv/bin/python scripts/bench_kernel_vs_thrift.py + +Honest disclaimers: +- Single warehouse, single machine, single network route. High + server-side variance is expected. +- Server-side caches warm differently between back-to-back runs; + the first-run-drop helps but doesn't eliminate it. +- Comparison is **kernel-backed vs Thrift**. The pure-Python + native SEA backend (``backend/sea/``) is no longer reachable via + ``use_sea=True`` after this PR, so it's not included. +- RSS delta is process-wide and includes pyarrow tables we hold + in scope during the drain. Two-orders-of-magnitude differences + are signal; 10% differences are noise. + +The output is a Markdown table you can paste into a PR +description. +""" + +from __future__ import annotations + +import argparse +import gc +import os +import resource +import statistics +import sys +import time +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple + +import databricks.sql as sql + + +# ─── Config ────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Shape: + name: str + sql: Optional[str] # None means it's a metadata call + metadata_call: Optional[str] # e.g. "catalogs" + expected_rows: Optional[int] # None when we don't assert + + +SHAPES: List[Shape] = [ + Shape("SELECT 1", "SELECT 1 AS n", None, 1), + Shape("range(10k)", "SELECT * FROM range(10000)", None, 10_000), + Shape("range(1M)", "SELECT * FROM range(1000000)", None, 1_000_000), + Shape( + "wide-uuid(100k)", + "SELECT id, uuid() AS u FROM range(100000)", + None, + 100_000, + ), + Shape("metadata.catalogs", None, "catalogs", None), +] + + +BACKENDS: List[Tuple[str, Dict]] = [ + ("thrift", {"use_sea": False}), + ("kernel", {"use_sea": True}), +] + + +# ─── Measurement ───────────────────────────────────────────────── + + +@dataclass +class SampleMetrics: + open_s: float + ttfr_s: float + drain_s: float + rows: int + rss_delta_kb: int + + +def _rss_kb() -> int: + # ru_maxrss is in KB on Linux, bytes on macOS — the script is + # primarily for Linux CI / dev shells, document the macOS + # caveat and move on. + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + +def run_one(backend_kwargs: Dict, shape: Shape, conn_params: Dict) -> SampleMetrics: + """Open a fresh connection, run the shape, drain, return metrics.""" + gc.collect() + rss_before = _rss_kb() + + t0 = time.perf_counter() + conn = sql.connect(**conn_params, **backend_kwargs) + t_open = time.perf_counter() + try: + cur = conn.cursor() + try: + t_pre_exec = time.perf_counter() + if shape.sql is not None: + cur.execute(shape.sql) + else: + getattr(cur, shape.metadata_call)() + # First row marks the end of poll + first-fetch latency. + first = cur.fetchmany(1) + t_ttfr = time.perf_counter() + # Drain the rest. + tail_rows = 0 + while True: + chunk = cur.fetchmany(10_000) + if not chunk: + break + tail_rows += len(chunk) + t_drain = time.perf_counter() + total_rows = len(first) + tail_rows + if shape.expected_rows is not None and total_rows != shape.expected_rows: + raise RuntimeError( + f"{shape.name}: expected {shape.expected_rows} rows, got {total_rows}" + ) + finally: + cur.close() + finally: + conn.close() + + rss_after = _rss_kb() + return SampleMetrics( + open_s=t_open - t0, + ttfr_s=t_ttfr - t_pre_exec, + drain_s=t_drain - t_pre_exec, + rows=total_rows, + rss_delta_kb=max(0, rss_after - rss_before), + ) + + +# ─── Aggregation ───────────────────────────────────────────────── + + +@dataclass +class Aggregated: + open_min: float + open_med: float + open_max: float + ttfr_min: float + ttfr_med: float + ttfr_max: float + drain_min: float + drain_med: float + drain_max: float + rows: int + rss_med_kb: int + + +def aggregate(samples: List[SampleMetrics]) -> Aggregated: + o = [s.open_s for s in samples] + t = [s.ttfr_s for s in samples] + d = [s.drain_s for s in samples] + r = [s.rss_delta_kb for s in samples] + return Aggregated( + open_min=min(o), open_med=statistics.median(o), open_max=max(o), + ttfr_min=min(t), ttfr_med=statistics.median(t), ttfr_max=max(t), + drain_min=min(d), drain_med=statistics.median(d), drain_max=max(d), + rows=samples[0].rows, + rss_med_kb=int(statistics.median(r)), + ) + + +def fmt_ms(seconds: float) -> str: + return f"{seconds * 1000:.0f}" + + +def fmt_rate(rows: int, seconds: float) -> str: + if seconds <= 0: + return "—" + return f"{int(rows / seconds):,}" + + +# ─── Driver ────────────────────────────────────────────────────── + + +def build_conn_params() -> Dict: + host = os.environ.get("DATABRICKS_SERVER_HOSTNAME") or os.environ.get("DATABRICKS_HOST", "") + host = host.replace("https://", "").rstrip("/") + http_path = os.environ.get("DATABRICKS_HTTP_PATH", "") + token = os.environ.get("DATABRICKS_TOKEN", "") + if not (host and http_path and token): + sys.exit( + "Missing credentials. Set DATABRICKS_SERVER_HOSTNAME (or _HOST), " + "DATABRICKS_HTTP_PATH, DATABRICKS_TOKEN before running." + ) + return { + "server_hostname": host, + "http_path": http_path, + "access_token": token, + } + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--samples", type=int, default=5, + help="Sample runs per (backend, shape). First run is dropped as warm-up. Default: 5.", + ) + parser.add_argument( + "--shapes", nargs="*", + help="Subset of shapes to run by name. Default: all. Choices: " + + ", ".join(s.name for s in SHAPES), + ) + parser.add_argument( + "--backends", nargs="*", choices=[b for b, _ in BACKENDS], + help="Subset of backends. Default: both.", + ) + args = parser.parse_args() + + conn_params = build_conn_params() + shapes = [s for s in SHAPES if not args.shapes or s.name in args.shapes] + backends = [(n, k) for (n, k) in BACKENDS if not args.backends or n in args.backends] + + if not shapes: + sys.exit(f"No shapes match {args.shapes!r}") + if not backends: + sys.exit(f"No backends match {args.backends!r}") + + # results[(shape_name, backend_name)] = Aggregated + results: Dict[Tuple[str, str], Aggregated] = {} + + total_runs = len(shapes) * len(backends) * (args.samples + 1) + print(f"Running {total_runs} samples ({len(shapes)} shapes × {len(backends)} backends × {args.samples + 1} runs/cell)\n", flush=True) + + for shape in shapes: + for backend_name, backend_kwargs in backends: + print(f" {shape.name:24s} on {backend_name:8s} … ", end="", flush=True) + samples: List[SampleMetrics] = [] + # +1 because we drop the first run. + for run_idx in range(args.samples + 1): + try: + m = run_one(backend_kwargs, shape, conn_params) + except Exception as exc: + print(f"\n run {run_idx} FAILED: {exc}", flush=True) + raise + if run_idx == 0: + continue # warmup + samples.append(m) + agg = aggregate(samples) + results[(shape.name, backend_name)] = agg + print( + f"open={fmt_ms(agg.open_med)}ms " + f"ttfr={fmt_ms(agg.ttfr_med)}ms " + f"drain={fmt_ms(agg.drain_med)}ms " + f"rows={agg.rows:,} " + f"rss+={agg.rss_med_kb}kb", + flush=True, + ) + + # ─── Report ───────────────────────────────────────────────── + print("\n" + "=" * 70) + print("Results (median across {} samples; warm-up dropped):".format(args.samples)) + print("=" * 70) + for shape in shapes: + print(f"\n### {shape.name}") + if shape.metadata_call: + print(f"_metadata: cursor.{shape.metadata_call}()_") + else: + print(f"_SQL: `{shape.sql}`_") + print() + print("| backend | open (ms) | ttfr (ms) | drain (ms) | rows/s | rss Δ (KB) |") + print("|---|---|---|---|---|---|") + for backend_name, _ in backends: + agg = results.get((shape.name, backend_name)) + if agg is None: + print(f"| {backend_name} | (skipped) | | | | |") + continue + print( + f"| {backend_name} | " + f"{fmt_ms(agg.open_med)} ({fmt_ms(agg.open_min)}–{fmt_ms(agg.open_max)}) | " + f"{fmt_ms(agg.ttfr_med)} ({fmt_ms(agg.ttfr_min)}–{fmt_ms(agg.ttfr_max)}) | " + f"{fmt_ms(agg.drain_med)} ({fmt_ms(agg.drain_min)}–{fmt_ms(agg.drain_max)}) | " + f"{fmt_rate(agg.rows, agg.drain_med)} | " + f"{agg.rss_med_kb} |" + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/databricks/sql/backend/kernel/__init__.py b/src/databricks/sql/backend/kernel/__init__.py new file mode 100644 index 000000000..4a1ad8205 --- /dev/null +++ b/src/databricks/sql/backend/kernel/__init__.py @@ -0,0 +1,25 @@ +"""Backend that delegates to the Databricks SQL Kernel (Rust) via PyO3. + +Routed when ``use_sea=True`` is passed to ``databricks.sql.connect``. +The module's identity is "delegates to the kernel" — not the wire +protocol the kernel happens to use today (SEA REST). The kernel may +switch its default transport (SEA REST → SEA gRPC → …) without +renaming this module. + +This ``__init__`` deliberately does **not** re-export +``KernelDatabricksClient`` from ``.client``. Importing ``.client`` +loads the ``databricks_sql_kernel`` PyO3 extension at module-import +time; doing that eagerly here would make ``import +databricks.sql.backend.kernel.type_mapping`` (used by tests / by +``KernelResultSet`` consumers) require the kernel wheel even when +the caller never plans to open a kernel-backed session. Callers +that need the client import it directly: + + from databricks.sql.backend.kernel.client import KernelDatabricksClient + +``session.py::_create_backend`` already does this lazy import under +the ``use_sea=True`` branch. + +See ``docs/designs/pysql-kernel-integration.md`` in +``databricks-sql-kernel`` for the full integration design. +""" diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py new file mode 100644 index 000000000..bb94dddf1 --- /dev/null +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -0,0 +1,97 @@ +"""Translate the connector's ``AuthProvider`` into ``databricks_sql_kernel`` +``Session`` auth kwargs. + +This phase ships PAT only. The kernel-side PyO3 binding accepts +``auth_type='pat'``; OAuth / federation / custom credentials +providers are reserved but not yet wired in either layer. Non-PAT +auth raises ``NotSupportedError`` from this bridge so the failure +surfaces at session-open time with a clear message rather than +deep inside the kernel. + +Token extraction goes through ``AuthProvider.add_headers({})`` +rather than touching auth-provider-specific attributes, so the +bridge works uniformly for every PAT shape — including +``AccessTokenAuthProvider`` wrapped in ``TokenFederationProvider`` +(which ``get_python_sql_connector_auth_provider`` does for every +provider it builds). +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider +from databricks.sql.auth.token_federation import TokenFederationProvider +from databricks.sql.exc import NotSupportedError + +logger = logging.getLogger(__name__) + + +_BEARER_PREFIX = "Bearer " + + +def _is_pat(auth_provider: AuthProvider) -> bool: + """Return True iff this provider ultimately wraps an + ``AccessTokenAuthProvider``. + + ``get_python_sql_connector_auth_provider`` always wraps the + base provider in a ``TokenFederationProvider``, so an + ``isinstance`` check against ``AccessTokenAuthProvider`` alone + never matches in practice. We peek through the federation + wrapper to find the real type. + """ + if isinstance(auth_provider, AccessTokenAuthProvider): + return True + if isinstance(auth_provider, TokenFederationProvider) and isinstance( + auth_provider.external_provider, AccessTokenAuthProvider + ): + return True + return False + + +def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: + """Pull the current bearer token out of an ``AuthProvider``. + + The connector's ``AuthProvider.add_headers`` mutates a header + dict and writes the ``Authorization: Bearer `` value. + Going through that public surface keeps us insulated from + provider-specific internals. + + Returns ``None`` if the provider did not write an Authorization + header or wrote a non-Bearer scheme — neither is representable + in the kernel's PAT auth surface. + """ + headers: Dict[str, str] = {} + auth_provider.add_headers(headers) + auth = headers.get("Authorization") + if not auth: + return None + if not auth.startswith(_BEARER_PREFIX): + return None + return auth[len(_BEARER_PREFIX) :] + + +def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: + """Build the kwargs passed to ``databricks_sql_kernel.Session(...)``. + + PAT (including ``TokenFederationProvider``-wrapped PAT) routes + through the kernel's PAT path. Anything else raises + ``NotSupportedError`` — the kernel binding doesn't accept OAuth + today, and routing OAuth through PAT would silently break + token refresh during long-running sessions. + """ + if _is_pat(auth_provider): + token = _extract_bearer_token(auth_provider) + if not token: + raise ValueError( + "PAT auth provider did not produce a Bearer Authorization " + "header; cannot route through the kernel's PAT path" + ) + return {"auth_type": "pat", "access_token": token} + + raise NotSupportedError( + f"The kernel backend (use_sea=True) currently only supports PAT auth, " + f"but got {type(auth_provider).__name__}. Use use_sea=False (Thrift) " + "for OAuth / federation / custom credential providers." + ) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py new file mode 100644 index 000000000..a54501edf --- /dev/null +++ b/src/databricks/sql/backend/kernel/client.py @@ -0,0 +1,509 @@ +"""``DatabricksClient`` backed by the Rust kernel via PyO3. + +Routed when ``use_sea=True``. Constructor takes the connector's +already-built ``auth_provider`` and forwards everything else to the +kernel's ``Session``. Every kernel call goes through this thin +wrapper; this module is the single seam between the connector's +``DatabricksClient`` contract and the kernel's Python surface. + +Errors map cleanly: ``KernelError`` from the kernel is inspected +for its ``code`` attribute and re-raised as the appropriate PEP +249 exception (``DatabaseError``, ``OperationalError``, +``ProgrammingError``, etc.). Connector callers see standard +exception types, never the underlying kernel error. + +Phase 1 gaps documented in the integration design: + +- ``query_tags`` on execute is not supported (kernel exposes + ``statement_conf`` but PyO3 doesn't surface it). +- ``get_tables`` with a non-empty ``table_types`` filter applies + the filter client-side; today the kernel returns the full + ``SHOW TABLES`` shape unchanged. The connector's existing + ``ResultSetFilter.filter_tables_by_type`` is keyed on + ``SeaResultSet`` not ``KernelResultSet``, so we punt and let + the caller see all rows — documented as a known gap in the + design doc. +- Volume PUT/GET (staging operations): kernel has no Volume API + yet. Users on Thrift-only paths. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.kernel.auth_bridge import kernel_auth_kwargs +from databricks.sql.backend.kernel.result_set import KernelResultSet +from databricks.sql.backend.types import ( + BackendType, + CommandId, + CommandState, + SessionId, +) +from databricks.sql.exc import ( + DatabaseError, + Error, + InterfaceError, + NotSupportedError, + OperationalError, + ProgrammingError, +) +from databricks.sql.thrift_api.TCLIService import ttypes + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + + +try: + import databricks_sql_kernel as _kernel # type: ignore[import-not-found] +except ImportError as exc: # pragma: no cover - import-time error surfaces clearly + # The `databricks-sql-kernel` wheel is not yet on PyPI, so we + # don't yet declare it as an optional extra in pyproject.toml + # (doing so breaks `poetry lock`). Once published the install + # hint will move to `pip install 'databricks-sql-connector[kernel]'`. + raise ImportError( + "use_sea=True requires the databricks-sql-kernel package. Install it with:\n" + " pip install databricks-sql-kernel\n" + "or for local development from the kernel repo:\n" + " cd databricks-sql-kernel/pyo3 && maturin develop --release" + ) from exc + + +# ─── Error mapping ────────────────────────────────────────────────────────── + + +# Map a kernel `code` slug to the PEP 249 exception class that best +# captures it. The match isn't a perfect 1:1 — PEP 249 has a +# narrower taxonomy than the kernel — so several kernel codes +# collapse onto the same Python exception. This table is the only +# place that mapping lives. +_CODE_TO_EXCEPTION = { + "InvalidArgument": ProgrammingError, + "Unauthenticated": OperationalError, + "PermissionDenied": OperationalError, + "NotFound": ProgrammingError, + "ResourceExhausted": OperationalError, + "Unavailable": OperationalError, + "Timeout": OperationalError, + "Cancelled": OperationalError, + "DataLoss": DatabaseError, + "Internal": DatabaseError, + "InvalidStatementHandle": ProgrammingError, + "NetworkError": OperationalError, + "SqlError": DatabaseError, + "Unknown": DatabaseError, +} + + +def _reraise_kernel_error(exc: BaseException) -> "Error": + """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 + exception. Other exception types fall through unchanged. + + Kernel errors carry their structured attrs (``code``, + ``message``, ``sql_state``, ``error_code``, ``query_id`` …) as + plain attributes — we copy them onto the re-raised exception so + callers can branch on them without reaching back through + ``__cause__``. + """ + if not isinstance(exc, _kernel.KernelError): + return exc # type: ignore[return-value] + code = getattr(exc, "code", "Unknown") + cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) + new = cls(getattr(exc, "message", str(exc))) + # Forward the structured fields so connector users can read + # err.sql_state / err.query_id / etc. without a type-switch. + for attr in ( + "code", + "sql_state", + "error_code", + "vendor_code", + "http_status", + "retryable", + "query_id", + ): + try: + setattr(new, attr, getattr(exc, attr)) + except (AttributeError, TypeError): # pragma: no cover - defensive + pass + new.__cause__ = exc + return new + + +# ─── Client ───────────────────────────────────────────────────────────────── + + +class KernelDatabricksClient(DatabricksClient): + """``DatabricksClient`` that delegates to the Rust kernel. + + Owns one ``databricks_sql_kernel.Session`` per ``open_session`` + call. Async-execute handles (from ``submit()``) live in a dict + keyed on ``CommandId`` so the connector's polling APIs + (``get_query_state`` / ``get_execution_result`` / + ``cancel_command`` / ``close_command``) can find them again. + """ + + def __init__( + self, + server_hostname: str, + http_path: str, + auth_provider, + ssl_options, + catalog: Optional[str] = None, + schema: Optional[str] = None, + http_headers=None, + http_client=None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ): + # The connector hands us several fields the kernel doesn't + # consume directly (ssl_options, http_headers, http_client, + # port, _use_arrow_native_complex_types). Kernel manages + # its own HTTP stack so we accept-and-ignore. + self._server_hostname = server_hostname + self._http_path = http_path + self._auth_provider = auth_provider + self._catalog = catalog + self._schema = schema + self._auth_kwargs = kernel_auth_kwargs(auth_provider) + # Open ``databricks_sql_kernel.Session`` lazily in + # ``open_session`` so the Session lifecycle gates the + # underlying connection setup — same shape as Thrift's + # ``TOpenSession``. + self._kernel_session: Optional[Any] = None + self._session_id: Optional[SessionId] = None + # Async-exec handles keyed by CommandId.guid. Populated by + # ``execute_command(async_op=True)``; drained by ``close_command``. + self._async_handles: Dict[str, Any] = {} + + # ── Session lifecycle ────────────────────────────────────────── + + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + if self._kernel_session is not None: + raise InterfaceError("KernelDatabricksClient already has an open session.") + # ``session_configuration`` flows through to the kernel's + # ``session_conf`` map verbatim; the SEA endpoint enforces + # its own allow-list and rejects unknown keys. + session_conf: Optional[Dict[str, str]] = None + if session_configuration: + session_conf = {k: str(v) for k, v in session_configuration.items()} + try: + self._kernel_session = _kernel.Session( + host=self._server_hostname, + http_path=self._http_path, + catalog=catalog or self._catalog, + schema=schema or self._schema, + session_conf=session_conf, + **self._auth_kwargs, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + # Use the kernel's real server-issued session id, not a + # synthetic UUID. Matches what the native SEA backend does. + # Bind to a local first so mypy sees a non-Optional return. + session_id = SessionId.from_sea_session_id(self._kernel_session.session_id) + self._session_id = session_id + logger.info("Opened kernel-backed session %s", session_id) + return session_id + + def close_session(self, session_id: SessionId) -> None: + if self._kernel_session is None: + return + # Close any tracked async handles first so they fire their + # server-side CloseStatement before the session goes away. + for handle in list(self._async_handles.values()): + try: + handle.close() + except _kernel.KernelError as exc: + logger.warning( + "Error closing async handle during session close: %s", exc + ) + self._async_handles.clear() + try: + self._kernel_session.close() + except _kernel.KernelError as exc: + # Surface as a non-fatal warning — the kernel's Drop + # impl will retry the close fire-and-forget. PEP 249 + # discourages raising from connection.close(). + logger.warning("Error closing kernel session: %s", exc) + self._kernel_session = None + self._session_id = None + + # ── Query execution ──────────────────────────────────────────── + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, + ) -> Union["ResultSet", None]: + if self._kernel_session is None: + raise InterfaceError("Cannot execute_command without an open session.") + if query_tags: + raise NotSupportedError( + "Statement-level query_tags are not yet supported on the kernel backend." + ) + + stmt = self._kernel_session.statement() + try: + stmt.set_sql(operation) + if parameters: + # Lazy import — type_mapping touches pyarrow at + # module load; keep `execute_command` callable from + # contexts that don't yet need it. + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + bind_tspark_params(stmt, parameters) + if async_op: + async_exec = stmt.submit() + command_id = CommandId.from_sea_statement_id(async_exec.statement_id) + cursor.active_command_id = command_id + self._async_handles[command_id.guid] = async_exec + return None + executed = stmt.execute() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + finally: + # ``Statement`` is a lifecycle owner separate from the + # executed handle it produces. Drop it here so the + # parent doesn't keep the handle alive longer than the + # caller expects. + try: + stmt.close() + except _kernel.KernelError: + pass + + command_id = CommandId.from_sea_statement_id(executed.statement_id) + cursor.active_command_id = command_id + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=executed, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + def cancel_command(self, command_id: CommandId) -> None: + handle = self._async_handles.get(command_id.guid) + if handle is None: + # Sync-execute paths fully materialise the result before + # ``execute_command`` returns, so by the time + # cancel_command can fire there's nothing in flight. + # Match the Thrift backend's tolerant behaviour. + logger.debug("cancel_command: no in-flight async handle for %s", command_id) + return + try: + handle.cancel() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + def close_command(self, command_id: CommandId) -> None: + handle = self._async_handles.pop(command_id.guid, None) + if handle is None: + logger.debug("close_command: no tracked handle for %s", command_id) + return + try: + handle.close() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + def get_query_state(self, command_id: CommandId) -> CommandState: + handle = self._async_handles.get(command_id.guid) + if handle is None: + # No tracked async handle means execute_command ran + # sync and the result was materialised before returning; + # the command is terminal by construction. + return CommandState.SUCCEEDED + try: + state, failure = handle.status() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + if state == "Failed" and failure is not None: + # Surface server-reported failure as a database error so + # the cursor's polling loop terminates with the right + # exception class — matches the Thrift backend's + # behaviour on TOperationState::ERROR_STATE. + raise _reraise_kernel_error(failure) + return _STATE_TO_COMMAND_STATE.get(state, CommandState.FAILED) + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> "ResultSet": + handle = self._async_handles.get(command_id.guid) + if handle is None: + raise ProgrammingError( + "get_execution_result called for an unknown command_id; " + "the kernel backend only tracks async-submitted statements." + ) + try: + stream = handle.await_result() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=stream, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + # ── Metadata ─────────────────────────────────────────────────── + + def _metadata_result(self, stream, cursor, command_id): + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=stream, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + def _synthetic_command_id(self) -> CommandId: + """Metadata calls don't produce a server statement id; mint + a synthetic one so the ``ResultSet`` still has a stable + identifier the cursor can attribute logs to.""" + return CommandId.from_sea_statement_id(f"metadata-{uuid.uuid4()}") + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_catalogs requires an open session.") + try: + stream = self._kernel_session.metadata().list_catalogs() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_schemas requires an open session.") + try: + stream = self._kernel_session.metadata().list_schemas( + catalog=catalog_name, + schema_pattern=schema_name, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_tables requires an open session.") + if table_types: + # Documented gap: native SEA backend filters here, but + # its filter is keyed on SeaResultSet. Day-1 we surface + # the unfiltered result; a small follow-up ports the + # filter to operate on KernelResultSet. + logger.warning( + "get_tables: client-side table_types filter not yet implemented " + "on the kernel backend; returning unfiltered rows for %r", + table_types, + ) + try: + stream = self._kernel_session.metadata().list_tables( + catalog=catalog_name, + schema_pattern=schema_name, + table_pattern=table_name, + table_types=table_types, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_columns requires an open session.") + if not catalog_name: + # Kernel's list_columns requires a catalog (SEA `SHOW + # COLUMNS` cannot span catalogs). Surface the constraint + # explicitly rather than letting the kernel error. + raise ProgrammingError( + "get_columns requires catalog_name on the kernel backend." + ) + try: + stream = self._kernel_session.metadata().list_columns( + catalog=catalog_name, + schema_pattern=schema_name, + table_pattern=table_name, + column_pattern=column_name, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + # ── Misc ─────────────────────────────────────────────────────── + + @property + def max_download_threads(self) -> int: + # CloudFetch parallelism lives kernel-side. This property is + # consulted by Thrift code paths that don't run for + # use_sea=True; return a non-zero default so anything that + # peeks at it does not divide by zero. + return 10 + + +_STATE_TO_COMMAND_STATE: Dict[str, CommandState] = { + "Pending": CommandState.PENDING, + "Running": CommandState.RUNNING, + "Succeeded": CommandState.SUCCEEDED, + "Failed": CommandState.FAILED, + "Cancelled": CommandState.CANCELLED, + "Closed": CommandState.CLOSED, +} diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py new file mode 100644 index 000000000..0ee85c2be --- /dev/null +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -0,0 +1,222 @@ +"""Streaming ``ResultSet`` over a kernel ``ExecutedStatement`` or +``ResultStream``. + +The kernel surfaces two flavours of result-bearing handle: + +- ``ExecutedStatement`` — returned by ``Statement.execute()``. Has a + ``statement_id`` and a ``cancel()`` method. +- ``ResultStream`` — returned by ``Session.metadata().list_*`` and by + ``ExecutedAsyncStatement.await_result()``. No statement id; no + cancel. + +Both implement the same three methods this class actually calls: +``arrow_schema() / fetch_next_batch() / fetch_all_arrow() / close()``. +``KernelResultSet`` takes either via the ``kernel_handle`` parameter +and treats them uniformly — the connector's ``ResultSet`` contract +doesn't need to distinguish them. + +Buffer shape mirrors the prior ADBC POC's ``AdbcResultSet``: a FIFO +of pyarrow ``RecordBatch``es, fed one batch at a time from the +kernel as the connector calls ``fetch*``. ``fetchmany(n)`` slices +within a batch when ``n`` is smaller than the kernel's natural +batch size; ``fetchall`` drains the whole stream. +""" + +from __future__ import annotations + +import logging +from collections import deque +from typing import Any, Deque, List, Optional, TYPE_CHECKING + +import pyarrow + +from databricks.sql.backend.kernel.type_mapping import description_from_arrow_schema +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.result_set import ResultSet +from databricks.sql.types import Row + +if TYPE_CHECKING: + from databricks.sql.client import Connection + from databricks.sql.backend.kernel.client import KernelDatabricksClient + +logger = logging.getLogger(__name__) + + +class KernelResultSet(ResultSet): + """Streaming ``ResultSet`` over a kernel handle. + + The ``kernel_handle`` is duck-typed: it must implement + ``arrow_schema() -> pyarrow.Schema``, ``fetch_next_batch() -> + Optional[pyarrow.RecordBatch]``, and ``close() -> None``. + Both ``databricks_sql_kernel.ExecutedStatement`` and + ``databricks_sql_kernel.ResultStream`` satisfy that contract. + """ + + def __init__( + self, + connection: "Connection", + backend: "KernelDatabricksClient", + kernel_handle: Any, + command_id: CommandId, + arraysize: int, + buffer_size_bytes: int, + ): + schema = kernel_handle.arrow_schema() + super().__init__( + connection=connection, + backend=backend, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=CommandState.RUNNING, + has_been_closed_server_side=False, + has_more_rows=True, + results_queue=None, + description=description_from_arrow_schema(schema), + is_staging_operation=False, + lz4_compressed=False, + arrow_schema_bytes=None, + ) + self._kernel_handle = kernel_handle + self._schema: pyarrow.Schema = schema + # FIFO of record batches plus a per-head row offset, so + # partial fetches (fetchmany(n) for n < batch_size) don't + # re-fetch from the kernel. + self._buffer: Deque[pyarrow.RecordBatch] = deque() + self._buffer_offset: int = 0 + self._exhausted: bool = False + + # ----- internal helpers ----- + + def _pull_one_batch(self) -> bool: + """Pull the next batch from the kernel into the local buffer. + Returns True if a batch was added; False if the kernel side + is exhausted.""" + if self._exhausted: + return False + batch = self._kernel_handle.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + return False + if batch.num_rows > 0: + self._buffer.append(batch) + return True + + def _ensure_buffered(self, n_rows: int) -> int: + """Pull batches until ``n_rows`` are buffered or the kernel + is exhausted. Returns total rows currently buffered.""" + while self._buffered_rows() < n_rows: + if not self._pull_one_batch(): + break + return self._buffered_rows() + + def _buffered_rows(self) -> int: + if not self._buffer: + return 0 + first = self._buffer[0].num_rows - self._buffer_offset + rest = sum(b.num_rows for b in list(self._buffer)[1:]) + return first + rest + + def _take_buffered(self, n: int) -> pyarrow.Table: + """Slice up to ``n`` rows out of the buffer; advances state.""" + slices: List[pyarrow.RecordBatch] = [] + remaining = n + while remaining > 0 and self._buffer: + head = self._buffer[0] + avail = head.num_rows - self._buffer_offset + take = min(avail, remaining) + slices.append(head.slice(self._buffer_offset, take)) + self._buffer_offset += take + remaining -= take + if self._buffer_offset >= head.num_rows: + self._buffer.popleft() + self._buffer_offset = 0 + self._next_row_index += n - remaining + if not slices: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(slices, schema=self._schema) + + def _drain(self) -> pyarrow.Table: + """Consume everything left in the buffer + kernel stream + and return as a single Table.""" + chunks: List[pyarrow.RecordBatch] = [] + if self._buffer and self._buffer_offset > 0: + head = self._buffer.popleft() + chunks.append( + head.slice(self._buffer_offset, head.num_rows - self._buffer_offset) + ) + self._buffer_offset = 0 + while self._buffer: + chunks.append(self._buffer.popleft()) + if not self._exhausted: + while True: + batch = self._kernel_handle.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + break + if batch.num_rows > 0: + chunks.append(batch) + rows = sum(c.num_rows for c in chunks) + self._next_row_index += rows + if not chunks: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(chunks, schema=self._schema) + + # ----- Arrow fetches ----- + + def fetchall_arrow(self) -> pyarrow.Table: + return self._drain() + + def fetchmany_arrow(self, size: int) -> pyarrow.Table: + if size < 0: + raise ValueError(f"fetchmany_arrow size must be >= 0, got {size}") + if size == 0: + return pyarrow.Table.from_batches([], schema=self._schema) + self._ensure_buffered(size) + return self._take_buffered(size) + + # ----- Row fetches ----- + + def fetchone(self) -> Optional[Row]: + self._ensure_buffered(1) + if self._buffered_rows() == 0: + return None + table = self._take_buffered(1) + rows = self._convert_arrow_table(table) + return rows[0] if rows else None + + def fetchmany(self, size: int) -> List[Row]: + if size < 0: + raise ValueError(f"fetchmany size must be >= 0, got {size}") + if size == 0: + return [] + self._ensure_buffered(size) + table = self._take_buffered(size) + return self._convert_arrow_table(table) + + def fetchall(self) -> List[Row]: + return self._convert_arrow_table(self._drain()) + + def close(self) -> None: + """Close the underlying kernel handle. Idempotent — the + kernel's own ``close()`` is idempotent, and we guard against + repeated calls so partially-drained streams don't double- + decrement reference counts.""" + if self._kernel_handle is None: + return + try: + self._kernel_handle.close() + except Exception as exc: + # close() failures are not actionable at the connector + # level; log and swallow so the cursor's __del__ / + # connection close path stays clean. + logger.warning("Error closing kernel handle: %s", exc) + self._buffer.clear() + self._kernel_handle = None + self._exhausted = True + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py new file mode 100644 index 000000000..ce5d2939f --- /dev/null +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -0,0 +1,135 @@ +"""Arrow ↔ PEP 249 type translation for the kernel backend. + +The kernel returns results as pyarrow ``Schema`` / ``RecordBatch``; +PEP 249 ``cursor.description`` is a list of 7-tuples with a +type-name string per column. ``description_from_arrow_schema`` +flattens the conversion so ``KernelResultSet`` and any future +kernel-result wrapper share the same mapping. + +Parameter binding (``TSparkParameter`` → kernel +``Statement.bind_param``) is handled by ``bind_tspark_params`` — +forwards the connector's already-string-encoded form to the kernel +binding without an intermediate Python-typed round-trip. +""" + +from __future__ import annotations + +from typing import Any, List, Tuple + +import pyarrow + +from databricks.sql.thrift_api.TCLIService import ttypes + + +def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: + """Map a pyarrow type to the Databricks SQL type name used in + PEP 249 ``description``. Names match what the Thrift backend + produces so consumers can branch on them identically. + """ + if pyarrow.types.is_boolean(arrow_type): + return "boolean" + if pyarrow.types.is_int8(arrow_type): + return "tinyint" + if pyarrow.types.is_int16(arrow_type): + return "smallint" + if pyarrow.types.is_int32(arrow_type): + return "int" + if pyarrow.types.is_int64(arrow_type): + return "bigint" + if pyarrow.types.is_float32(arrow_type): + return "float" + if pyarrow.types.is_float64(arrow_type): + return "double" + if pyarrow.types.is_decimal(arrow_type): + return "decimal" + if pyarrow.types.is_string(arrow_type) or pyarrow.types.is_large_string(arrow_type): + return "string" + if pyarrow.types.is_binary(arrow_type) or pyarrow.types.is_large_binary(arrow_type): + return "binary" + if pyarrow.types.is_date(arrow_type): + return "date" + if pyarrow.types.is_timestamp(arrow_type): + return "timestamp" + if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list(arrow_type): + return "array" + if pyarrow.types.is_struct(arrow_type): + return "struct" + if pyarrow.types.is_map(arrow_type): + return "map" + return str(arrow_type) + + +def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: + """Build a PEP 249 ``description`` list from a pyarrow Schema. + + Each tuple is ``(name, type_code, display_size, internal_size, + precision, scale, null_ok)``. The kernel does not report the + last five so they're all ``None`` — same shape the existing + ADBC / Thrift result paths produce. + """ + return [ + ( + field.name, + _arrow_type_to_dbapi_string(field.type), + None, + None, + None, + None, + None, + ) + for field in schema + ] + + +def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any: + """Extract the string-encoded value from a ``TSparkParameter``, + or ``None`` for SQL NULL. + + Native parameters (``IntegerParameter`` etc.) always wrap their + value in ``TSparkParameterValue(stringValue=str(self.value))``; + ``VoidParameter`` sets ``stringValue="None"`` but the type is + ``"VOID"`` — the kernel-side parser ignores the value when the + type is VOID, so we don't have to special-case here. + """ + if param.value is None: + return None + return param.value.stringValue + + +def bind_tspark_params( + kernel_stmt, parameters: List[ttypes.TSparkParameter] +) -> None: + """Bind a list of ``TSparkParameter`` onto a kernel ``Statement``. + + The kernel expects positional bindings only (SEA v0 doesn't + accept named bindings on the wire). The connector's + ``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means + "treat as positional in source-list order". Native bindings + almost always come through positional today; named-binding + parameters surface as ``NotSupportedError`` so the user gets a + clear message instead of a server-side rejection. + + Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) are routed + through the kernel parser which currently rejects them — same + user-visible message ("compound parameter types … are not yet + supported"). Tracked as a follow-up. + """ + for i, param in enumerate(parameters, start=1): + # The connector's `ordinal` field is a bool (True/False) on + # native params and indicates positional vs named. Named + # params can't flow through the kernel today; raise early + # rather than letting the server reject. + if getattr(param, "ordinal", None) is False and getattr(param, "name", None): + from databricks.sql.exc import NotSupportedError + + raise NotSupportedError( + f"Named parameter binding (got name={param.name!r}) is not yet " + "supported on the kernel backend; pass parameters positionally." + ) + + sql_type = param.type or "STRING" + value_str = _tspark_param_value_str(param) + # The kernel takes 1-based ordinals; `i` is already that. + # Errors from the kernel side (bad literal, unsupported type, + # etc.) come up as KernelError and bubble through normally. + kernel_stmt.bind_param(i, value_str, sql_type) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 65c0d6aca..be2bdb4c2 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -9,7 +9,6 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -123,14 +122,33 @@ def _create_backend( """Create and return the appropriate backend client.""" self.use_sea = kwargs.get("use_sea", False) - databricks_client_class: Type[DatabricksClient] if self.use_sea: - logger.debug("Creating SEA backend client") - databricks_client_class = SeaDatabricksClient - else: - logger.debug("Creating Thrift backend client") - databricks_client_class = ThriftDatabricksClient + # `use_sea=True` now routes through the Rust kernel via + # PyO3. The native pure-Python SEA backend + # (`backend/sea/`) is no longer reachable through this + # flag; whether it's removed is tracked separately. See + # `docs/designs/pysql-kernel-integration.md` in the + # databricks-sql-kernel repo. + # + # Lazy import so the connector doesn't ImportError at + # startup when the kernel wheel isn't installed — the + # error surfaces only when a caller actually requests + # use_sea=True. + from databricks.sql.backend.kernel.client import KernelDatabricksClient + + logger.debug("Creating kernel-backed client for use_sea=True") + return KernelDatabricksClient( + server_hostname=server_hostname, + http_path=http_path, + http_headers=all_headers, + auth_provider=auth_provider, + ssl_options=self.ssl_options, + http_client=self.http_client, + catalog=kwargs.get("catalog"), + schema=kwargs.get("schema"), + ) + logger.debug("Creating Thrift backend client") common_args = { "server_hostname": server_hostname, "port": self.port, @@ -142,7 +160,7 @@ def _create_backend( "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } - return databricks_client_class(**common_args) + return ThriftDatabricksClient(**common_args) @staticmethod def _extract_spog_headers(http_path, existing_headers): diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py new file mode 100644 index 000000000..c4c0425f5 --- /dev/null +++ b/tests/e2e/test_kernel_backend.py @@ -0,0 +1,252 @@ +"""E2E tests for ``use_sea=True`` (routes through the Rust kernel +via the PyO3 ``databricks_sql_kernel`` module). + +PAT auth only. Anything else surfaces as ``NotSupportedError`` +from the auth bridge — covered as a unit test, not exercised here. + +Skipped automatically when: + - The standard ``DATABRICKS_SERVER_HOSTNAME`` / ``HTTP_PATH`` / + ``TOKEN`` creds aren't set (existing connector convention). + - ``databricks_sql_kernel`` isn't importable (the wheel hasn't + been installed; run ``pip install + 'databricks-sql-connector[kernel]'`` or, for local dev, + ``cd databricks-sql-kernel/pyo3 && maturin develop --release`` + into this venv). + +Run from the connector repo root: + + set -a && source ~/.databricks/pecotesting-creds && set +a + .venv/bin/pytest tests/e2e/test_kernel_backend.py -v +""" + +from __future__ import annotations + +import pytest + +import databricks.sql as sql +from databricks.sql.exc import DatabaseError + + +# Skip the whole module unless the kernel wheel is importable. +pytest.importorskip( + "databricks_sql_kernel", + reason="use_sea=True requires the databricks-sql-kernel package", +) + + +@pytest.fixture(scope="module") +def kernel_conn_params(connection_details): + """Live-cred check + connection params for use_sea=True. + + Skips the module if any cred is missing rather than letting + every test fail with a confusing connect-time error. + """ + host = connection_details.get("host") + http_path = connection_details.get("http_path") + token = connection_details.get("access_token") + if not (host and http_path and token): + pytest.skip( + "DATABRICKS_SERVER_HOSTNAME / DATABRICKS_HTTP_PATH / " + "DATABRICKS_TOKEN not set" + ) + return { + "server_hostname": host, + "http_path": http_path, + "access_token": token, + "use_sea": True, + } + + +@pytest.fixture +def conn(kernel_conn_params): + """One-shot connection per test (the simple_test pattern the + existing e2e suite uses for cursor-level tests).""" + c = sql.connect(**kernel_conn_params) + try: + yield c + finally: + c.close() + + +def test_connect_with_use_sea_opens_a_session(conn): + assert conn.open, "connection should report open after connect()" + + +def test_select_one(conn): + with conn.cursor() as cur: + cur.execute("SELECT 1 AS n") + assert cur.description[0][0] == "n" + # description type slug matches what Thrift produces + assert cur.description[0][1] == "int" + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 1 + + +def test_drain_large_range_to_arrow(conn): + """SELECT * FROM range(10000) drains as a pyarrow Table with + 10000 rows. Exercises the CloudFetch / multi-batch path on the + kernel side.""" + with conn.cursor() as cur: + cur.execute("SELECT * FROM range(10000)") + rows = cur.fetchall() + assert len(rows) == 10000 + + +def test_fetchmany_pacing(conn): + """fetchmany honours the requested size and stops cleanly at + end-of-stream — covers the buffer-slicing logic in + KernelResultSet.""" + with conn.cursor() as cur: + cur.execute("SELECT * FROM range(50)") + r1 = cur.fetchmany(10) + r2 = cur.fetchmany(20) + r3 = cur.fetchmany(100) # capped at remaining + assert (len(r1), len(r2), len(r3)) == (10, 20, 20) + + +def test_fetchall_arrow(conn): + with conn.cursor() as cur: + cur.execute("SELECT 1 AS a, 'hi' AS b") + table = cur.fetchall_arrow() + assert table.num_rows == 1 + assert table.column_names == ["a", "b"] + + +# ── Metadata ────────────────────────────────────────────────────── + + +def test_metadata_catalogs(conn): + with conn.cursor() as cur: + cur.catalogs() + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_schemas(conn): + with conn.cursor() as cur: + cur.schemas(catalog_name="main") + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_tables(conn): + with conn.cursor() as cur: + cur.tables(catalog_name="system", schema_name="information_schema") + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_columns(conn): + with conn.cursor() as cur: + cur.columns( + catalog_name="system", + schema_name="information_schema", + table_name="tables", + ) + rows = cur.fetchall() + assert len(rows) > 0 + + +# ── Session configuration ───────────────────────────────────────── + + +def test_session_configuration_round_trips(kernel_conn_params): + """`session_configuration` flows through to the kernel's + `session_conf` and is honoured by the server. + + `ANSI_MODE` is the safe choice — it's on the SEA allow-list and + isn't workspace-policy-clamped (unlike `STATEMENT_TIMEOUT`) or + rejected by the warehouse (unlike `TIMEZONE` on dogfood).""" + params = dict(kernel_conn_params) + params["session_configuration"] = {"ANSI_MODE": "false"} + with sql.connect(**params) as c: + with c.cursor() as cur: + cur.execute("SET ANSI_MODE") + rows = cur.fetchall() + kv = {r[0]: r[1] for r in rows} + assert kv.get("ANSI_MODE") == "false", f"got {rows!r}" + + +# ── Error mapping ───────────────────────────────────────────────── + + +def test_bad_sql_surfaces_as_databaseerror(conn): + """Bad SQL should surface as a PEP 249 ``DatabaseError`` with + the kernel's structured fields (`code`, `sql_state`, `query_id`) + attached as attributes — the connector backend re-raises the + kernel's ``SqlError`` to ``DatabaseError`` while preserving the + server-reported state.""" + with conn.cursor() as cur: + with pytest.raises(DatabaseError) as exc_info: + cur.execute("SELECT * FROM definitely_not_a_table_xyz_kernel_e2e") + err = exc_info.value + # Structured fields copied off the kernel exception: + assert getattr(err, "code", None) == "SqlError" + assert getattr(err, "sql_state", None) == "42P01" + + +# ── Parameter binding ───────────────────────────────────────────── + + +def test_parameterized_query_round_trips(conn): + """Positional parameter binding via the kernel backend. The + connector's native parameter classes (IntegerParameter etc.) + serialize to TSparkParameter under the hood; the kernel + backend's mapper forwards them positionally to the kernel. + """ + from databricks.sql.parameters.native import ( + IntegerParameter, + StringParameter, + BooleanParameter, + ) + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS i, ? AS s, ? AS b", + [ + IntegerParameter(42), + StringParameter("alice"), + BooleanParameter(True), + ], + ) + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 42 + assert rows[0][1] == "alice" + assert rows[0][2] is True + + +def test_parameterized_query_with_null(conn): + """`None` in the parameter list flows through as VoidParameter + → kernel TypedValue::Null.""" + with conn.cursor() as cur: + cur.execute("SELECT ? IS NULL AS is_null", [None]) + rows = cur.fetchall() + assert rows[0][0] is True + + +def test_parameterized_query_decimal(conn): + """DECIMAL parameters carry precision/scale in the SQL type + string ('DECIMAL(p,s)') — the kernel parser extracts them so + fractional digits survive the wire. + + Uses the connector's auto-inference path + (`calculate_decimal_cast_string`) to derive precision/scale + from the value; the explicit-arg path + (`DecimalParameter(v, scale=, precision=)`) has a pre-existing + bug in this branch where the format-args are passed + `(scale, precision)` instead of `(precision, scale)` — out of + scope for this PR. + """ + import decimal + from databricks.sql.parameters.native import DecimalParameter + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS d", + [DecimalParameter(decimal.Decimal("-123.45"))], + ) + rows = cur.fetchall() + # Server echoes back as decimal.Decimal. + assert str(rows[0][0]) == "-123.45" diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py new file mode 100644 index 000000000..57f1ecaaf --- /dev/null +++ b/tests/unit/test_kernel_auth_bridge.py @@ -0,0 +1,133 @@ +"""Unit tests for the kernel backend's auth bridge. + +Phase 1 ships PAT only. Tests verify: + - PAT routes through ``auth_type='pat'``. + - ``TokenFederationProvider``-wrapped PAT also routes through + PAT (every provider built by ``get_python_sql_connector_auth_provider`` + is federation-wrapped, so the naive isinstance check has to + look through the wrapper). + - Anything else raises ``NotSupportedError`` with a clear message. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +# auth_bridge.py itself has no pyarrow or kernel-wheel deps. The +# `databricks.sql.backend.kernel` package's __init__.py deliberately +# does *not* eagerly re-export from .client either (which would +# require the kernel wheel). So this test can run on the +# default-deps CI matrix without any extras. No importorskip needed. + +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, + AuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, +) +from databricks.sql.backend.kernel.auth_bridge import ( + _extract_bearer_token, + kernel_auth_kwargs, +) +from databricks.sql.exc import NotSupportedError + + +class _FakeOAuthProvider(AuthProvider): + """Stand-in for any non-PAT provider. The bridge should reject + these with NotSupportedError.""" + + def add_headers(self, request_headers): + request_headers["Authorization"] = "Bearer oauth-token-xyz" + + +class _MalformedProvider(AuthProvider): + """Provider that returns a non-Bearer Authorization header.""" + + def add_headers(self, request_headers): + request_headers["Authorization"] = "Basic dXNlcjpwYXNz" + + +class _SilentProvider(AuthProvider): + """Provider that writes nothing — misconfigured auth.""" + + def add_headers(self, request_headers): + pass + + +class TestExtractBearerToken: + def test_pat_provider_returns_token(self): + p = AccessTokenAuthProvider("dapi-abc-123") + assert _extract_bearer_token(p) == "dapi-abc-123" + + def test_non_bearer_auth_returns_none(self): + assert _extract_bearer_token(_MalformedProvider()) is None + + def test_silent_provider_returns_none(self): + assert _extract_bearer_token(_SilentProvider()) is None + + +class TestKernelAuthKwargs: + def test_pat_routes_to_kernel_pat(self): + kwargs = kernel_auth_kwargs(AccessTokenAuthProvider("dapi-xyz")) + assert kwargs == {"auth_type": "pat", "access_token": "dapi-xyz"} + + def test_federation_wrapped_pat_routes_to_kernel_pat(self): + """``get_python_sql_connector_auth_provider`` always wraps + the base provider in a ``TokenFederationProvider``, so the + PAT case never reaches us unwrapped in practice. The bridge + must look through the federation wrapper to find the + underlying ``AccessTokenAuthProvider``.""" + from databricks.sql.auth.token_federation import TokenFederationProvider + + base = AccessTokenAuthProvider("dapi-abc") + # TokenFederationProvider's __init__ requires an http_client + # to construct cleanly; for this unit test we only exercise + # the add_headers passthrough + the external_provider + # attribute. Bypass __init__ with __new__ and stash just + # the fields the bridge touches. + federated = TokenFederationProvider.__new__(TokenFederationProvider) + federated.external_provider = base + federated.add_headers = base.add_headers + kwargs = kernel_auth_kwargs(federated) + assert kwargs == {"auth_type": "pat", "access_token": "dapi-abc"} + + def test_pat_with_silent_provider_raises_value_error(self): + """An AccessTokenAuthProvider that produces no Authorization + header is misconfigured; surface that at bridge-build time, + not on the first kernel HTTP request.""" + broken = AccessTokenAuthProvider("dapi-x") + broken.add_headers = lambda h: None # type: ignore[method-assign] + with pytest.raises(ValueError, match="Bearer"): + kernel_auth_kwargs(broken) + + def test_generic_oauth_provider_raises_not_supported(self): + with pytest.raises(NotSupportedError, match="only supports PAT"): + kernel_auth_kwargs(_FakeOAuthProvider()) + + def test_external_credentials_provider_raises_not_supported(self): + """``ExternalAuthProvider`` wraps user-supplied + credentials_provider — kernel doesn't accept these today, + and the bridge surfaces that explicitly.""" + # ExternalAuthProvider's __init__ calls the credentials + # provider; supply a noop one. + from databricks.sql.auth.authenticators import CredentialsProvider + + class _NoopCreds(CredentialsProvider): + def auth_type(self): + return "noop" + + def __call__(self, *args, **kwargs): + return lambda: {"Authorization": "Bearer noop"} + + ext = ExternalAuthProvider(_NoopCreds()) + with pytest.raises(NotSupportedError, match="only supports PAT"): + kernel_auth_kwargs(ext) + + def test_silent_non_pat_provider_also_raises_not_supported(self): + """Even if a non-PAT provider produces no header, the bridge + rejects the type itself — we don't try to extract a token + from something we already know is unsupported.""" + with pytest.raises(NotSupportedError): + kernel_auth_kwargs(_SilentProvider()) diff --git a/tests/unit/test_kernel_result_set.py b/tests/unit/test_kernel_result_set.py new file mode 100644 index 000000000..c83bfce94 --- /dev/null +++ b/tests/unit/test_kernel_result_set.py @@ -0,0 +1,169 @@ +"""Unit tests for ``KernelResultSet`` — the buffer behavior + +close() semantics. Uses a fake kernel handle so tests run with no +network and no Rust extension dependency.""" + +from __future__ import annotations + +from collections import deque +from typing import Deque +from unittest.mock import MagicMock + +import pytest + +# pyarrow is an optional connector dep; the default-deps CI test +# job runs without it. KernelResultSet imports pyarrow eagerly, +# so the whole module must skip when pyarrow is unavailable. +pa = pytest.importorskip("pyarrow") + +from databricks.sql.backend.kernel.result_set import KernelResultSet +from databricks.sql.backend.types import CommandId, CommandState + + +class _FakeKernelHandle: + """Stand-in for ``databricks_sql_kernel.ExecutedStatement`` / + ``ResultStream``. Emits a configured list of ``RecordBatch``es + via ``fetch_next_batch`` and then returns ``None``.""" + + def __init__(self, schema: pa.Schema, batches): + self._schema = schema + self._batches: Deque[pa.RecordBatch] = deque(batches) + self.closed = False + + def arrow_schema(self) -> pa.Schema: + return self._schema + + def fetch_next_batch(self): + if self.closed: + raise RuntimeError("fetched after close") + if not self._batches: + return None + return self._batches.popleft() + + def close(self): + self.closed = True + + +def _make_rs(handle) -> KernelResultSet: + # The base ResultSet __init__ takes a `connection` ref it never + # actually dereferences during these buffer tests, so a Mock is + # fine. + connection = MagicMock() + backend = MagicMock() + return KernelResultSet( + connection=connection, + backend=backend, + kernel_handle=handle, + command_id=CommandId.from_sea_statement_id("smoke-test"), + arraysize=100, + buffer_size_bytes=1024, + ) + + +def _batch(schema: pa.Schema, values) -> pa.RecordBatch: + return pa.RecordBatch.from_arrays( + [pa.array(values, type=schema.field(0).type)], schema=schema + ) + + +# Renamed from `schema` -> `int_schema` because the connector's +# top-level conftest.py defines a session-scoped `schema` fixture +# for E2E tests; pytest's fixture-resolution complains about +# scope-mismatch if we shadow it with a function-scoped one here. +@pytest.fixture +def int_schema(): + return pa.schema([("n", pa.int64())]) + + +def test_description_built_from_kernel_schema(int_schema): + handle = _FakeKernelHandle(int_schema, []) + rs = _make_rs(handle) + assert rs.description == [("n", "bigint", None, None, None, None, None)] + + +def test_fetchall_arrow_drains_all_batches(int_schema): + handle = _FakeKernelHandle( + int_schema, [_batch(int_schema, [1, 2]), _batch(int_schema, [3, 4, 5])] + ) + rs = _make_rs(handle) + table = rs.fetchall_arrow() + assert table.num_rows == 5 + assert table.column(0).to_pylist() == [1, 2, 3, 4, 5] + assert rs.status == CommandState.SUCCEEDED + assert rs.has_more_rows is False + + +def test_fetchmany_arrow_slices_within_batch(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [10, 20, 30, 40])]) + rs = _make_rs(handle) + t1 = rs.fetchmany_arrow(2) + assert t1.num_rows == 2 and t1.column(0).to_pylist() == [10, 20] + t2 = rs.fetchmany_arrow(2) + assert t2.num_rows == 2 and t2.column(0).to_pylist() == [30, 40] + t3 = rs.fetchmany_arrow(2) + assert t3.num_rows == 0 + + +def test_fetchmany_arrow_spans_batch_boundary(int_schema): + handle = _FakeKernelHandle( + int_schema, + [_batch(int_schema, [1, 2]), _batch(int_schema, [3, 4]), _batch(int_schema, [5, 6])], + ) + rs = _make_rs(handle) + t = rs.fetchmany_arrow(5) + assert t.num_rows == 5 + assert t.column(0).to_pylist() == [1, 2, 3, 4, 5] + t = rs.fetchmany_arrow(2) + assert t.column(0).to_pylist() == [6] + + +def test_fetchone_returns_row_then_none(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [42])]) + rs = _make_rs(handle) + row = rs.fetchone() + assert row is not None + assert row[0] == 42 + assert rs.fetchone() is None + + +def test_fetchall_rows(int_schema): + handle = _FakeKernelHandle( + int_schema, [_batch(int_schema, [1, 2]), _batch(int_schema, [3])] + ) + rs = _make_rs(handle) + rows = rs.fetchall() + assert [r[0] for r in rows] == [1, 2, 3] + + +def test_fetchmany_negative_raises(int_schema): + rs = _make_rs(_FakeKernelHandle(int_schema, [])) + with pytest.raises(ValueError): + rs.fetchmany(-1) + with pytest.raises(ValueError): + rs.fetchmany_arrow(-1) + + +def test_close_is_idempotent_and_calls_handle(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [1])]) + rs = _make_rs(handle) + rs.close() + assert handle.closed is True + assert rs.status == CommandState.CLOSED + rs.close() # second call is a no-op (kernel handle is None) + + +def test_empty_stream(int_schema): + rs = _make_rs(_FakeKernelHandle(int_schema, [])) + assert rs.fetchone() is None + assert rs.fetchall_arrow().num_rows == 0 + assert rs.status == CommandState.SUCCEEDED + + +def test_close_swallows_handle_close_failures(int_schema): + """ResultSet.close() must not raise even if the kernel + handle's close() fails — PEP 249 discourages exceptions from + close paths (cursor/connection teardown depends on it).""" + handle = _FakeKernelHandle(int_schema, []) + handle.close = MagicMock(side_effect=RuntimeError("kernel boom")) + rs = _make_rs(handle) + rs.close() # must not raise + assert rs.status == CommandState.CLOSED diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py new file mode 100644 index 000000000..3ebf45001 --- /dev/null +++ b/tests/unit/test_kernel_type_mapping.py @@ -0,0 +1,172 @@ +"""Unit tests for Arrow → PEP 249 description-string mapping.""" + +from __future__ import annotations + +import pytest + +# pyarrow is an optional connector dep; the default-deps CI test +# job runs without it. The kernel backend itself imports pyarrow +# at module load, so any test that touches the backend must skip +# when pyarrow is unavailable. +pa = pytest.importorskip("pyarrow") + +from databricks.sql.backend.kernel.type_mapping import ( + _arrow_type_to_dbapi_string, + description_from_arrow_schema, +) + + +@pytest.mark.parametrize( + "arrow_type, expected", + [ + (pa.bool_(), "boolean"), + (pa.int8(), "tinyint"), + (pa.int16(), "smallint"), + (pa.int32(), "int"), + (pa.int64(), "bigint"), + (pa.float32(), "float"), + (pa.float64(), "double"), + (pa.decimal128(10, 2), "decimal"), + (pa.string(), "string"), + (pa.large_string(), "string"), + (pa.binary(), "binary"), + (pa.large_binary(), "binary"), + (pa.date32(), "date"), + (pa.timestamp("us"), "timestamp"), + (pa.list_(pa.int32()), "array"), + (pa.large_list(pa.int32()), "array"), + (pa.struct([("a", pa.int32())]), "struct"), + (pa.map_(pa.string(), pa.int32()), "map"), + ], +) +def test_arrow_to_dbapi_known_types(arrow_type, expected): + assert _arrow_type_to_dbapi_string(arrow_type) == expected + + +def test_arrow_to_dbapi_unknown_falls_back_to_str(): + # null type isn't in the explicit list but should fall through + # to the default str() so unknown variants are still printable + # rather than silently misclassified. + assert _arrow_type_to_dbapi_string(pa.null()) == "null" + + +def test_description_from_schema_preserves_field_names_and_order(): + schema = pa.schema( + [ + ("user_id", pa.int64()), + ("name", pa.string()), + ("created_at", pa.timestamp("us")), + ] + ) + desc = description_from_arrow_schema(schema) + assert len(desc) == 3 + assert [(d[0], d[1]) for d in desc] == [ + ("user_id", "bigint"), + ("name", "string"), + ("created_at", "timestamp"), + ] + # PEP 249 says all 7-tuples; the last 5 slots are None for the + # kernel backend (we don't report display_size / precision / + # scale / nullability). + for d in desc: + assert len(d) == 7 + assert d[2:] == (None, None, None, None, None) + + +# ─── bind_tspark_params ────────────────────────────────────────────────── + + +def _mk_param(*, type, value, ordinal=True, name=None): + """Build a minimal TSparkParameter for tests.""" + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=ordinal, name=name, type=type) + p.value = ttypes.TSparkParameterValue(stringValue=value) if value is not None else None + return p + + +class _RecordingStmt: + """Stand-in for the kernel `Statement` pyclass — records every + `bind_param` call so tests can assert the (ordinal, value, type) + triples the mapper forwarded.""" + + def __init__(self): + self.calls = [] + + def bind_param(self, ordinal, value_str, sql_type): + self.calls.append((ordinal, value_str, sql_type)) + + +def test_bind_tspark_params_forwards_each_param_positionally(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + params = [ + _mk_param(type="INT", value="42"), + _mk_param(type="STRING", value="alice"), + _mk_param(type="DATE", value="2026-05-15"), + ] + stmt = _RecordingStmt() + bind_tspark_params(stmt, params) + assert stmt.calls == [ + (1, "42", "INT"), + (2, "alice", "STRING"), + (3, "2026-05-15", "DATE"), + ] + + +def test_bind_tspark_params_null_value(): + """TSparkParameter with value=None → kernel sees value_str=None, + interpreted as SQL NULL regardless of the SQL type.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="STRING", value=None) + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, None, "STRING")] + + +def test_bind_tspark_params_void_passes_through(): + """VoidParameter sets type='VOID' with stringValue='None'; the + kernel parser ignores the value when type=VOID.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="VOID", value="None") + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, "None", "VOID")] + + +def test_bind_tspark_params_named_param_rejected(): + """The kernel doesn't accept named bindings on the SEA wire; + surface that at the connector layer so the user gets a pointed + error instead of a server-side rejection.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + + p = _mk_param(type="INT", value="42", ordinal=False, name="my_param") + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)named"): + bind_tspark_params(stmt, [p]) + # Nothing should have been forwarded before the rejection. + assert stmt.calls == [] + + +def test_bind_tspark_params_missing_type_defaults_to_string(): + """Defensive: a TSparkParameter with no `type` shouldn't crash + the mapper — fall back to STRING and let the kernel parse.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=True, name=None, type=None) + p.value = ttypes.TSparkParameterValue(stringValue="hello") + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, "hello", "STRING")] + + +def test_bind_tspark_params_empty_list_is_noop(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + stmt = _RecordingStmt() + bind_tspark_params(stmt, []) + assert stmt.calls == []