|
29 | 29 | DatabricksOAuthProvider, |
30 | 30 | ExternalAuthProvider, |
31 | 31 | ) |
| 32 | +from databricks.sql.auth.token_federation import TokenFederationProvider |
32 | 33 | from databricks import sql |
33 | 34 |
|
34 | 35 |
|
@@ -211,6 +212,64 @@ def test_auth_flow_detection(self): |
211 | 212 | # Test None auth provider |
212 | 213 | assert TelemetryHelper.get_auth_flow(None) is None |
213 | 214 |
|
| 215 | + def _make_real_federation(self, inner): |
| 216 | + """Build a real TokenFederationProvider so attribute renames break tests.""" |
| 217 | + return TokenFederationProvider( |
| 218 | + hostname="example.cloud.databricks.com", |
| 219 | + external_provider=inner, |
| 220 | + http_client=MagicMock(), |
| 221 | + ) |
| 222 | + |
| 223 | + def test_token_federation_unwraps_pat(self): |
| 224 | + fed = self._make_real_federation(AccessTokenAuthProvider("test-token")) |
| 225 | + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.PAT |
| 226 | + assert TelemetryHelper.get_auth_flow(fed) is None |
| 227 | + |
| 228 | + def test_token_federation_unwraps_m2m(self): |
| 229 | + fed = self._make_real_federation(MagicMock(spec=ExternalAuthProvider)) |
| 230 | + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OTHER |
| 231 | + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.CLIENT_CREDENTIALS |
| 232 | + |
| 233 | + def test_token_federation_unwraps_oauth_browser(self): |
| 234 | + oauth = MagicMock(spec=DatabricksOAuthProvider) |
| 235 | + oauth._access_token = None |
| 236 | + oauth._refresh_token = None |
| 237 | + fed = self._make_real_federation(oauth) |
| 238 | + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH |
| 239 | + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.BROWSER_BASED_AUTHENTICATION |
| 240 | + |
| 241 | + def test_token_federation_unwraps_oauth_passthrough(self): |
| 242 | + oauth = MagicMock(spec=DatabricksOAuthProvider) |
| 243 | + oauth._access_token = "a" |
| 244 | + oauth._refresh_token = "r" |
| 245 | + fed = self._make_real_federation(oauth) |
| 246 | + assert TelemetryHelper.get_auth_mechanism(fed) == AuthMech.OAUTH |
| 247 | + assert TelemetryHelper.get_auth_flow(fed) == AuthFlow.TOKEN_PASSTHROUGH |
| 248 | + |
| 249 | + def test_token_federation_payload_serialization(self): |
| 250 | + """End-to-end: federated PAT must serialize as PAT in the connection-params payload.""" |
| 251 | + fed = self._make_real_federation(AccessTokenAuthProvider("test-token")) |
| 252 | + params = DriverConnectionParameters( |
| 253 | + http_path="/sql/1.0/warehouses/abc", |
| 254 | + mode=DatabricksClientType.THRIFT, |
| 255 | + host_info=HostDetails(host_url="https://example.cloud.databricks.com", port=443), |
| 256 | + auth_mech=TelemetryHelper.get_auth_mechanism(fed), |
| 257 | + auth_flow=TelemetryHelper.get_auth_flow(fed), |
| 258 | + ) |
| 259 | + payload = json.loads(params.to_json()) |
| 260 | + assert payload["auth_mech"] == "PAT" |
| 261 | + assert "auth_flow" not in payload # None-valued fields are stripped |
| 262 | + |
| 263 | + def test_token_federation_with_no_inner_provider(self): |
| 264 | + """Federation with a None inner provider should not crash; both helpers return None.""" |
| 265 | + fed = TokenFederationProvider( |
| 266 | + hostname="example.cloud.databricks.com", |
| 267 | + external_provider=None, |
| 268 | + http_client=MagicMock(), |
| 269 | + ) |
| 270 | + assert TelemetryHelper.get_auth_mechanism(fed) is None |
| 271 | + assert TelemetryHelper.get_auth_flow(fed) is None |
| 272 | + |
214 | 273 |
|
215 | 274 | class TestTelemetryFactory: |
216 | 275 | """Tests for TelemetryClientFactory lifecycle and management.""" |
@@ -811,6 +870,38 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool |
811 | 870 | assert driver_params.async_poll_interval_millis == 2000 |
812 | 871 | assert driver_params.support_many_parameters is True |
813 | 872 |
|
| 873 | + def test_federated_pat_populates_telemetry_as_pat(self, mock_setup_pools, mock_session): |
| 874 | + """End-to-end: a TokenFederationProvider wrapping a PAT should report mech=PAT in the captured telemetry payload.""" |
| 875 | + federated_pat = TokenFederationProvider( |
| 876 | + hostname="workspace.databricks.com", |
| 877 | + external_provider=AccessTokenAuthProvider("token"), |
| 878 | + http_client=MagicMock(), |
| 879 | + ) |
| 880 | + mock_session_instance = MagicMock() |
| 881 | + mock_session_instance.guid_hex = "test-session-fed-pat" |
| 882 | + mock_session_instance.auth_provider = federated_pat |
| 883 | + mock_session_instance.is_open = False |
| 884 | + mock_session_instance.use_sea = False |
| 885 | + mock_session_instance.port = 443 |
| 886 | + mock_session_instance.host = "workspace.databricks.com" |
| 887 | + mock_session.return_value = mock_session_instance |
| 888 | + |
| 889 | + with patch( |
| 890 | + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" |
| 891 | + ) as mock_export: |
| 892 | + sql.connect( |
| 893 | + server_hostname="workspace.databricks.com", |
| 894 | + http_path="/sql/1.0/warehouses/test", |
| 895 | + access_token="test-token", |
| 896 | + enable_telemetry=True, |
| 897 | + force_enable_telemetry=True, |
| 898 | + ) |
| 899 | + |
| 900 | + mock_export.assert_called_once() |
| 901 | + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") |
| 902 | + assert driver_params.auth_mech == AuthMech.PAT |
| 903 | + assert driver_params.auth_flow is None |
| 904 | + |
814 | 905 | def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): |
815 | 906 | """Test that CloudFlare proxy fields default to False/None (not yet supported).""" |
816 | 907 | mock_session_instance = MagicMock() |
|
0 commit comments