99from uuid import UUID
1010
1111from databricks .sql .result_set import ThriftResultSet
12-
12+ from databricks . sql . telemetry . models . event import StatementType
1313
1414if TYPE_CHECKING :
1515 from databricks .sql .client import Cursor
1616 from databricks .sql .result_set import ResultSet
17- from databricks .sql .telemetry .models .event import StatementType
1817
1918from databricks .sql .backend .types import (
2019 CommandState ,
@@ -833,7 +832,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
833832 return execute_response , is_direct_results
834833
835834 def get_execution_result (
836- self , command_id : CommandId , cursor : "Cursor" , statement_type : StatementType
835+ self , command_id : CommandId , cursor : "Cursor"
837836 ) -> "ResultSet" :
838837 thrift_handle = command_id .to_thrift_handle ()
839838 if not thrift_handle :
@@ -889,6 +888,7 @@ def get_execution_result(
889888 arrow_schema_bytes = schema_bytes ,
890889 result_format = t_result_set_metadata_resp .resultFormat ,
891890 )
891+ execute_response .command_id .set_statement_type (StatementType .QUERY )
892892
893893 return ThriftResultSet (
894894 connection = cursor .connection ,
@@ -902,7 +902,6 @@ def get_execution_result(
902902 ssl_options = self ._ssl_options ,
903903 is_direct_results = is_direct_results ,
904904 session_id_hex = self ._session_id_hex ,
905- statement_type = statement_type ,
906905 )
907906
908907 def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -968,7 +967,6 @@ def execute_command(
968967 max_bytes : int ,
969968 lz4_compression : bool ,
970969 cursor : Cursor ,
971- statement_type : StatementType ,
972970 use_cloud_fetch = True ,
973971 parameters = [],
974972 async_op = False ,
@@ -1030,6 +1028,8 @@ def execute_command(
10301028 if resp .directResults and resp .directResults .resultSet :
10311029 t_row_set = resp .directResults .resultSet .results
10321030
1031+ execute_response .command_id .set_statement_type (StatementType .QUERY )
1032+
10331033 return ThriftResultSet (
10341034 connection = cursor .connection ,
10351035 execute_response = execute_response ,
@@ -1042,7 +1042,6 @@ def execute_command(
10421042 ssl_options = self ._ssl_options ,
10431043 is_direct_results = is_direct_results ,
10441044 session_id_hex = self ._session_id_hex ,
1045- statement_type = statement_type ,
10461045 )
10471046
10481047 def get_catalogs (
@@ -1051,7 +1050,6 @@ def get_catalogs(
10511050 max_rows : int ,
10521051 max_bytes : int ,
10531052 cursor : "Cursor" ,
1054- statement_type : StatementType ,
10551053 ) -> "ResultSet" :
10561054 thrift_handle = session_id .to_thrift_handle ()
10571055 if not thrift_handle :
@@ -1073,6 +1071,8 @@ def get_catalogs(
10731071 if resp .directResults and resp .directResults .resultSet :
10741072 t_row_set = resp .directResults .resultSet .results
10751073
1074+ execute_response .command_id .set_statement_type (StatementType .METADATA )
1075+
10761076 return ThriftResultSet (
10771077 connection = cursor .connection ,
10781078 execute_response = execute_response ,
@@ -1085,7 +1085,6 @@ def get_catalogs(
10851085 ssl_options = self ._ssl_options ,
10861086 is_direct_results = is_direct_results ,
10871087 session_id_hex = self ._session_id_hex ,
1088- statement_id = statement_type ,
10891088 )
10901089
10911090 def get_schemas (
@@ -1094,7 +1093,6 @@ def get_schemas(
10941093 max_rows : int ,
10951094 max_bytes : int ,
10961095 cursor : Cursor ,
1097- statement_type : StatementType ,
10981096 catalog_name = None ,
10991097 schema_name = None ,
11001098 ) -> "ResultSet" :
@@ -1122,6 +1120,8 @@ def get_schemas(
11221120 if resp .directResults and resp .directResults .resultSet :
11231121 t_row_set = resp .directResults .resultSet .results
11241122
1123+ execute_response .command_id .set_statement_type (StatementType .METADATA )
1124+
11251125 return ThriftResultSet (
11261126 connection = cursor .connection ,
11271127 execute_response = execute_response ,
@@ -1134,7 +1134,6 @@ def get_schemas(
11341134 ssl_options = self ._ssl_options ,
11351135 is_direct_results = is_direct_results ,
11361136 session_id_hex = self ._session_id_hex ,
1137- statement_type = statement_type ,
11381137 )
11391138
11401139 def get_tables (
@@ -1143,7 +1142,6 @@ def get_tables(
11431142 max_rows : int ,
11441143 max_bytes : int ,
11451144 cursor : Cursor ,
1146- statement_type : StatementType ,
11471145 catalog_name = None ,
11481146 schema_name = None ,
11491147 table_name = None ,
@@ -1175,6 +1173,8 @@ def get_tables(
11751173 if resp .directResults and resp .directResults .resultSet :
11761174 t_row_set = resp .directResults .resultSet .results
11771175
1176+ execute_response .command_id .set_statement_type (StatementType .METADATA )
1177+
11781178 return ThriftResultSet (
11791179 connection = cursor .connection ,
11801180 execute_response = execute_response ,
@@ -1187,7 +1187,6 @@ def get_tables(
11871187 ssl_options = self ._ssl_options ,
11881188 is_direct_results = is_direct_results ,
11891189 session_id_hex = self ._session_id_hex ,
1190- statement_type = statement_type ,
11911190 )
11921191
11931192 def get_columns (
@@ -1196,7 +1195,6 @@ def get_columns(
11961195 max_rows : int ,
11971196 max_bytes : int ,
11981197 cursor : Cursor ,
1199- statement_type : StatementType ,
12001198 catalog_name = None ,
12011199 schema_name = None ,
12021200 table_name = None ,
@@ -1228,6 +1226,8 @@ def get_columns(
12281226 if resp .directResults and resp .directResults .resultSet :
12291227 t_row_set = resp .directResults .resultSet .results
12301228
1229+ execute_response .command_id .set_statement_type (StatementType .METADATA )
1230+
12311231 return ThriftResultSet (
12321232 connection = cursor .connection ,
12331233 execute_response = execute_response ,
@@ -1240,7 +1240,6 @@ def get_columns(
12401240 ssl_options = self ._ssl_options ,
12411241 is_direct_results = is_direct_results ,
12421242 session_id_hex = self ._session_id_hex ,
1243- statement_type = statement_type ,
12441243 )
12451244
12461245 def _handle_execute_response (self , resp , cursor ):
@@ -1275,7 +1274,6 @@ def fetch_results(
12751274 lz4_compressed : bool ,
12761275 arrow_schema_bytes ,
12771276 description ,
1278- statement_type ,
12791277 chunk_id : int ,
12801278 use_cloud_fetch = True ,
12811279 ):
@@ -1316,7 +1314,7 @@ def fetch_results(
13161314 ssl_options = self ._ssl_options ,
13171315 session_id_hex = self ._session_id_hex ,
13181316 statement_id = command_id .to_hex_guid (),
1319- statement_type = statement_type ,
1317+ statement_type = command_id . statement_type ,
13201318 chunk_id = chunk_id ,
13211319 )
13221320
0 commit comments