diff --git a/src/databricks/sql/backend/adbc/__init__.py b/src/databricks/sql/backend/adbc/__init__.py new file mode 100644 index 000000000..806658540 --- /dev/null +++ b/src/databricks/sql/backend/adbc/__init__.py @@ -0,0 +1,9 @@ +"""ADBC-Rust-kernel-backed backend for databricks-sql-python (POC). + +Wraps the PyO3 binding `databricks_adbc_pyo3` and adapts it to the +`DatabricksClient` / `ResultSet` interfaces used by the rest of the connector. +""" + +from databricks.sql.backend.adbc.client import AdbcDatabricksClient + +__all__ = ["AdbcDatabricksClient"] diff --git a/src/databricks/sql/backend/adbc/client.py b/src/databricks/sql/backend/adbc/client.py new file mode 100644 index 000000000..cb94d1664 --- /dev/null +++ b/src/databricks/sql/backend/adbc/client.py @@ -0,0 +1,203 @@ +"""DatabricksClient backed by the Rust ADBC kernel via PyO3 (POC). + +Implements the connector's `DatabricksClient` interface by delegating to the +`databricks_adbc_pyo3` extension module, which loads the Rust kernel +(`databricks-adbc`) in-process. PAT-only for now; metadata and async +operations raise NotImplementedError. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import ( + BackendType, + CommandId, + CommandState, + SessionId, +) +from databricks.sql.backend.adbc.result_set import AdbcResultSet +from databricks.sql.exc import DatabaseError, OperationalError, ProgrammingError +from databricks.sql.thrift_api.TCLIService import ttypes + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + +try: + import databricks_adbc_pyo3 as _rust_kernel +except ImportError as exc: # pragma: no cover - import-time error surfaces clearly + raise ImportError( + "use_sea=True requires the databricks_adbc_pyo3 extension. Install it from " + "the databricks-adbc/rust-pyo3 directory with `maturin develop --release` " + "in your venv." + ) from exc + + +class AdbcDatabricksClient(DatabricksClient): + """DatabricksClient that routes execution through the Rust ADBC kernel. + + Construction does not open a Rust connection — that happens in + `open_session` so the same Session lifecycle that today gates Thrift's + `TOpenSession` gates the Rust kernel's connection setup too. + """ + + def __init__( + self, + server_hostname: str, + http_path: str, + access_token: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + **kwargs, + ): + if not access_token: + raise ProgrammingError( + "AdbcDatabricksClient (use_sea=True) currently supports only PAT auth. " + "Pass access_token=." + ) + # Auth provider is built upstream but the Rust kernel re-does PAT auth itself, + # so we just need the raw token here. + self._server_hostname = server_hostname + self._http_path = http_path + self._access_token = access_token + self._initial_catalog = catalog + self._initial_schema = schema + + # Per-session state. We support a single open session at a time; opening + # a second one will raise. Matches the current Session lifecycle. + self._connection: Optional[_rust_kernel.Connection] = None + self._session_id: Optional[SessionId] = None + + # ----- session lifecycle ----- + + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + if self._connection is not None: + raise OperationalError("AdbcDatabricksClient already has an open session.") + if session_configuration: + logger.warning( + "AdbcDatabricksClient ignores session_configuration in POC: %s", + list(session_configuration.keys()), + ) + try: + self._connection = _rust_kernel.Connection( + self._server_hostname, + self._http_path, + self._access_token, + catalog=catalog or self._initial_catalog, + schema=schema or self._initial_schema, + ) + except RuntimeError as exc: + raise OperationalError(f"Failed to open Rust ADBC session: {exc}") from exc + + # Mint a synthetic SEA-style session id; the kernel manages real session + # lifecycle internally and does not surface its session GUID today. + self._session_id = SessionId.from_sea_session_id(str(uuid.uuid4())) + logger.info("Opened ADBC-Rust session %s", self._session_id) + return self._session_id + + def close_session(self, session_id: SessionId) -> None: + if self._connection is None: + return + # PyO3 Connection has no explicit close in the POC — drop the reference + # and let Drop release the Rust-side resources. + self._connection = None + self._session_id = None + + # ----- query execution ----- + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, + ) -> Union["ResultSet", None]: + if self._connection is None: + raise OperationalError("Cannot execute_command on closed session.") + if async_op: + raise NotImplementedError( + "async_op is not supported by the Rust ADBC backend (POC)." + ) + if parameters: + raise NotImplementedError( + "Parameter binding is not supported by the Rust ADBC backend (POC)." + ) + + try: + rs = self._connection.execute(operation) + except RuntimeError as exc: + raise DatabaseError(f"Rust ADBC execution failed: {exc}") from exc + + # The kernel does not surface its statement_id today; mint a synthetic one. + command_id = CommandId.from_sea_statement_id(str(uuid.uuid4())) + cursor.active_command_id = command_id + + return AdbcResultSet( + connection=cursor.connection, + backend=self, + rust_result_set=rs, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + def cancel_command(self, command_id: CommandId) -> None: + # POC: execute_command is fully synchronous and the result is materialized + # before it returns, so there is nothing to cancel after the fact. + logger.debug("cancel_command is a no-op in the Rust ADBC POC backend") + + def close_command(self, command_id: CommandId) -> None: + # Result set is already drained on the Rust side. + logger.debug("close_command is a no-op in the Rust ADBC POC backend") + + def get_query_state(self, command_id: CommandId) -> CommandState: + # All commands run synchronously and reach SUCCEEDED before returning. + return CommandState.SUCCEEDED + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> "ResultSet": + raise NotImplementedError( + "get_execution_result requires async execution (not supported in POC)." + ) + + # ----- metadata (not yet wired) ----- + + def get_catalogs(self, *args, **kwargs): + raise NotImplementedError("get_catalogs is not supported by the Rust ADBC backend (POC).") + + def get_schemas(self, *args, **kwargs): + raise NotImplementedError("get_schemas is not supported by the Rust ADBC backend (POC).") + + def get_tables(self, *args, **kwargs): + raise NotImplementedError("get_tables is not supported by the Rust ADBC backend (POC).") + + def get_columns(self, *args, **kwargs): + raise NotImplementedError("get_columns is not supported by the Rust ADBC backend (POC).") + + @property + def max_download_threads(self) -> int: + # The kernel manages its own CloudFetch parallelism; this property is + # only consulted by Thrift code paths that don't run for use_sea=True. + return 10 diff --git a/src/databricks/sql/backend/adbc/result_set.py b/src/databricks/sql/backend/adbc/result_set.py new file mode 100644 index 000000000..7d66f589e --- /dev/null +++ b/src/databricks/sql/backend/adbc/result_set.py @@ -0,0 +1,231 @@ +"""ResultSet for the ADBC-Rust-kernel backend (POC, streaming). + +Wraps the streaming `databricks_adbc_pyo3.ResultSet` and pulls record batches +lazily as the connector calls `fetch*`. No full materialization on execute. +""" + +from __future__ import annotations + +import logging +from collections import deque +from typing import Deque, List, Optional, TYPE_CHECKING + +import pyarrow + +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.result_set import ResultSet +from databricks.sql.types import Row + +if TYPE_CHECKING: + from databricks.sql.client import Connection + from databricks.sql.backend.adbc.client import AdbcDatabricksClient + +logger = logging.getLogger(__name__) + + +def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: + """Map a pyarrow type to a Databricks SQL type string for PEP-249 description.""" + if pyarrow.types.is_boolean(arrow_type): + return "boolean" + if pyarrow.types.is_int8(arrow_type): + return "tinyint" + if pyarrow.types.is_int16(arrow_type): + return "smallint" + if pyarrow.types.is_int32(arrow_type): + return "int" + if pyarrow.types.is_int64(arrow_type): + return "bigint" + if pyarrow.types.is_float32(arrow_type): + return "float" + if pyarrow.types.is_float64(arrow_type): + return "double" + if pyarrow.types.is_decimal(arrow_type): + return "decimal" + if pyarrow.types.is_string(arrow_type) or pyarrow.types.is_large_string(arrow_type): + return "string" + if pyarrow.types.is_binary(arrow_type) or pyarrow.types.is_large_binary(arrow_type): + return "binary" + if pyarrow.types.is_date(arrow_type): + return "date" + if pyarrow.types.is_timestamp(arrow_type): + return "timestamp" + if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list(arrow_type): + return "array" + if pyarrow.types.is_struct(arrow_type): + return "struct" + if pyarrow.types.is_map(arrow_type): + return "map" + return str(arrow_type) + + +def description_from_arrow_schema(schema: pyarrow.Schema) -> List[tuple]: + """Build a PEP-249 description list from a pyarrow Schema.""" + return [ + (field.name, _arrow_type_to_dbapi_string(field.type), None, None, None, None, None) + for field in schema + ] + + +class AdbcResultSet(ResultSet): + """Streaming ResultSet adapter over `databricks_adbc_pyo3.ResultSet`. + + Holds a small in-memory buffer of batches to support partial fetches + (`fetchmany(n)`, `fetchone()`) without re-fetching from the kernel. The + buffer is fed by pulling one batch at a time from the kernel reader. + """ + + def __init__( + self, + connection: "Connection", + backend: "AdbcDatabricksClient", + rust_result_set, # databricks_adbc_pyo3.ResultSet + command_id: CommandId, + arraysize: int, + buffer_size_bytes: int, + ): + schema = rust_result_set.arrow_schema() + super().__init__( + connection=connection, + backend=backend, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=CommandState.RUNNING, + has_been_closed_server_side=False, + has_more_rows=True, + results_queue=None, + description=description_from_arrow_schema(schema), + is_staging_operation=False, + lz4_compressed=False, + arrow_schema_bytes=None, + ) + self._rust_rs = rust_result_set + self._schema: pyarrow.Schema = schema + # FIFO of (batch, offset_in_batch) for partial consumption. + self._buffer: Deque[pyarrow.RecordBatch] = deque() + self._buffer_offset: int = 0 # how many rows already fetched out of buffer[0] + self._exhausted: bool = False + + # ----- internal helpers ----- + + def _pull_one_batch(self) -> bool: + """Pull the next batch from the kernel into the buffer. + Returns True if a batch was added, False if reader is exhausted.""" + if self._exhausted: + return False + batch = self._rust_rs.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + return False + if batch.num_rows > 0: + self._buffer.append(batch) + return True + + def _ensure_buffered(self, n_rows: int) -> int: + """Pull batches until we have at least n_rows buffered, or reader exhausted. + Returns total rows currently buffered (>= min(n_rows, total_remaining)).""" + target = n_rows + while self._buffered_rows() < target: + if not self._pull_one_batch(): + break + return self._buffered_rows() + + def _buffered_rows(self) -> int: + if not self._buffer: + return 0 + first = self._buffer[0].num_rows - self._buffer_offset + rest = sum(b.num_rows for b in list(self._buffer)[1:]) + return first + rest + + def _take_buffered(self, n: int) -> pyarrow.Table: + """Slice up to n rows out of the buffer; advances state.""" + slices: List[pyarrow.RecordBatch] = [] + remaining = n + while remaining > 0 and self._buffer: + head = self._buffer[0] + avail = head.num_rows - self._buffer_offset + take = min(avail, remaining) + slices.append(head.slice(self._buffer_offset, take)) + self._buffer_offset += take + remaining -= take + if self._buffer_offset >= head.num_rows: + self._buffer.popleft() + self._buffer_offset = 0 + self._next_row_index += (n - remaining) + if not slices: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(slices, schema=self._schema) + + def _drain(self) -> pyarrow.Table: + """Consume everything that's left and return as a single Table.""" + chunks: List[pyarrow.RecordBatch] = [] + # First flush any partially consumed head batch. + if self._buffer and self._buffer_offset > 0: + head = self._buffer.popleft() + chunks.append(head.slice(self._buffer_offset, head.num_rows - self._buffer_offset)) + self._buffer_offset = 0 + # Then everything else already buffered. + while self._buffer: + chunks.append(self._buffer.popleft()) + # Then pull whatever remains from the kernel. + if not self._exhausted: + while True: + batch = self._rust_rs.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + break + if batch.num_rows > 0: + chunks.append(batch) + rows = sum(c.num_rows for c in chunks) + self._next_row_index += rows + if not chunks: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(chunks, schema=self._schema) + + # ----- Arrow fetches ----- + + def fetchall_arrow(self) -> pyarrow.Table: + return self._drain() + + def fetchmany_arrow(self, size: int) -> pyarrow.Table: + if size < 0: + raise ValueError(f"fetchmany_arrow size must be >= 0, got {size}") + if size == 0: + return pyarrow.Table.from_batches([], schema=self._schema) + self._ensure_buffered(size) + return self._take_buffered(size) + + # ----- Row fetches ----- + + def fetchone(self) -> Optional[Row]: + self._ensure_buffered(1) + if self._buffered_rows() == 0: + return None + table = self._take_buffered(1) + rows = self._convert_arrow_table(table) + return rows[0] if rows else None + + def fetchmany(self, size: int) -> List[Row]: + if size < 0: + raise ValueError(f"fetchmany size must be >= 0, got {size}") + if size == 0: + return [] + self._ensure_buffered(size) + table = self._take_buffered(size) + return self._convert_arrow_table(table) + + def fetchall(self) -> List[Row]: + return self._convert_arrow_table(self._drain()) + + def close(self) -> None: + # Drop our handle to the streaming reader; PyO3 Drop releases + # kernel-side resources (HTTP connections, buffered chunks). + self._buffer.clear() + self._rust_rs = None + self._exhausted = True + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 65c0d6aca..0f466b865 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -9,7 +9,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.adbc.client import AdbcDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -123,14 +123,20 @@ def _create_backend( """Create and return the appropriate backend client.""" self.use_sea = kwargs.get("use_sea", False) - databricks_client_class: Type[DatabricksClient] if self.use_sea: - logger.debug("Creating SEA backend client") - databricks_client_class = SeaDatabricksClient - else: - logger.debug("Creating Thrift backend client") - databricks_client_class = ThriftDatabricksClient - + logger.debug("Creating ADBC-Rust SEA backend client") + # The Rust kernel handles its own HTTP, auth, headers, and SSL — we + # only need to forward the connection params it understands. + adbc_args = { + "server_hostname": server_hostname, + "http_path": http_path, + "access_token": kwargs.get("access_token"), + "catalog": kwargs.get("catalog"), + "schema": kwargs.get("schema"), + } + return AdbcDatabricksClient(**adbc_args) + + logger.debug("Creating Thrift backend client") common_args = { "server_hostname": server_hostname, "port": self.port, @@ -142,7 +148,7 @@ def _create_backend( "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } - return databricks_client_class(**common_args) + return ThriftDatabricksClient(**common_args) @staticmethod def _extract_spog_headers(http_path, existing_headers):