Skip to content

Commit cbd6a88

Browse files
authored
Telemetry: unwrap TokenFederationProvider to report inner auth mech/flow (#781)
* Telemetry: report wrapped provider's auth mech/flow under TokenFederation TokenFederationProvider wraps an underlying auth provider (PAT, OAuth, M2M) and only adds a token-exchange step. The telemetry helpers previously fell through to AuthMech.OTHER with no flow, hiding the actual auth method. Unwrap and recurse on external_provider so federated PAT reports PAT, federated M2M reports CLIENT_CREDENTIALS, etc. Co-authored-by: Isaac * Strengthen telemetry token-federation tests Build real TokenFederationProvider instances (instead of MagicMocks) so attribute renames on external_provider break the test rather than passing silently. Add a payload-serialization assertion confirming the federated PAT case emits "auth_mech": "PAT" in the JSON event, and a None-inner- provider edge case. Co-authored-by: Isaac * Add end-to-end telemetry test for federated PAT through Connection Sets the mocked Session's auth_provider to a real TokenFederationProvider wrapping AccessTokenAuthProvider, then asserts the captured DriverConnectionParameters reports auth_mech=PAT. This catches regressions in the wiring at client.py:383-384 (e.g., wrong provider passed to TelemetryHelper) that the helper-only tests would miss. Co-authored-by: Isaac
1 parent ee63b81 commit cbd6a88

2 files changed

Lines changed: 96 additions & 0 deletions

File tree

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
DatabricksOAuthProvider,
3535
ExternalAuthProvider,
3636
)
37+
from databricks.sql.auth.token_federation import TokenFederationProvider
3738
import sys
3839
import platform
3940
import uuid
@@ -90,6 +91,8 @@ def get_auth_mechanism(auth_provider):
9091

9192
if not auth_provider:
9293
return None
94+
if isinstance(auth_provider, TokenFederationProvider):
95+
return TelemetryHelper.get_auth_mechanism(auth_provider.external_provider)
9396
if isinstance(auth_provider, AccessTokenAuthProvider):
9497
return AuthMech.PAT
9598
elif isinstance(auth_provider, DatabricksOAuthProvider):
@@ -105,6 +108,8 @@ def get_auth_flow(auth_provider):
105108

106109
if not auth_provider:
107110
return None
111+
if isinstance(auth_provider, TokenFederationProvider):
112+
return TelemetryHelper.get_auth_flow(auth_provider.external_provider)
108113
if isinstance(auth_provider, DatabricksOAuthProvider):
109114
if auth_provider._access_token and auth_provider._refresh_token:
110115
return AuthFlow.TOKEN_PASSTHROUGH

tests/unit/test_telemetry.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DatabricksOAuthProvider,
3030
ExternalAuthProvider,
3131
)
32+
from databricks.sql.auth.token_federation import TokenFederationProvider
3233
from databricks import sql
3334

3435

@@ -211,6 +212,64 @@ def test_auth_flow_detection(self):
211212
# Test None auth provider
212213
assert TelemetryHelper.get_auth_flow(None) is None
213214

215+
def _make_real_federation(self, inner):
216+
"""Build a real TokenFederationProvider so attribute renames break tests."""
217+
return TokenFederationProvider(
218+
hostname="example.cloud.databricks.com",
219+
external_provider=inner,
220+
http_client=MagicMock(),
221+
)
222+
223+
def test_token_federation_unwraps_pat(self):
224+
fed = self._make_real_federation(AccessTokenAuthProvider("test-token"))
225+
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.PAT
226+
assert TelemetryHelper.get_auth_flow(fed) is None
227+
228+
def test_token_federation_unwraps_m2m(self):
229+
fed = self._make_real_federation(MagicMock(spec=ExternalAuthProvider))
230+
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OTHER
231+
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.CLIENT_CREDENTIALS
232+
233+
def test_token_federation_unwraps_oauth_browser(self):
234+
oauth = MagicMock(spec=DatabricksOAuthProvider)
235+
oauth._access_token = None
236+
oauth._refresh_token = None
237+
fed = self._make_real_federation(oauth)
238+
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH
239+
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.BROWSER_BASED_AUTHENTICATION
240+
241+
def test_token_federation_unwraps_oauth_passthrough(self):
242+
oauth = MagicMock(spec=DatabricksOAuthProvider)
243+
oauth._access_token = "a"
244+
oauth._refresh_token = "r"
245+
fed = self._make_real_federation(oauth)
246+
assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH
247+
assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.TOKEN_PASSTHROUGH
248+
249+
def test_token_federation_payload_serialization(self):
250+
"""End-to-end: federated PAT must serialize as PAT in the connection-params payload."""
251+
fed = self._make_real_federation(AccessTokenAuthProvider("test-token"))
252+
params = DriverConnectionParameters(
253+
http_path="/sql/1.0/warehouses/abc",
254+
mode=DatabricksClientType.THRIFT,
255+
host_info=HostDetails(host_url="https://example.cloud.databricks.com", port=443),
256+
auth_mech=TelemetryHelper.get_auth_mechanism(fed),
257+
auth_flow=TelemetryHelper.get_auth_flow(fed),
258+
)
259+
payload = json.loads(params.to_json())
260+
assert payload["auth_mech"] == "PAT"
261+
assert "auth_flow" not in payload # None-valued fields are stripped
262+
263+
def test_token_federation_with_no_inner_provider(self):
264+
"""Federation with a None inner provider should not crash; both helpers return None."""
265+
fed = TokenFederationProvider(
266+
hostname="example.cloud.databricks.com",
267+
external_provider=None,
268+
http_client=MagicMock(),
269+
)
270+
assert TelemetryHelper.get_auth_mechanism(fed) is None
271+
assert TelemetryHelper.get_auth_flow(fed) is None
272+
214273

215274
class TestTelemetryFactory:
216275
"""Tests for TelemetryClientFactory lifecycle and management."""
@@ -811,6 +870,38 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool
811870
assert driver_params.async_poll_interval_millis == 2000
812871
assert driver_params.support_many_parameters is True
813872

873+
def test_federated_pat_populates_telemetry_as_pat(self, mock_setup_pools, mock_session):
874+
"""End-to-end: a TokenFederationProvider wrapping a PAT should report mech=PAT in the captured telemetry payload."""
875+
federated_pat = TokenFederationProvider(
876+
hostname="workspace.databricks.com",
877+
external_provider=AccessTokenAuthProvider("token"),
878+
http_client=MagicMock(),
879+
)
880+
mock_session_instance = MagicMock()
881+
mock_session_instance.guid_hex = "test-session-fed-pat"
882+
mock_session_instance.auth_provider = federated_pat
883+
mock_session_instance.is_open = False
884+
mock_session_instance.use_sea = False
885+
mock_session_instance.port = 443
886+
mock_session_instance.host = "workspace.databricks.com"
887+
mock_session.return_value = mock_session_instance
888+
889+
with patch(
890+
"databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log"
891+
) as mock_export:
892+
sql.connect(
893+
server_hostname="workspace.databricks.com",
894+
http_path="/sql/1.0/warehouses/test",
895+
access_token="test-token",
896+
enable_telemetry=True,
897+
force_enable_telemetry=True,
898+
)
899+
900+
mock_export.assert_called_once()
901+
driver_params = mock_export.call_args.kwargs.get("driver_connection_params")
902+
assert driver_params.auth_mech == AuthMech.PAT
903+
assert driver_params.auth_flow is None
904+
814905
def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session):
815906
"""Test that CloudFlare proxy fields default to False/None (not yet supported)."""
816907
mock_session_instance = MagicMock()

0 commit comments

Comments
 (0)