@@ -201,6 +201,7 @@ def cloud_time_to_timestamp(self, cloud_time_string):
201201 def report_status (self ,
202202 completed = False ,
203203 progress = None ,
204+ source_operation_response = None ,
204205 exception_details = None ):
205206 """Reports to the service status of a work item (completion or progress).
206207
@@ -209,6 +210,7 @@ def report_status(self,
209210 either because it succeeded or because it failed. False if this is a
210211 progress report.
211212 progress: Progress of processing the work_item.
213+ source_operation_response: Response to a custom source operation
212214 exception_details: A string representation of the stack trace for an
213215 exception raised while executing the work item. The string is the
214216 output of the standard traceback.format_exc() function.
@@ -226,7 +228,8 @@ def report_status(self,
226228 completed ,
227229 progress if not completed else None ,
228230 self .dynamic_split_result_to_report if not completed else None ,
229- exception_details )
231+ source_operation_response = source_operation_response ,
232+ exception_details = exception_details )
230233
231234 # Resetting dynamic_split_result_to_report after reporting status
232235 # successfully.
@@ -368,6 +371,7 @@ def report_completion_status(
368371 self ,
369372 current_work_item ,
370373 progress_reporter ,
374+ source_operation_response = None ,
371375 exception_details = None ):
372376 """Reports to the service a work item completion (successful or failed).
373377
@@ -383,6 +387,7 @@ def report_completion_status(
383387 current_work_item: A WorkItem instance describing the work.
384388 progress_reporter: A ProgressReporter configured to process work item
385389 current_work_item.
390+ source_operation_response: Response to a custom source operation.
386391 exception_details: A string representation of the stack trace for an
387392 exception raised while executing the work item. The string is the
388393 output of the standard traceback.format_exc() function.
@@ -395,8 +400,10 @@ def report_completion_status(
395400 'successfully' if exception_details is None
396401 else 'with exception' )
397402
398- progress_reporter .report_status (completed = True ,
399- exception_details = exception_details )
403+ progress_reporter .report_status (
404+ completed = True ,
405+ source_operation_response = source_operation_response ,
406+ exception_details = exception_details )
400407
401408 @staticmethod
402409 def log_memory_usage_if_needed (worker_id , force = False ):
@@ -416,12 +423,21 @@ def log_memory_usage_if_needed(worker_id, force=False):
416423 def shutdown (self ):
417424 self ._shutdown = True
418425
426+ def get_executor_for_work_item (self , work_item ):
427+ if work_item .map_task is not None :
428+ return executor .MapTaskExecutor (work_item .map_task )
429+ elif work_item .source_operation_split_task is not None :
430+ return executor .CustomSourceSplitExecutor (
431+ work_item .source_operation_split_task )
432+ else :
433+ raise ValueError ('Unknown type of work item : %s' , work_item )
434+
419435 def do_work (self , work_item , deferred_exception_details = None ):
420436 """Executes worker operations and adds any failures to the report status."""
421437 logging .info ('Executing %s' , work_item )
422438 BatchWorker .log_memory_usage_if_needed (self .worker_id , force = True )
423439
424- work_executor = executor . MapTaskExecutor ( )
440+ work_executor = self . get_executor_for_work_item ( work_item )
425441 progress_reporter = ProgressReporter (
426442 work_item , work_executor , self , self .client )
427443
@@ -441,7 +457,7 @@ def do_work(self, work_item, deferred_exception_details=None):
441457 exception_details = None
442458 try :
443459 progress_reporter .start_reporting_progress ()
444- work_executor .execute (work_item . map_task )
460+ work_executor .execute ()
445461 except Exception : # pylint: disable=broad-except
446462 exception_details = traceback .format_exc ()
447463 logging .error ('An exception was raised when trying to execute the '
@@ -464,8 +480,14 @@ def do_work(self, work_item, deferred_exception_details=None):
464480 exception_details = traceback .format_exc ()
465481
466482 with work_item .lock :
467- self .report_completion_status (work_item , progress_reporter ,
468- exception_details = exception_details )
483+ source_split_response = None
484+ if isinstance (work_executor , executor .CustomSourceSplitExecutor ):
485+ source_split_response = work_executor .response
486+
487+ self .report_completion_status (
488+ work_item , progress_reporter ,
489+ source_operation_response = source_split_response ,
490+ exception_details = exception_details )
469491 work_item .done = True
470492
471493 def status_server (self ):
@@ -559,9 +581,13 @@ def run(self):
559581 time .sleep (1.0 * (1 - 0.5 * random .random ()))
560582 continue
561583
584+ stage_name = None
585+ if work_item .map_task :
586+ stage_name = work_item .map_task .stage_name
587+
562588 with logger .PerThreadLoggingContext (
563589 work_item_id = work_item .proto .id ,
564- stage_name = work_item . map_task . stage_name ):
590+ stage_name = stage_name ):
565591 # TODO(silviuc): Add more detailed timing and profiling support.
566592 start_time = time .time ()
567593
0 commit comments