Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit a1ca465

Browse files
aaltaysilviulica
authored andcommitted
Make sdk pipeline options available in the DoFn context
There is also a bug fix in get_all_options() to apply the overrides. ----Release Notes---- Adds a new API DoFnProcessContext.get_sdk_pipeline_options() that returns the pipeline options. [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=118083641
1 parent 2c4c414 commit a1ca465

8 files changed

Lines changed: 129 additions & 40 deletions

File tree

google/cloud/dataflow/pipeline_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def test_defaults(self):
255255
def test_dir(self):
256256
options = Breakfast()
257257
self.assertEquals(
258-
['get_all_options', 'slices', 'style', 'view_as'],
258+
['from_dictionary', 'get_all_options', 'slices', 'style', 'view_as'],
259259
[attr for attr in dir(options) if not attr.startswith('_')])
260260
self.assertEquals(
261-
['get_all_options', 'style', 'view_as'],
261+
['from_dictionary', 'get_all_options', 'style', 'view_as'],
262262
[attr for attr in dir(options.view_as(Eggs))
263263
if not attr.startswith('_')])
264264

google/cloud/dataflow/runners/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,14 @@ def existing_windows(self):
171171

172172
class DoFnState(object):
173173
"""Keeps track of state that DoFns want, currently, user counters.
174+
175+
Attributes:
176+
pipeline_options: a PipelineOptions object associated with this DoFn.
177+
step_name: name of the step as a string.
174178
"""
175179

176-
def __init__(self):
180+
def __init__(self, pipeline_options):
181+
self.pipeline_options = pipeline_options
177182
self.step_name = ''
178183
self._user_counters = {}
179184

google/cloud/dataflow/runners/direct_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ def func_wrapper(self, pvalue, *args, **kwargs):
7777
@skip_if_cached
7878
def run_ParDo(self, transform_node):
7979
transform = transform_node.transform
80+
options = transform_node.inputs[0].pipeline.options
8081
# TODO(gildea): what is the appropriate object to attach the state to?
81-
context = DoFnProcessContext(label=transform.label, state=DoFnState())
82+
context = DoFnProcessContext(label=transform.label,
83+
state=DoFnState(options))
8284

8385
# Construct the list of values from side-input PCollections that we'll
8486
# substitute into the arguments for DoFn methods.
@@ -105,7 +107,6 @@ def get_side_input_value(si):
105107

106108
# TODO(robertwb): Do this type checking inside DoFnRunner to get it on
107109
# remote workers as well?
108-
options = transform_node.inputs[0].pipeline.options
109110
if options is not None and options.view_as(TypeOptions).runtime_type_check:
110111
transform.dofn = TypeCheckWrapperDoFn(
111112
transform.dofn, transform.get_type_hints())

google/cloud/dataflow/transforms/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class DoFnProcessContext(object):
6161
(in process method only; always None in start_bundle and finish_bundle)
6262
windows: windows of the element
6363
(in process method only; always None in start_bundle and finish_bundle)
64+
pipeline_options: PipelineOptions object used for creating the pipeline.
6465
state: a DoFnState object, which holds the runner's internal state
6566
for this element. For example, aggregator state is here.
6667
Not used by the pipeline code.
@@ -106,6 +107,10 @@ def aggregate_to(self, aggregator, input_value):
106107
"""
107108
self.state.counter_for(aggregator).update(input_value)
108109

110+
@property
111+
def pipeline_options(self):
112+
return self.state.pipeline_options
113+
109114

110115
class DoFn(WithTypeHints):
111116
"""A function object used by a transform with custom processing.

google/cloud/dataflow/utils/options.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def _add_argparse_args(cls, parser):
8181
# Override this in subclasses to provide options.
8282
pass
8383

84+
@classmethod
85+
def from_dictionary(cls, options):
86+
"""Returns a PipelineOptions from a dictionary of arguments.
87+
88+
Args:
89+
options: Dictinary of argument value pairs.
90+
91+
Returns:
92+
A PipelineOptions object representing the given arguments.
93+
"""
94+
flags = []
95+
for k, v in options.iteritems():
96+
if isinstance(v, bool):
97+
if v:
98+
flags.append('--%s' % k)
99+
else:
100+
flags.append('--%s=%s' % (k, v))
101+
102+
return cls(flags)
103+
84104
def get_all_options(self):
85105
"""Returns a dictionary of all defined arguments.
86106
@@ -94,7 +114,14 @@ def get_all_options(self):
94114
for cls in PipelineOptions.__subclasses__():
95115
cls._add_argparse_args(parser) # pylint: disable=protected-access
96116
known_args, _ = parser.parse_known_args(self._flags)
97-
return vars(known_args)
117+
result = vars(known_args)
118+
119+
# Apply the overrides if any
120+
for k in result:
121+
if k in self._all_options:
122+
result[k] = self._all_options[k]
123+
124+
return result
98125

99126
def view_as(self, cls):
100127
view = cls(self._flags)

google/cloud/dataflow/utils/pipeline_options_test.py

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,84 @@
2020
from google.cloud.dataflow.utils.options import PipelineOptions
2121

2222

23-
class SetupTest(unittest.TestCase):
24-
25-
def test_get_unknown_args(self):
26-
27-
# Used for testing newly added flags.
28-
class MockOptions(PipelineOptions):
29-
30-
@classmethod
31-
def _add_argparse_args(cls, parser):
32-
parser.add_argument('--mock_flag',
33-
action='store_true',
34-
help='Enable work item profiling')
35-
36-
test_cases = [
37-
{'flags': ['--num_workers', '5'],
38-
'expected': {'num_workers': 5, 'mock_flag': False}},
39-
{
40-
'flags': [
41-
'--profile', '--profile_location', 'gs://bucket/', 'ignored'],
42-
'expected': {
43-
'profile': True, 'profile_location': 'gs://bucket/',
44-
'mock_flag': False}
45-
},
46-
{'flags': ['--num_workers', '5', '--mock_flag'],
47-
'expected': {'num_workers': 5, 'mock_flag': True}},
48-
]
49-
50-
for case in test_cases:
23+
class PipelineOptionsTest(unittest.TestCase):
24+
25+
TEST_CASES = [
26+
{'flags': ['--num_workers', '5'],
27+
'expected': {'num_workers': 5, 'mock_flag': False, 'mock_option': None}},
28+
{
29+
'flags': [
30+
'--profile', '--profile_location', 'gs://bucket/', 'ignored'],
31+
'expected': {
32+
'profile': True, 'profile_location': 'gs://bucket/',
33+
'mock_flag': False, 'mock_option': None}
34+
},
35+
{'flags': ['--num_workers', '5', '--mock_flag'],
36+
'expected': {'num_workers': 5, 'mock_flag': True, 'mock_option': None}},
37+
{'flags': ['--mock_option', 'abc'],
38+
'expected': {'mock_flag': False, 'mock_option': 'abc'}},
39+
{'flags': ['--mock_option', ' abc def '],
40+
'expected': {'mock_flag': False, 'mock_option': ' abc def '}},
41+
{'flags': ['--mock_option= abc xyz '],
42+
'expected': {'mock_flag': False, 'mock_option': ' abc xyz '}},
43+
{'flags': ['--mock_option=gs://my bucket/my folder/my file'],
44+
'expected': {'mock_flag': False,
45+
'mock_option': 'gs://my bucket/my folder/my file'}},
46+
]
47+
48+
# Used for testing newly added flags.
49+
class MockOptions(PipelineOptions):
50+
51+
@classmethod
52+
def _add_argparse_args(cls, parser):
53+
parser.add_argument('--mock_flag', action='store_true', help='mock flag')
54+
parser.add_argument('--mock_option', help='mock option')
55+
parser.add_argument('--option with space', help='mock option with space')
56+
57+
def test_get_all_options(self):
58+
for case in PipelineOptionsTest.TEST_CASES:
5159
options = PipelineOptions(flags=case['flags'])
5260
self.assertDictContainsSubset(case['expected'], options.get_all_options())
53-
self.assertEqual(options.view_as(MockOptions).mock_flag,
61+
self.assertEqual(options.view_as(
62+
PipelineOptionsTest.MockOptions).mock_flag,
5463
case['expected']['mock_flag'])
64+
self.assertEqual(options.view_as(
65+
PipelineOptionsTest.MockOptions).mock_option,
66+
case['expected']['mock_option'])
67+
68+
def test_from_dictionary(self):
69+
for case in PipelineOptionsTest.TEST_CASES:
70+
options = PipelineOptions(flags=case['flags'])
71+
all_options_dict = options.get_all_options()
72+
options_from_dict = PipelineOptions.from_dictionary(all_options_dict)
73+
self.assertEqual(options_from_dict.view_as(
74+
PipelineOptionsTest.MockOptions).mock_flag,
75+
case['expected']['mock_flag'])
76+
self.assertEqual(options.view_as(
77+
PipelineOptionsTest.MockOptions).mock_option,
78+
case['expected']['mock_option'])
79+
80+
def test_option_with_spcae(self):
81+
options = PipelineOptions(flags=['--option with space= value with space'])
82+
self.assertEqual(
83+
getattr(options.view_as(PipelineOptionsTest.MockOptions),
84+
'option with space'), ' value with space')
85+
options_from_dict = PipelineOptions.from_dictionary(
86+
options.get_all_options())
87+
self.assertEqual(
88+
getattr(options_from_dict.view_as(PipelineOptionsTest.MockOptions),
89+
'option with space'), ' value with space')
90+
91+
def test_override_options(self):
92+
base_flags = ['--num_workers', '5']
93+
options = PipelineOptions(base_flags)
94+
self.assertEqual(options.get_all_options()['num_workers'], 5)
95+
self.assertEqual(options.get_all_options()['mock_flag'], False)
96+
97+
options.view_as(PipelineOptionsTest.MockOptions).mock_flag = True
98+
self.assertEqual(options.get_all_options()['num_workers'], 5)
99+
self.assertEqual(options.get_all_options()['mock_flag'], True)
100+
55101

56102
if __name__ == '__main__':
57103
logging.getLogger().setLevel(logging.INFO)

google/cloud/dataflow/worker/batchworker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.cloud.dataflow.internal import auth
5050
from google.cloud.dataflow.internal import pickler
5151
from google.cloud.dataflow.utils import names
52+
from google.cloud.dataflow.utils import options
5253
from google.cloud.dataflow.utils import profiler
5354
from google.cloud.dataflow.utils import retry
5455
from google.cloud.dataflow.worker import executor
@@ -76,6 +77,8 @@ def __init__(self, properties, sdk_pipeline_options):
7677
self.job_id = properties['job_id']
7778
self.worker_id = properties['worker_id']
7879
self.service_path = properties['service_path']
80+
self.pipeline_options = options.PipelineOptions.from_dictionary(
81+
sdk_pipeline_options)
7982
self.capabilities = [self.worker_id, 'remote_source', 'custom_source']
8083
self.work_types = ['map_task', 'seq_map_task', 'remote_source_task']
8184
# The following properties are passed to the worker when its container
@@ -402,7 +405,8 @@ def do_work(self, work_item):
402405
self.dynamic_split_result_to_report = None
403406

404407
self.set_current_work_item_and_executor(work_item,
405-
executor.MapTaskExecutor())
408+
executor.MapTaskExecutor(
409+
self.pipeline_options))
406410
self.report_progress = True
407411
self.current_executor.execute(work_item.map_task)
408412
except Exception: # pylint: disable=broad-except

google/cloud/dataflow/worker/executor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,9 @@ def process(self, o):
383383
class DoOperation(Operation):
384384
"""A Do operation that will execute a custom DoFn for each input element."""
385385

386-
def __init__(self, spec):
386+
def __init__(self, spec, pipeline_options):
387387
super(DoOperation, self).__init__(spec)
388-
self.state = common.DoFnState()
388+
self.state = common.DoFnState(pipeline_options)
389389

390390
def _read_side_inputs(self, tags_and_types):
391391
"""Generator reading side inputs in the order prescribed by tags_and_types.
@@ -728,7 +728,8 @@ class MapTaskExecutor(object):
728728
multiple_read_instruction_error_msg = (
729729
'Found more than one \'read instruction\' in a single \'map task\'')
730730

731-
def __init__(self):
731+
def __init__(self, pipeline_options=None):
732+
self.pipeline_options = pipeline_options
732733
self._ops = []
733734
self._read_operation = None
734735

@@ -775,7 +776,7 @@ def execute(self, map_task, test_shuffle_source=None, test_shuffle_sink=None):
775776
elif isinstance(spec, maptask.WorkerPartialGroupByKey):
776777
op = create_pgbk_op(spec)
777778
elif isinstance(spec, maptask.WorkerDoFn):
778-
op = DoOperation(spec)
779+
op = DoOperation(spec, self.pipeline_options)
779780
elif isinstance(spec, maptask.WorkerGroupingShuffleRead):
780781
op = GroupedShuffleReadOperation(
781782
spec, shuffle_source=test_shuffle_source)

0 commit comments

Comments
 (0)