Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 574a29e

Browse files
chamikaramjaaltay
authored andcommitted
Adds support for reading custom sources using DataflowPipelineRunner.
Adds support for performing custom source split operations and reading sub-sources generated by split operations. Generalizes WorkItem executing by introducing an interface 'Executor'. Adds an executor for performing custom source split operations. Adds a local runner based integration test for custom sources. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=123344217
1 parent 67033fc commit 574a29e

12 files changed

Lines changed: 648 additions & 124 deletions

File tree

google/cloud/dataflow/internal/apiclient.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import re
2222
import time
2323

24+
2425
from google.cloud.dataflow import utils
2526
from google.cloud.dataflow import version
27+
from google.cloud.dataflow.internal import pickler
2628
from google.cloud.dataflow.internal.auth import get_service_credentials
2729
from google.cloud.dataflow.internal.json_value import to_json_value
2830
from google.cloud.dataflow.io import iobase
2931
from google.cloud.dataflow.transforms import cy_combiners
3032
from google.cloud.dataflow.utils import dependency
33+
from google.cloud.dataflow.utils import names
3134
from google.cloud.dataflow.utils import retry
3235
from google.cloud.dataflow.utils.names import PropertyNames
3336
from google.cloud.dataflow.utils.options import GoogleCloudOptions
@@ -665,6 +668,7 @@ def report_status(self,
665668
completed,
666669
progress,
667670
dynamic_split_result_to_report=None,
671+
source_operation_response=None,
668672
exception_details=None):
669673
"""Reports status for a work item (success or failure).
670674
@@ -688,6 +692,9 @@ def report_status(self,
688692
handling the work item.
689693
dynamic_split_result_to_report: A successful dynamic split result that
690694
should be sent to the Dataflow service along with the status report.
695+
source_operation_response: Response to a source operation request from
696+
the service. This will be sent to the service along with the status
697+
report.
691698
exception_details: A string representation of the stack trace for an
692699
exception raised while executing the work item. The string is the
693700
output of the standard traceback.format_exc() function.
@@ -746,9 +753,13 @@ def report_status(self,
746753
status.message = exception_details
747754
work_item_status.errors.append(status)
748755

756+
if source_operation_response is not None:
757+
work_item_status.sourceOperationResponse = source_operation_response
758+
749759
# Look through the work item for metrics to send.
750-
for counter in work_item.map_task.itercounters():
751-
append_counter(work_item_status, counter, tentative=not completed)
760+
if work_item.map_task:
761+
for counter in work_item.map_task.itercounters():
762+
append_counter(work_item_status, counter, tentative=not completed)
752763

753764
report_request = dataflow.ReportWorkItemStatusRequest()
754765
report_request.currentWorkerTime = worker_info.formatted_current_time
@@ -871,3 +882,44 @@ def set_mean(accumulator, metric_update):
871882
cy_combiners.AllCombineFn: ('and', set_scalar),
872883
cy_combiners.AnyCombineFn: ('or', set_scalar),
873884
}
885+
886+
887+
def splits_to_split_response(bundles):
888+
"""Generates a response to a custom source split request.
889+
890+
Args:
891+
bundles: a set of bundles generated by a BoundedSource.split() invocation.
892+
Returns:
893+
a SourceOperationResponse object.
894+
"""
895+
derived_sources = []
896+
for bundle in bundles:
897+
derived_source = dataflow.DerivedSource()
898+
derived_source.derivationMode = (
899+
dataflow.DerivedSource.DerivationModeValueValuesEnum
900+
.SOURCE_DERIVATION_MODE_INDEPENDENT)
901+
derived_source.source = dataflow.Source()
902+
derived_source.source.doesNotNeedSplitting = True
903+
904+
derived_source.source.spec = dataflow.Source.SpecValue()
905+
derived_source.source.spec.additionalProperties.append(
906+
dataflow.Source.SpecValue.AdditionalProperty(
907+
key=names.SERIALIZED_SOURCE_KEY,
908+
value=to_json_value(pickler.dumps(
909+
(bundle.source, bundle.start_position, bundle.stop_position)),
910+
with_type=True)))
911+
derived_source.source.spec.additionalProperties.append(
912+
dataflow.Source.SpecValue.AdditionalProperty(key='@type',
913+
value=to_json_value(
914+
names.SOURCE_TYPE)))
915+
derived_sources.append(derived_source)
916+
917+
split_response = dataflow.SourceSplitResponse()
918+
split_response.bundles = derived_sources
919+
split_response.outcome = (
920+
dataflow.SourceSplitResponse.OutcomeValueValuesEnum
921+
.SOURCE_SPLIT_OUTCOME_SPLITTING_HAPPENED)
922+
923+
response = dataflow.SourceOperationResponse()
924+
response.split = split_response
925+
return response

google/cloud/dataflow/runners/dataflow_runner.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from google.cloud.dataflow.runners.runner import PipelineState
3535
from google.cloud.dataflow.runners.runner import PValueCache
3636
from google.cloud.dataflow.typehints import typehints
37+
from google.cloud.dataflow.utils import names
3738
from google.cloud.dataflow.utils.names import PropertyNames
3839
from google.cloud.dataflow.utils.names import TransformNames
3940
from google.cloud.dataflow.utils.options import StandardOptions
@@ -456,12 +457,17 @@ def run_Read(self, transform_node):
456457
# TODO(mairbek): refactor if-else tree to use registerable functions.
457458
# Initialize the source specific properties.
458459

459-
if isinstance(transform.source, iobase.BoundedSource):
460-
raise ValueError('DataflowPipelineRunner does not support reading '
461-
'BoundedSource implementations yet. Please use a source '
462-
'provided by Dataflow SDK or use DirectPipelineRunner.')
463-
464-
if transform.source.format == 'text':
460+
if not hasattr(transform.source, 'format'):
461+
# If a format is not set, we assume the source to be a custom source.
462+
source_dict = dict()
463+
spec_dict = dict()
464+
465+
spec_dict[names.SERIALIZED_SOURCE_KEY] = pickler.dumps(transform.source)
466+
spec_dict['@type'] = names.SOURCE_TYPE
467+
source_dict['spec'] = spec_dict
468+
step.add_property(PropertyNames.SOURCE_STEP_INPUT,
469+
source_dict)
470+
elif transform.source.format == 'text':
465471
step.add_property(PropertyNames.FILE_PATTERN, transform.source.path)
466472
elif transform.source.format == 'bigquery':
467473
# TODO(silviuc): Add table validation if transform.source.validate.
@@ -494,15 +500,22 @@ def run_Read(self, transform_node):
494500
if transform.source.id_label:
495501
step.add_property(PropertyNames.PUBSUB_ID_LABEL,
496502
transform.source.id_label)
497-
elif transform.source.format == 'custom':
498-
# TODO(silviuc): Implement custom sources.
499-
raise NotImplementedError
500503
else:
501504
raise ValueError(
502505
'Source %r has unexpected format %s.' % (
503506
transform.source, transform.source.format))
504-
step.add_property(PropertyNames.FORMAT, transform.source.format)
505-
step.encoding = self._get_cloud_encoding(transform.source.coder)
507+
508+
if not hasattr(transform.source, 'format'):
509+
step.add_property(PropertyNames.FORMAT, names.SOURCE_FORMAT)
510+
else:
511+
step.add_property(PropertyNames.FORMAT, transform.source.format)
512+
513+
if isinstance(transform.source, iobase.BoundedSource):
514+
coder = transform.source.default_output_coder()
515+
else:
516+
coder = transform.source.coder
517+
518+
step.encoding = self._get_cloud_encoding(coder)
506519
step.add_property(
507520
PropertyNames.OUTPUT_INFO,
508521
[{PropertyNames.USER_NAME: (

google/cloud/dataflow/utils/names.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
PICKLED_MAIN_SESSION_FILE = 'pickled_main_session'
2020
DATAFLOW_SDK_TARBALL_FILE = 'dataflow_python_sdk.tar'
2121

22+
# String constants related to sources framework
23+
SOURCE_FORMAT = 'custom_source'
24+
SOURCE_TYPE = 'CustomSourcesType'
25+
SERIALIZED_SOURCE_KEY = 'serialized_source'
26+
2227

2328
class TransformNames(object):
2429
"""Transform strings as they are expected in the CloudWorkflow protos."""
@@ -61,6 +66,7 @@ class PropertyNames(object):
6166
PUBSUB_ID_LABEL = 'pubsub_id_label'
6267
SERIALIZED_FN = 'serialized_fn'
6368
SHARD_NAME_TEMPLATE = 'shard_template'
69+
SOURCE_STEP_INPUT = 'custom_source_step_input'
6470
STEP_NAME = 'step_name'
6571
USER_FN = 'user_fn'
6672
USER_NAME = 'user_name'

google/cloud/dataflow/worker/batchworker.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def cloud_time_to_timestamp(self, cloud_time_string):
201201
def report_status(self,
202202
completed=False,
203203
progress=None,
204+
source_operation_response=None,
204205
exception_details=None):
205206
"""Reports to the service status of a work item (completion or progress).
206207
@@ -209,6 +210,7 @@ def report_status(self,
209210
either because it succeeded or because it failed. False if this is a
210211
progress report.
211212
progress: Progress of processing the work_item.
213+
source_operation_response: Response to a custom source operation
212214
exception_details: A string representation of the stack trace for an
213215
exception raised while executing the work item. The string is the
214216
output of the standard traceback.format_exc() function.
@@ -226,7 +228,8 @@ def report_status(self,
226228
completed,
227229
progress if not completed else None,
228230
self.dynamic_split_result_to_report if not completed else None,
229-
exception_details)
231+
source_operation_response=source_operation_response,
232+
exception_details=exception_details)
230233

