Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 57415a9

Browse files
robertwbgildea
authored andcommitted
Optimize logging context
As this is called per-element, it's important to minimize logging overhead here. Instead, reduce enter/exit to a single thread-local access, and puts the overwriting into the (already expensive) logging itself. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=117290642
1 parent deebc93 commit 57415a9

2 files changed

Lines changed: 66 additions & 55 deletions

File tree

google/cloud/dataflow/worker/logger.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,32 @@
2323
# Per-thread worker information. This is used only for logging to set
2424
# context information that changes while work items get executed:
2525
# work_item_id, step_name, stage_name.
26-
per_thread_worker_data = threading.local()
26+
class _PerThreadWorkerData(threading.local):
27+
28+
def __init__(self):
29+
super(_PerThreadWorkerData, self).__init__()
30+
self.stack = []
31+
32+
def get_data(self):
33+
all_data = {}
34+
for datum in self.stack:
35+
all_data.update(datum)
36+
return all_data
37+
38+
per_thread_worker_data = _PerThreadWorkerData()
2739

2840

2941
class PerThreadLoggingContext(object):
3042
"""A context manager to add per thread attributes."""
31-
_instance = dict()
32-
33-
def __new__(cls, *args, **kwargs):
34-
# TODO(robertwb): make the class generic, this is special-cased to save
35-
# time on the DoFn
36-
if not args and len(kwargs) == 1 and 'step_name' in kwargs:
37-
k = kwargs['step_name']
38-
if k not in cls._instance:
39-
cls._instance[k] = super(PerThreadLoggingContext, cls).__new__(
40-
cls, *args, **kwargs)
41-
return cls._instance[k]
42-
else:
43-
return super(PerThreadLoggingContext, cls).__new__(cls, *args, **kwargs)
4443

45-
def __init__(self, *args, **kwargs):
46-
if args:
47-
raise ValueError(
48-
'PerThreadLoggingContext expects only keyword arguments.')
44+
def __init__(self, **kwargs):
4945
self.kwargs = kwargs
50-
self.previous = {}
5146

5247
def __enter__(self):
53-
for key in self.kwargs:
54-
if hasattr(per_thread_worker_data, key):
55-
self.previous[key] = getattr(per_thread_worker_data, key)
56-
setattr(per_thread_worker_data, key, self.kwargs[key])
57-
return self
48+
per_thread_worker_data.stack.append(self.kwargs)
5849

5950
def __exit__(self, exn_type, exn_value, exn_traceback):
60-
for key in self.kwargs:
61-
if key in self.previous:
62-
setattr(per_thread_worker_data, key, self.previous[key])
63-
else:
64-
delattr(per_thread_worker_data, key)
51+
per_thread_worker_data.stack.pop()
6552

6653

6754
class JsonLogFormatter(logging.Formatter):
@@ -138,12 +125,13 @@ def format(self, record):
138125
# Stage, step and work item ID come from thread local storage since they
139126
# change with every new work item leased for execution. If there is no
140127
# work item ID then we make sure the step is undefined too.
141-
if hasattr(per_thread_worker_data, 'work_item_id'):
142-
output['work'] = getattr(per_thread_worker_data, 'work_item_id')
143-
if hasattr(per_thread_worker_data, 'stage_name'):
144-
output['stage'] = getattr(per_thread_worker_data, 'stage_name')
145-
if hasattr(per_thread_worker_data, 'step_name'):
146-
output['step'] = getattr(per_thread_worker_data, 'step_name')
128+
data = per_thread_worker_data.get_data()
129+
if 'work_item_id' in data:
130+
output['work'] = data['work_item_id']
131+
if 'stage_name' in data:
132+
output['stage'] = data['stage_name']
133+
if 'step_name' in data:
134+
output['step'] = data['step_name']
147135
# All logging happens using the root logger. We will add the basename of the
148136
# file and the function name where the logging happened to make it easier
149137
# to identify who generated the record.

google/cloud/dataflow/worker/logger_test.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,43 @@
2626
class PerThreadLoggingContextTest(unittest.TestCase):
2727

2828
def thread_check_attribute(self, name):
29-
self.assertFalse(hasattr(logger.per_thread_worker_data, name))
30-
with logger.PerThreadLoggingContext(xyz='thread-value'):
29+
self.assertFalse(name in logger.per_thread_worker_data.get_data())
30+
with logger.PerThreadLoggingContext(**{name: 'thread-value'}):
3131
self.assertEqual(
32-
getattr(logger.per_thread_worker_data, name), 'thread-value')
33-
self.assertFalse(hasattr(logger.per_thread_worker_data, name))
32+
logger.per_thread_worker_data.get_data()[name], 'thread-value')
33+
self.assertFalse(name in logger.per_thread_worker_data.get_data())
3434

3535
def test_no_positional_args(self):
36-
with self.assertRaises(ValueError) as exn:
36+
with self.assertRaises(TypeError):
3737
with logger.PerThreadLoggingContext('something'):
3838
pass
39-
self.assertEqual(
40-
exn.exception.message,
41-
'PerThreadLoggingContext expects only keyword arguments.')
4239

4340
def test_per_thread_attribute(self):
44-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
41+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
4542
with logger.PerThreadLoggingContext(xyz='value'):
46-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value')
43+
self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
4744
thread = threading.Thread(
4845
target=self.thread_check_attribute, args=('xyz',))
4946
thread.start()
5047
thread.join()
51-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value')
52-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
48+
self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
49+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
5350

5451
def test_set_when_undefined(self):
55-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
52+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
5653
with logger.PerThreadLoggingContext(xyz='value'):
57-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value')
58-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
54+
self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
55+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
5956

6057
def test_set_when_already_defined(self):
61-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
58+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
6259
with logger.PerThreadLoggingContext(xyz='value'):
63-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value')
60+
self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
6461
with logger.PerThreadLoggingContext(xyz='value2'):
65-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value2')
66-
self.assertEqual(logger.per_thread_worker_data.xyz, 'value')
67-
self.assertFalse(hasattr(logger.per_thread_worker_data, 'xyz'))
62+
self.assertEqual(
63+
logger.per_thread_worker_data.get_data()['xyz'], 'value2')
64+
self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
65+
self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
6866

6967

7068
class JsonLogFormatterTest(unittest.TestCase):
@@ -140,6 +138,31 @@ def test_record_with_per_thread_info(self):
140138
{'work': 'workitem', 'stage': 'stage', 'step': 'step'})
141139
self.assertEqual(log_output, expected_output)
142140

141+
def test_nested_with_per_thread_info(self):
142+
formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
143+
with logger.PerThreadLoggingContext(
144+
work_item_id='workitem', stage_name='stage', step_name='step1'):
145+
record = self.create_log_record(**self.SAMPLE_RECORD)
146+
log_output1 = json.loads(formatter.format(record))
147+
148+
with logger.PerThreadLoggingContext(step_name='step2'):
149+
record = self.create_log_record(**self.SAMPLE_RECORD)
150+
log_output2 = json.loads(formatter.format(record))
151+
152+
record = self.create_log_record(**self.SAMPLE_RECORD)
153+
log_output3 = json.loads(formatter.format(record))
154+
155+
record = self.create_log_record(**self.SAMPLE_RECORD)
156+
log_output4 = json.loads(formatter.format(record))
157+
158+
self.assertEqual(log_output1, dict(
159+
self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
160+
self.assertEqual(log_output2, dict(
161+
self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2'))
162+
self.assertEqual(log_output3, dict(
163+
self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
164+
self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
165+
143166
def test_exception_record(self):
144167
formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
145168
try:

0 commit comments

Comments
 (0)