Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 164caa4

Browse files
aaltaygildea
authored andcommitted
Run batchworker do worker under cProfile
Captures all defined command line options with a value and sends to the worker as sdk_pipeline_options. Batchworker supports two new options: --profile : Flag to enable per work item profiling --profile_location gs:[] : gcs location to save profile results. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=117274168
1 parent d0cb464 commit 164caa4

7 files changed

Lines changed: 187 additions & 10 deletions

File tree

google/cloud/dataflow/internal/apiclient.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ def __init__(self, packages, options, environment_version):
262262
pool.dataDisks.append(disk)
263263
self.proto.workerPools.append(pool)
264264

265+
sdk_pipeline_options = options.get_all_options()
266+
if sdk_pipeline_options:
267+
self.proto.sdkPipelineOptions = (
268+
dataflow.Environment.SdkPipelineOptionsValue())
269+
270+
for k, v in sdk_pipeline_options.iteritems():
271+
if v is not None:
272+
self.proto.sdkPipelineOptions.additionalProperties.append(
273+
dataflow.Environment.SdkPipelineOptionsValue.AdditionalProperty(
274+
key=k, value=to_json_value(v)))
275+
265276

266277
class Job(object):
267278
"""Wrapper for a dataflow Job protobuf."""

google/cloud/dataflow/pipeline_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def test_defaults(self):
255255
def test_dir(self):
256256
options = Breakfast()
257257
self.assertEquals(
258-
['slices', 'style', 'view_as'],
258+
['get_all_options', 'slices', 'style', 'view_as'],
259259
[attr for attr in dir(options) if not attr.startswith('_')])
260260
self.assertEquals(
261-
['style', 'view_as'],
261+
['get_all_options', 'style', 'view_as'],
262262
[attr for attr in dir(options.view_as(Eggs))
263263
if not attr.startswith('_')])
264264

google/cloud/dataflow/utils/options.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919

2020
import argparse
2121

22-
# Raw (unparsed) options. They are also added by other modules that want to
23-
# contribute modules other than the ones defined in this file. See add_option(),
24-
# below.
25-
OPTIONS = []
26-
2722