231234
# Resetting dynamic_split_result_to_report after reporting status
232235
# successfully.
@@ -368,6 +371,7 @@ def report_completion_status(
368371
self,
369372
current_work_item,
370373
progress_reporter,
374+
source_operation_response=None,
371375
exception_details=None):
372376
"""Reports to the service a work item completion (successful or failed).
373377
@@ -383,6 +387,7 @@ def report_completion_status(
383387
current_work_item: A WorkItem instance describing the work.
384388
progress_reporter: A ProgressReporter configured to process work item
385389
current_work_item.
390+
source_operation_response: Response to a custom source operation.
386391
exception_details: A string representation of the stack trace for an
387392
exception raised while executing the work item. The string is the
388393
output of the standard traceback.format_exc() function.
@@ -395,8 +400,10 @@ def report_completion_status(
395400
'successfully' if exception_details is None
396401
else 'with exception')
397402

398-
progress_reporter.report_status(completed=True,
399-
exception_details=exception_details)
403+
progress_reporter.report_status(
404+
completed=True,
405+
source_operation_response=source_operation_response,
406+
exception_details=exception_details)
400407

401408
@staticmethod
402409
def log_memory_usage_if_needed(worker_id, force=False):
@@ -416,12 +423,21 @@ def log_memory_usage_if_needed(worker_id, force=False):
416423
def shutdown(self):
417424
self._shutdown = True
418425

426+
def get_executor_for_work_item(self, work_item):
427+
if work_item.map_task is not None:
428+
return executor.MapTaskExecutor(work_item.map_task)
429+
elif work_item.source_operation_split_task is not None:
430+
return executor.CustomSourceSplitExecutor(
431+
work_item.source_operation_split_task)
432+
else:
433+
raise ValueError('Unknown type of work item : %s', work_item)
434+
419435
def do_work(self, work_item, deferred_exception_details=None):
420436
"""Executes worker operations and adds any failures to the report status."""
421437
logging.info('Executing %s', work_item)
422438
BatchWorker.log_memory_usage_if_needed(self.worker_id, force=True)
423439

424-
work_executor = executor.MapTaskExecutor()
440+
work_executor = self.get_executor_for_work_item(work_item)
425441
progress_reporter = ProgressReporter(
426442
work_item, work_executor, self, self.client)
427443

@@ -441,7 +457,7 @@ def do_work(self, work_item, deferred_exception_details=None):
441457
exception_details = None
442458
try:
443459
progress_reporter.start_reporting_progress()
444-
work_executor.execute(work_item.map_task)
460+
work_executor.execute()
445461
except Exception: # pylint: disable=broad-except
446462
exception_details = traceback.format_exc()
447463
logging.error('An exception was raised when trying to execute the '
@@ -464,8 +480,14 @@ def do_work(self, work_item, deferred_exception_details=None):
464480
exception_details = traceback.format_exc()
465481

466482
with work_item.lock:
467-
self.report_completion_status(work_item, progress_reporter,
468-
exception_details=exception_details)
483+
source_split_response = None
484+
if isinstance(work_executor, executor.CustomSourceSplitExecutor):
485+
source_split_response = work_executor.response
486+
487+
self.report_completion_status(
488+
work_item, progress_reporter,
489+
source_operation_response=source_split_response,
490+
exception_details=exception_details)
469491
work_item.done = True
470492

471493
def status_server(self):
@@ -559,9 +581,13 @@ def run(self):
559581
time.sleep(1.0 * (1 - 0.5 * random.random()))
560582
continue
561583

584+
stage_name = None
585+
if work_item.map_task:
586+
stage_name = work_item.map_task.stage_name
587+
562588
with logger.PerThreadLoggingContext(
563589
work_item_id=work_item.proto.id,
564-
stage_name=work_item.map_task.stage_name):
590+
stage_name=stage_name):
565591
# TODO(silviuc): Add more detailed timing and profiling support.
566592
start_time = time.time()
567593

google/cloud/dataflow/worker/batchworker_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def test_worker_starts_and_stops_progress_reporter(
8787
worker.do_work(mock_work_item)
8888

8989
mock_report_status.assert_called_with(
90-
completed=True, exception_details=None)
90+
completed=True, source_operation_response=None, exception_details=None)
9191
mock_start.assert_called_once_with()
92-
mock_execute.assert_called_once_with(mock.ANY)
92+
mock_execute.assert_called_once_with()
9393
mock_stop.assert_called_once_with()
9494

9595
@patch.object(executor.MapTaskExecutor, 'execute')
@@ -103,7 +103,9 @@ def test_worker_fails_for_deferred_exceptions(
103103
worker.do_work(mock_work_item, deferred_exception_details='deferred_exc')
104104

105105
mock_report_status.assert_called_with(
106-
completed=True, exception_details='deferred_exc')
106+
completed=True,
107+
source_operation_response=None,
108+
exception_details='deferred_exc')
107109
assert not mock_stop.called
108110
assert not mock_start.called
109111
assert not mock_execute.called
@@ -121,10 +123,11 @@ def __eq__(self, other):
121123

122124
mock_report_status.assert_called_with(
123125
completed=True,
126+
source_operation_response=None,
124127
exception_details=AnyStringWith(expected_exception))
125128

126129
mock_start.assert_called_once_with()
127-
mock_execute.assert_called_once_with(mock.ANY)
130+
mock_execute.assert_called_once_with()
128131
mock_stop.assert_called_once_with()
129132

130133
@patch.object(executor.MapTaskExecutor, 'execute')
@@ -167,8 +170,8 @@ class ProgressReporterTest(unittest.TestCase):
167170
@patch.object(batchworker.ProgressReporter, 'process_report_status_response')
168171
def test_progress_reporter_reports_progress(
169172
self, mock_report_response, mock_next_progress): # pylint: disable=unused-argument
170-
work_item = workitem.BatchWorkItem(
171-
proto=mock.MagicMock(), map_task=mock.MagicMock())
173+
work_item = workitem.BatchWorkItem(proto=mock.MagicMock())
174+
work_item.map_task = mock.MagicMock()
172175
mock_work_executor = mock.MagicMock()
173176
mock_batch_worker = mock.MagicMock()
174177
mock_client = mock.MagicMock()
@@ -181,7 +184,8 @@ def test_progress_reporter_reports_progress(
181184
time.sleep(10)
182185
progress_reporter.stop_reporting_progress()
183186
mock_client.report_status.assert_called_with(
184-
mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY)
187+
mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY,
188+
exception_details=mock.ANY, source_operation_response=mock.ANY)
185189

186190
@patch.object(batchworker.ProgressReporter, 'next_progress_report_interval')
187191
@patch.object(batchworker.ProgressReporter, 'process_report_status_response')
@@ -200,7 +204,7 @@ def test_progress_reporter_sends_last_update(
200204
progress_reporter.stop_reporting_progress()
201205
mock_client.report_status.assert_called_with(
202206
mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock_split_result,
203-
mock.ANY)
207+
exception_details=mock.ANY, source_operation_response=mock.ANY)
204208

205209

206210
if __name__ == '__main__':

0 commit comments

Comments
 (0)