Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 8484d34

Browse files
gildeasilviulica
authored andcommitted
Implement aggregated_values for DirectPipelineRunner
Create new class DirectPipelineResult, on which new method aggregated_values() is defined. This method lets you query the accumulated value of an Aggregator after the pipeline has run. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=120006979
1 parent 4bb35c6 commit 8484d34

4 files changed

Lines changed: 76 additions & 4 deletions

File tree

google/cloud/dataflow/examples/wordcount.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525

2626
empty_line_aggregator = df.Aggregator('emptyLines')
27+
average_word_size_aggregator = df.Aggregator('averageWordLength',
28+
df.combiners.Mean())
2729

2830

2931
class WordExtractingDoFn(df.DoFn):
@@ -43,7 +45,10 @@ def process(self, context):
4345
text_line = context.element.strip()
4446
if not text_line:
4547
context.aggregate_to(empty_line_aggregator, 1)
46-
return re.findall(r'[A-Za-z\']+', text_line)
48+
words = re.findall(r'[A-Za-z\']+', text_line)
49+
for w in words:
50+
context.aggregate_to(average_word_size_aggregator, float(len(w)))
51+
return words
4752

4853

4954
def run(argv=None):
@@ -81,7 +86,11 @@ def run(argv=None):
8186
output | df.io.Write('write', df.io.TextFileSink(known_args.output))
8287

8388
# Actually run the pipeline (all operations above are deferred).
84-
p.run()
89+
result = p.run()
90+
empty_line_values = result.aggregated_values(empty_line_aggregator)
91+
logging.info('number of empty lines: %d', sum(empty_line_values.values()))
92+
word_length_values = result.aggregated_values(average_word_size_aggregator)
93+
logging.info('average word lengths: %s', word_length_values.values())
8594

8695

8796
if __name__ == '__main__':

google/cloud/dataflow/runners/direct_runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from google.cloud.dataflow.pvalue import EmptySideInput
3333
from google.cloud.dataflow.runners.common import DoFnRunner
3434
from google.cloud.dataflow.runners.common import DoFnState
35+
from google.cloud.dataflow.runners.runner import PipelineResult
3536
from google.cloud.dataflow.runners.runner import PipelineRunner
37+
from google.cloud.dataflow.runners.runner import PipelineState
3638
from google.cloud.dataflow.runners.runner import PValueCache
3739
from google.cloud.dataflow.transforms import DoFnProcessContext
3840
from google.cloud.dataflow.transforms.window import GlobalWindows
@@ -89,6 +91,8 @@ def func_wrapper(self, pvalue, *args, **kwargs):
8991
def run(self, pipeline, node=None):
9092
super(DirectPipelineRunner, self).run(pipeline, node)
9193
logging.info('Final: Debug counters: %s', self.debug_counters)
94+
return DirectPipelineResult(state=PipelineState.DONE,
95+
counter_factory=self._counter_factory)
9296

9397
@skip_if_cached
9498
def run_ParDo(self, transform_node):
@@ -226,3 +230,14 @@ def run__NativeWrite(self, transform_node):
226230
for v in self._cache.get_pvalue(transform_node.inputs[0]):
227231
self.debug_counters['element_counts'][transform_node.full_label] += 1
228232
writer.Write(v.value)
233+
234+
235+
class DirectPipelineResult(PipelineResult):
236+
"""A DirectPipelineResult provides access to info about a pipeline."""
237+
238+
def __init__(self, state, counter_factory=None):
239+
super(DirectPipelineResult, self).__init__(state)
240+
self._counter_factory = counter_factory
241+
242+
def aggregated_values(self, aggregator_or_name):
243+
return self._counter_factory.get_aggregator_values(aggregator_or_name)

google/cloud/dataflow/runners/runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def visit_transform(self, transform_node):
8383
raise
8484

8585
pipeline.visit(RunVisitor(self), node=node)
86-
return PipelineResult(state=PipelineState.DONE)
8786

8887
def clear(self, pipeline, node=None):
8988
"""Clear all nodes or nodes reachable from node of materialized values.
@@ -267,3 +266,10 @@ def __init__(self, state):
267266
def current_state(self):
268267
"""Return the current state of running the pipeline."""
269268
return self._state
269+
270+
# pylint: disable=unused-argument
271+
def aggregated_values(self, aggregator_or_name):
272+
"""Return a dict of step names to values of the Aggregator."""
273+
logging.warn('%s does not implement aggregated_values',
274+
self.__class__.__name__)
275+
return {}

google/cloud/dataflow/utils/counters.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def _update_small(self, delta):
8989
def total(self):
9090
return self.c_total + self.py_total
9191

92+
def value(self):
93+
if self.aggregation_kind == self.SUM:
94+
return self.total
95+
elif self.aggregation_kind == self.MEAN:
96+
return float(self.total)/self.elements
97+
else:
98+
# This can't happen, because we check in __init__
99+
raise TypeError('%s.value(): unsupported aggregation_kind' % self)
100+
92101
def __str__(self):
93102
return '<%s>' % self._str_internal()
94103

@@ -125,6 +134,10 @@ def __init__(self, name='unnamed'):
125134
Counter.SUM)
126135

127136

137+
# Counters that represent Accumulators have names starting with this
138+
USER_COUNTER_PREFIX = 'user-'
139+
140+
128141
class CounterFactory(object):
129142
"""Keeps track of unique counters."""
130143

@@ -168,7 +181,7 @@ def get_aggregator_counter(self, step_name, aggregator):
168181
A new or existing counter.
169182
"""
170183
with self._lock:
171-
name = 'user-%s-%s' % (step_name, aggregator.name)
184+
name = '%s%s-%s' % (USER_COUNTER_PREFIX, step_name, aggregator.name)
172185
aggregation_kind = aggregator.aggregation_kind
173186
counter = self.counters.get(name, None)
174187
if counter:
@@ -190,3 +203,32 @@ def get_counters(self):
190203
"""
191204
with self._lock:
192205
return self.counters.values()
206+
207+
def get_aggregator_values(self, aggregator_or_name):
208+
"""Returns dict of step names to values of the aggregator."""
209+
with self._lock:
210+
return get_aggregator_values(
211+
aggregator_or_name, self.counters, lambda counter: counter.value())
212+
213+
214+
def get_aggregator_values(aggregator_or_name, counter_dict,
215+
value_extractor=None):
216+
"""Extracts the named aggregator value from a set of counters.
217+
218+
Args:
219+
aggregator_or_name: an Aggregator object or the name of one.
220+
counter_dict: a dict object of {name: value_wrapper}
221+
value_extractor: a function to convert the value_wrapper into a value.
222+
If None, no extraction is done and the value is return unchanged.
223+
224+
Returns:
225+
dict of step names to values of the aggregator.
226+
"""
227+
name = aggregator_or_name
228+
if value_extractor is None:
229+
value_extractor = lambda x: x
230+
if not isinstance(aggregator_or_name, basestring):
231+
name = aggregator_or_name.name
232+
return {n: value_extractor(c) for n, c in counter_dict.iteritems()
233+
if n.startswith(USER_COUNTER_PREFIX)
234+
and n.endswith('-%s' % name)}

0 commit comments

Comments
 (0)