2823
class PipelineOptions(object):
2924
"""Pipeline options class used as container for command line options.
@@ -86,6 +81,21 @@ def _add_argparse_args(cls, parser):
8681
# Override this in subclasses to provide options.
8782
pass
8883

84+
def get_all_options(self):
85+
"""Returns a dictionary of all defined arguments.
86+
87+
Returns a dictionary of all defined arguments (arguments that are defined in
88+
any subclass of PipelineOptions) into a dictionary.
89+
90+
Returns:
91+
Dictionary of all args and values.
92+
"""
93+
parser = argparse.ArgumentParser()
94+
for cls in PipelineOptions.__subclasses__():
95+
cls._add_argparse_args(parser) # pylint: disable=protected-access
96+
known_args, _ = parser.parse_known_args(self._flags)
97+
return vars(known_args)
98+
8999
def view_as(self, cls):
90100
view = cls(self._flags)
91101
view._all_options = self._all_options
@@ -300,6 +310,18 @@ def _add_argparse_args(cls, parser):
300310
help='Debug file to write the workflow specification.')
301311

302312

313+
class ProfilingOptions(PipelineOptions):
314+
315+
@classmethod
316+
def _add_argparse_args(cls, parser):
317+
parser.add_argument('--profile',
318+
action='store_true',
319+
help='Enable work item profiling')
320+
parser.add_argument('--profile_location',
321+
default=None,
322+
help='GCS path for saving profiler data.')
323+
324+
303325
class SetupOptions(PipelineOptions):
304326

305327
@classmethod
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for the pipeline options module."""
16+
17+
import logging
18+
import unittest
19+
20+
from google.cloud.dataflow.utils.options import PipelineOptions
21+
22+
23+
class SetupTest(unittest.TestCase):
24+
25+
def test_get_unknown_args(self):
26+
27+
# Used for testing newly added flags.
28+
class MockOptions(PipelineOptions):
29+
30+
@classmethod
31+
def _add_argparse_args(cls, parser):
32+
parser.add_argument('--mock_flag',
33+
action='store_true',
34+
help='Enable work item profiling')
35+
36+
test_cases = [
37+
{'flags': ['--num_workers', '5'],
38+
'expected': {'num_workers': 5, 'mock_flag': False}},
39+
{
40+
'flags': [
41+
'--profile', '--profile_location', 'gs://bucket/', 'ignored'],
42+
'expected': {
43+
'profile': True, 'profile_location': 'gs://bucket/',
44+
'mock_flag': False}
45+
},
46+
{'flags': ['--num_workers', '5', '--mock_flag'],
47+
'expected': {'num_workers': 5, 'mock_flag': True}},
48+
]
49+
50+
for case in test_cases:
51+
options = PipelineOptions(flags=case['flags'])
52+
self.assertDictContainsSubset(case['expected'], options.get_all_options())
53+
self.assertEqual(options.view_as(MockOptions).mock_flag,
54+
case['expected']['mock_flag'])
55+
56+
if __name__ == '__main__':
57+
logging.getLogger().setLevel(logging.INFO)
58+
unittest.main()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A profiler context manager based on cProfile.Profile objects."""
16+
17+
import cProfile
18+
import logging
19+
import os
20+
import pstats
21+
import StringIO
22+
import tempfile
23+
import time
24+
25+
26+
from google.cloud.dataflow.utils.dependency import _dependency_file_copy
27+
28+
29+
class Profile(object):
30+
"""cProfile wrapper context for saving and logging profiler results."""
31+
32+
SORTBY = 'cumulative'
33+
34+
def __init__(self, profile_id, profile_location=None, log_results=False):
35+
self.stats = None
36+
self.profile_id = str(profile_id)
37+
self.profile_location = profile_location
38+
self.log_results = log_results
39+
40+
def __enter__(self):
41+
logging.info('Start profiling: %s', self.profile_id)
42+
self.profile = cProfile.Profile()
43+
self.profile.enable()
44+
return self
45+
46+
def __exit__(self, *args):
47+
self.profile.disable()
48+
logging.info('Stop profiling: %s', self.profile_id)
49+
50+
if self.profile_location:
51+
dump_location = os.path.join(
52+
self.profile_location, 'profile',
53+
('%s-%s' % (time.strftime('%Y-%m-%d_%H_%M_%S'), self.profile_id)))
54+
fd, filename = tempfile.mkstemp()
55+
self.profile.dump_stats(filename)
56+
logging.info('Copying profiler data to: [%s]', dump_location)
57+
_dependency_file_copy(filename, dump_location) # pylint: disable=protected-access
58+
os.close(fd)
59+
os.remove(filename)
60+
61+
if self.log_results:
62+
s = StringIO.StringIO()
63+
self.stats = pstats.Stats(
64+
self.profile, stream=s).sort_stats(Profile.SORTBY)
65+
self.stats.print_stats()
66+
logging.info('Profiler data: [%s]', s.getvalue())

google/cloud/dataflow/worker/batchworker.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.cloud.dataflow.internal import auth
5050
from google.cloud.dataflow.internal import pickler
5151
from google.cloud.dataflow.utils import names
52+
from google.cloud.dataflow.utils import profiler
5253
from google.cloud.dataflow.utils import retry
5354
from google.cloud.dataflow.worker import executor
5455
from google.cloud.dataflow.worker import logger
@@ -69,7 +70,7 @@ class BatchWorker(object):
6970
STATUS_HTTP_PORT = 0 # A value of 0 will pick a random unused port.
7071
MEMORY_USAGE_REPORTING_INTERVAL_SECS = 5 * 60
7172

72-
def __init__(self, properties):
73+
def __init__(self, properties, sdk_pipeline_options):
7374
"""Initializes a worker object from command line arguments."""
7475
self.project_id = properties['project_id']
7576
self.job_id = properties['job_id']
@@ -109,6 +110,11 @@ def __init__(self, properties):
109110
# the currently set work item; does not send progress updates otherwise.
110111
self.report_progress = False
111112

113+
# If 'True' each work item will be profiled with cProfile. Results will
114+
# be logged and also saved to profile_location if set.
115+
self.work_item_profiling = sdk_pipeline_options.get('profile', False)
116+
self.profile_location = sdk_pipeline_options.get('profile_location', None)
117+
112118
@property
113119
def current_work_item(self):
114120
with self.lock:
@@ -518,7 +524,15 @@ def run(self):
518524
# failed. The progress reporting_thread will take care of sending
519525
# updates and updating in the workitem object the reporting indexes
520526
# and duration for the lease.
521-
self.do_work(work_item)
527+
528+
if self.work_item_profiling:
529+
with profiler.Profile(
530+
profile_id=work_item.proto.id,
531+
profile_location=self.profile_location, log_results=True):
532+
self.do_work(work_item)
533+
else:
534+
self.do_work(work_item)
535+
522536
logging.info('Completed work item: %s in %.9f seconds',
523537
work_item.proto.id, time.time() - start_time)
524538

google/cloud/dataflow/worker/start.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# thread-safety issue with datetime.datetime.strptime if this module is not
1818
# already imported.
1919
import _strptime # pylint: disable=unused-import
20+
import json
2021
import logging
2122
import random
2223
import re
@@ -49,6 +50,11 @@ def main():
4950

5051
logging.info('Worker started with properties: %s', properties)
5152

53+
sdk_pipeline_options = json.loads(
54+
properties.get('sdk_pipeline_options', '{}'))
55+
logging.info('Worker started with sdk_pipeline_options: %s',
56+
sdk_pipeline_options)
57+
5258
if unused_args:
5359
logging.warning('Unrecognized arguments %s', unused_args)
5460

@@ -64,7 +70,7 @@ def main():
6470
streamingworker.StreamingWorker(properties).run()
6571
else:
6672
logging.info('Starting batch worker.')
67-
batchworker.BatchWorker(properties).run()
73+
batchworker.BatchWorker(properties, sdk_pipeline_options).run()
6874

6975

7076
if __name__ == '__main__':

0 commit comments

Comments
 (0)