diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 55d845e46..6297688fc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -34,6 +34,7 @@ DatabricksOAuthProvider, ExternalAuthProvider, ) +from databricks.sql.auth.token_federation import TokenFederationProvider import sys import platform import uuid @@ -90,6 +91,8 @@ def get_auth_mechanism(auth_provider): if not auth_provider: return None + if isinstance(auth_provider, TokenFederationProvider): + return TelemetryHelper.get_auth_mechanism(auth_provider.external_provider) if isinstance(auth_provider, AccessTokenAuthProvider): return AuthMech.PAT elif isinstance(auth_provider, DatabricksOAuthProvider): @@ -105,6 +108,8 @@ def get_auth_flow(auth_provider): if not auth_provider: return None + if isinstance(auth_provider, TokenFederationProvider): + return TelemetryHelper.get_auth_flow(auth_provider.external_provider) if isinstance(auth_provider, DatabricksOAuthProvider): if auth_provider._access_token and auth_provider._refresh_token: return AuthFlow.TOKEN_PASSTHROUGH diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 86f06aa8a..4f62fb833 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -29,6 +29,7 @@ DatabricksOAuthProvider, ExternalAuthProvider, ) +from databricks.sql.auth.token_federation import TokenFederationProvider from databricks import sql @@ -211,6 +212,64 @@ def test_auth_flow_detection(self): # Test None auth provider assert TelemetryHelper.get_auth_flow(None) is None + def _make_real_federation(self, inner): + """Build a real TokenFederationProvider so attribute renames break tests.""" + return TokenFederationProvider( + hostname="example.cloud.databricks.com", + external_provider=inner, + http_client=MagicMock(), + ) + + def test_token_federation_unwraps_pat(self): + fed = self._make_real_federation(AccessTokenAuthProvider("test-token")) + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.PAT + assert TelemetryHelper.get_auth_flow(fed) is None + + def test_token_federation_unwraps_m2m(self): + fed = self._make_real_federation(MagicMock(spec=ExternalAuthProvider)) + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OTHER + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.CLIENT_CREDENTIALS + + def test_token_federation_unwraps_oauth_browser(self): + oauth = MagicMock(spec=DatabricksOAuthProvider) + oauth._access_token = None + oauth._refresh_token = None + fed = self._make_real_federation(oauth) + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + def test_token_federation_unwraps_oauth_passthrough(self): + oauth = MagicMock(spec=DatabricksOAuthProvider) + oauth._access_token = "a" + oauth._refresh_token = "r" + fed = self._make_real_federation(oauth) + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.TOKEN_PASSTHROUGH + + def test_token_federation_payload_serialization(self): + """End-to-end: federated PAT must serialize as PAT in the connection-params payload.""" + fed = self._make_real_federation(AccessTokenAuthProvider("test-token")) + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc", + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url="https://example.cloud.databricks.com", port=443), + auth_mech=TelemetryHelper.get_auth_mechanism(fed), + auth_flow=TelemetryHelper.get_auth_flow(fed), + ) + payload = json.loads(params.to_json()) + assert payload["auth_mech"] == "PAT" + assert "auth_flow" not in payload # None-valued fields are stripped + + def test_token_federation_with_no_inner_provider(self): + """Federation with a None inner provider should not crash; both helpers return None.""" + fed = TokenFederationProvider( + hostname="example.cloud.databricks.com", + external_provider=None, + http_client=MagicMock(), + ) + assert TelemetryHelper.get_auth_mechanism(fed) is None + assert TelemetryHelper.get_auth_flow(fed) is None + class TestTelemetryFactory: """Tests for TelemetryClientFactory lifecycle and management.""" @@ -811,6 +870,38 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool assert driver_params.async_poll_interval_millis == 2000 assert driver_params.support_many_parameters is True + def test_federated_pat_populates_telemetry_as_pat(self, mock_setup_pools, mock_session): + """End-to-end: a TokenFederationProvider wrapping a PAT should report mech=PAT in the captured telemetry payload.""" + federated_pat = TokenFederationProvider( + hostname="workspace.databricks.com", + external_provider=AccessTokenAuthProvider("token"), + http_client=MagicMock(), + ) + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-fed-pat" + mock_session_instance.auth_provider = federated_pat + mock_session_instance.is_open = False + mock_session_instance.use_sea = False + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" + ) as mock_export: + sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + assert driver_params.auth_mech == AuthMech.PAT + assert driver_params.auth_flow is None + def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): """Test that CloudFlare proxy fields default to False/None (not yet supported).""" mock_session_instance = MagicMock()