1- from enum import Enum
21from typing import Optional , List
32
43from databricks .sql .auth .authenticators import (
54 AuthProvider ,
65 AccessTokenAuthProvider ,
76 ExternalAuthProvider ,
87 DatabricksOAuthProvider ,
8+ AzureServicePrincipalCredentialProvider ,
99)
10-
11-
12- class AuthType (Enum ):
13- DATABRICKS_OAUTH = "databricks-oauth"
14- AZURE_OAUTH = "azure-oauth"
15- # other supported types (access_token) can be inferred
16- # we can add more types as needed later
17-
18-
19- class ClientContext :
20- def __init__ (
21- self ,
22- hostname : str ,
23- access_token : Optional [str ] = None ,
24- auth_type : Optional [str ] = None ,
25- oauth_scopes : Optional [List [str ]] = None ,
26- oauth_client_id : Optional [str ] = None ,
27- oauth_redirect_port_range : Optional [List [int ]] = None ,
28- use_cert_as_auth : Optional [str ] = None ,
29- tls_client_cert_file : Optional [str ] = None ,
30- oauth_persistence = None ,
31- credentials_provider = None ,
32- ):
33- self .hostname = hostname
34- self .access_token = access_token
35- self .auth_type = auth_type
36- self .oauth_scopes = oauth_scopes
37- self .oauth_client_id = oauth_client_id
38- self .oauth_redirect_port_range = oauth_redirect_port_range
39- self .use_cert_as_auth = use_cert_as_auth
40- self .tls_client_cert_file = tls_client_cert_file
41- self .oauth_persistence = oauth_persistence
42- self .credentials_provider = credentials_provider
10+ from databricks .sql .auth .common import AuthType , ClientContext
4311
4412
4513def get_auth_provider (cfg : ClientContext ):
4614 if cfg .credentials_provider :
4715 return ExternalAuthProvider (cfg .credentials_provider )
48- if cfg .auth_type in [AuthType .DATABRICKS_OAUTH .value , AuthType .AZURE_OAUTH .value ]:
16+ elif cfg .auth_type == AuthType .AZURE_SP_M2M .value :
17+ return ExternalAuthProvider (
18+ AzureServicePrincipalCredentialProvider (
19+ cfg .hostname ,
20+ cfg .azure_client_id ,
21+ cfg .azure_client_secret ,
22+ cfg .azure_tenant_id ,
23+ cfg .azure_workspace_resource_id ,
24+ )
25+ )
26+ elif cfg .auth_type in [AuthType .DATABRICKS_OAUTH .value , AuthType .AZURE_OAUTH .value ]:
4927 assert cfg .oauth_redirect_port_range is not None
5028 assert cfg .oauth_client_id is not None
5129 assert cfg .oauth_scopes is not None
@@ -102,10 +80,13 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
10280
10381
10482def get_python_sql_connector_auth_provider (hostname : str , ** kwargs ):
83+ # TODO : unify all the auth mechanisms with the Python SDK
84+
10585 auth_type = kwargs .get ("auth_type" )
10686 (client_id , redirect_port_range ) = get_client_id_and_redirect_port (
10787 auth_type == AuthType .AZURE_OAUTH .value
10888 )
89+
10990 if kwargs .get ("username" ) or kwargs .get ("password" ):
11091 raise ValueError (
11192 "Username/password authentication is no longer supported. "
@@ -120,6 +101,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
120101 tls_client_cert_file = kwargs .get ("_tls_client_cert_file" ),
121102 oauth_scopes = PYSQL_OAUTH_SCOPES ,
122103 oauth_client_id = kwargs .get ("oauth_client_id" ) or client_id ,
104+ azure_client_id = kwargs .get ("azure_client_id" ),
105+ azure_client_secret = kwargs .get ("azure_client_secret" ),
106+ azure_tenant_id = kwargs .get ("azure_tenant_id" ),
107+ azure_workspace_resource_id = kwargs .get ("azure_workspace_resource_id" ),
123108 oauth_redirect_port_range = [kwargs ["oauth_redirect_port" ]]
124109 if kwargs .get ("oauth_client_id" ) and kwargs .get ("oauth_redirect_port" )
125110 else redirect_port_range ,
0 commit comments