Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from databricks.sql.auth.token_federation import TokenFederationProvider
import sys
import platform
import uuid
Expand Down Expand Up @@ -90,6 +91,8 @@ def get_auth_mechanism(auth_provider):

if not auth_provider:
return None
if isinstance(auth_provider, TokenFederationProvider):
return TelemetryHelper.get_auth_mechanism(auth_provider.external_provider)
if isinstance(auth_provider, AccessTokenAuthProvider):
return AuthMech.PAT
elif isinstance(auth_provider, DatabricksOAuthProvider):
Expand All @@ -105,6 +108,8 @@ def get_auth_flow(auth_provider):

if not auth_provider:
return None
if isinstance(auth_provider, TokenFederationProvider):
return TelemetryHelper.get_auth_flow(auth_provider.external_provider)
if isinstance(auth_provider, DatabricksOAuthProvider):
if auth_provider._access_token and auth_provider._refresh_token:
return AuthFlow.TOKEN_PASSTHROUGH
Expand Down
91 changes: 91 additions & 0 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from databricks.sql.auth.token_federation import TokenFederationProvider
from databricks import sql


Expand Down Expand Up @@ -211,6 +212,64 @@ def test_auth_flow_detection(self):
# Test None auth provider
assert TelemetryHelper.get_auth_flow(None) is None

def _make_real_federation(self, inner):
"""Build a real TokenFederationProvider so attribute renames break tests."""
return TokenFederationProvider(
hostname="example.cloud.databricks.com",
external_provider=inner,
http_client=MagicMock(),
)

def test_token_federation_unwraps_pat(self):
fed = self._make_real_federation(AccessTokenAuthProvider("test-token"))
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.PAT
assert TelemetryHelper.get_auth_flow(fed) is None

def test_token_federation_unwraps_m2m(self):
fed = self._make_real_federation(MagicMock(spec=ExternalAuthProvider))
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OTHER
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.CLIENT_CREDENTIALS

def test_token_federation_unwraps_oauth_browser(self):
oauth = MagicMock(spec=DatabricksOAuthProvider)
oauth._access_token = None
oauth._refresh_token = None
fed = self._make_real_federation(oauth)
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.BROWSER_BASED_AUTHENTICATION

def test_token_federation_unwraps_oauth_passthrough(self):
oauth = MagicMock(spec=DatabricksOAuthProvider)
oauth._access_token = "a"
oauth._refresh_token = "r"
fed = self._make_real_federation(oauth)
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.TOKEN_PASSTHROUGH

def test_token_federation_payload_serialization(self):
"""End-to-end: federated PAT must serialize as PAT in the connection-params payload."""
fed = self._make_real_federation(AccessTokenAuthProvider("test-token"))
params = DriverConnectionParameters(
http_path="/sql/1.0/warehouses/abc",
mode=DatabricksClientType.THRIFT,
host_info=HostDetails(host_url="https://example.cloud.databricks.com", port=443),
auth_mech=TelemetryHelper.get_auth_mechanism(fed),
auth_flow=TelemetryHelper.get_auth_flow(fed),
)
payload = json.loads(params.to_json())
assert payload["auth_mech"] == "PAT"
assert "auth_flow" not in payload # None-valued fields are stripped

def test_token_federation_with_no_inner_provider(self):
"""Federation with a None inner provider should not crash; both helpers return None."""
fed = TokenFederationProvider(
hostname="example.cloud.databricks.com",
external_provider=None,
http_client=MagicMock(),
)
assert TelemetryHelper.get_auth_mechanism(fed) is None
assert TelemetryHelper.get_auth_flow(fed) is None


class TestTelemetryFactory:
"""Tests for TelemetryClientFactory lifecycle and management."""
Expand Down Expand Up @@ -811,6 +870,38 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool
assert driver_params.async_poll_interval_millis == 2000
assert driver_params.support_many_parameters is True

def test_federated_pat_populates_telemetry_as_pat(self, mock_setup_pools, mock_session):
"""End-to-end: a TokenFederationProvider wrapping a PAT should report mech=PAT in the captured telemetry payload."""
federated_pat = TokenFederationProvider(
hostname="workspace.databricks.com",
external_provider=AccessTokenAuthProvider("token"),
http_client=MagicMock(),
)
mock_session_instance = MagicMock()
mock_session_instance.guid_hex = "test-session-fed-pat"
mock_session_instance.auth_provider = federated_pat
mock_session_instance.is_open = False
mock_session_instance.use_sea = False
mock_session_instance.port = 443
mock_session_instance.host = "workspace.databricks.com"
mock_session.return_value = mock_session_instance

with patch(
"databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log"
) as mock_export:
sql.connect(
server_hostname="workspace.databricks.com",
http_path="/sql/1.0/warehouses/test",
access_token="test-token",
enable_telemetry=True,
force_enable_telemetry=True,
)

mock_export.assert_called_once()
driver_params = mock_export.call_args.kwargs.get("driver_connection_params")
assert driver_params.auth_mech == AuthMech.PAT
assert driver_params.auth_flow is None

def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session):
"""Test that CloudFlare proxy fields default to False/None (not yet supported)."""
mock_session_instance = MagicMock()
Expand Down
Loading