Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 5f104d4

Browse files
robertwbsilviulica
authored andcommitted
Enable support for all supported counter types
Also implement optimized int64/double sum/min/max/mean and any/all CombineFns. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=121067639
1 parent 0a52fae commit 5f104d4

12 files changed

Lines changed: 524 additions & 164 deletions

File tree

google/cloud/dataflow/examples/wordcount.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
empty_line_aggregator = df.Aggregator('emptyLines')
2727
average_word_size_aggregator = df.Aggregator('averageWordLength',
28-
df.combiners.Mean())
28+
df.combiners.MeanCombineFn(),
29+
float)
2930

3031

3132
class WordExtractingDoFn(df.DoFn):
@@ -47,7 +48,7 @@ def process(self, context):
4748
context.aggregate_to(empty_line_aggregator, 1)
4849
words = re.findall(r'[A-Za-z\']+', text_line)
4950
for w in words:
50-
context.aggregate_to(average_word_size_aggregator, float(len(w)))
51+
context.aggregate_to(average_word_size_aggregator, len(w))
5152
return words
5253

5354

google/cloud/dataflow/internal/apiclient.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.dataflow.internal.auth import get_service_credentials
2727
from google.cloud.dataflow.internal.json_value import to_json_value
2828
from google.cloud.dataflow.io import iobase
29+
from google.cloud.dataflow.transforms import cy_combiners
2930
from google.cloud.dataflow.utils import dependency
3031
from google.cloud.dataflow.utils import retry
3132
from google.cloud.dataflow.utils.names import PropertyNames
@@ -44,7 +45,7 @@
4445
STORAGE_API_SERVICE = 'storage.googleapis.com'
4546

4647

47-
def append_counter(status_object, counter, tentative=False):
48+
def append_counter(status_object, counter, tentative):
4849
"""Appends a counter to the status.
4950
5051
Args:
@@ -55,22 +56,23 @@ def append_counter(status_object, counter, tentative=False):
5556
logging.debug('Appending counter%s %s',
5657
' (tentative)' if tentative else '',
5758
counter)
59+
kind, setter = metric_translations[counter.combine_fn.__class__]
5860
append_metric(
59-
status_object, counter.name, counter.total,
60-
counter.elements if counter.aggregation_kind == counter.MEAN else None,
61-
tentative=tentative)
61+
status_object, counter.name, kind, counter.accumulator,
62+
setter, tentative=tentative)
6263

6364

64-
def append_metric(status_object, metric_name, value1, value2=None,
65+
def append_metric(status_object, metric_name, kind, value, setter=None,
6566
step=None, output_user_name=None, tentative=False,
6667
worker_id=None, cumulative=True):
6768
"""Creates and adds a MetricUpdate field to the passed-in protobuf.
6869
6970
Args:
7071
status_object: a work_item_status to which to add this metric
7172
metric_name: a string naming this metric
72-
value1: scalar for a Sum or mean_sum for a Mean
73-
value2: mean_count for a Mean aggregation (do not provide for a Sum).
73+
kind: dataflow counter kind (e.g. 'sum')
74+
value: accumulator value to encode
75+
setter: if not None, a lambda to use to update metric_update with value
7476
step: the name of the associated step
7577
output_user_name: the user-visible name to use
7678
tentative: whether this should be labeled as a tentative metric
@@ -103,19 +105,13 @@ def append_to_context(key, value):
103105
append_to_context('workerId', worker_id)
104106
if cumulative and is_counter:
105107
metric_update.cumulative = cumulative
106-
if value2 is None:
107-
if is_counter:
108-
# Counters are distinguished by having a kind; metrics do not.
109-
metric_update.kind = 'Sum'
110-
metric_update.scalar = to_json_value(value1, with_type=True)
111-
elif value2 > 0:
112-
metric_update.kind = 'Mean'
113-
metric_update.meanSum = to_json_value(value1, with_type=True)
114-
metric_update.meanCount = to_json_value(value2, with_type=True)
108+
if is_counter:
109+
# Counters are distinguished by having a kind; metrics do not.
110+
metric_update.kind = kind
111+
if setter:
112+
setter(value, metric_update)
115113
else:
116-
# A denominator of 0 will raise an error in the service.
117-
# What it means is we have nothing to report yet, so don't.
118-
pass
114+
metric_update.scalar = to_json_value(value, with_type=True)
119115
logging.debug('Appending metric_update: %s', metric_update)
120116
status_object.metricUpdates.append(metric_update)
121117

@@ -840,3 +836,33 @@ def cloud_position_to_reader_position(cloud_position):
840836
def approximate_progress_to_dynamic_split_request(approximate_progress):
841837
return iobase.DynamicSplitRequest(cloud_progress_to_reader_progress(
842838
approximate_progress))
839+
840+
841+
def set_scalar(accumulator, metric_update):
842+
metric_update.scalar = to_json_value(accumulator.value, with_type=True)
843+
844+
845+
def set_mean(accumulator, metric_update):
846+
if accumulator.count:
847+
metric_update.meanSum = to_json_value(accumulator.sum, with_type=True)
848+
metric_update.meanCount = to_json_value(accumulator.count, with_type=True)
849+
else:
850+
# A denominator of 0 will raise an error in the service.
851+
# What it means is we have nothing to report yet, so don't.
852+
metric_update.kind = None
853+
854+
855+
# To enable a counter on the service, add it to this dictionary.
856+
metric_translations = {
857+
cy_combiners.CountCombineFn: ('sum', set_scalar),
858+
cy_combiners.SumInt64Fn: ('sum', set_scalar),
859+
cy_combiners.MinInt64Fn: ('min', set_scalar),
860+
cy_combiners.MaxInt64Fn: ('max', set_scalar),
861+
cy_combiners.MeanInt64Fn: ('mean', set_mean),
862+
cy_combiners.SumFloatFn: ('sum', set_scalar),
863+
cy_combiners.MinFloatFn: ('min', set_scalar),
864+
cy_combiners.MaxFloatFn: ('max', set_scalar),
865+
cy_combiners.MeanFloatFn: ('mean', set_mean),
866+
cy_combiners.AllCombineFn: ('and', set_scalar),
867+
cy_combiners.AnyCombineFn: ('or', set_scalar),
868+
}

google/cloud/dataflow/transforms/aggregator.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def process(self, context):
3838

3939
from __future__ import absolute_import
4040

41-
from google.cloud.dataflow.transforms import combiners
42-
from google.cloud.dataflow.utils.counters import Counter
41+
from google.cloud.dataflow.transforms import core
4342

4443

4544
class Aggregator(object):
@@ -49,14 +48,12 @@ class Aggregator(object):
4948
combine_fn: how to combine values input to the aggregation.
5049
It must be one of these arithmetic functions:
5150
52-
- Python's built-in sum
53-
- Python's built-in min
54-
- Python's built-in max
55-
- df.Mean()
51+
- Python's built-in sum, min, max, any, and all.
52+
- df.combiners.MeanCombineFn()
5653
57-
The default is sum.
54+
The default is sum of 64-bit ints.
5855
59-
type: describes the numeric type that will be accepted as input
56+
type: describes the type that will be accepted as input
6057
for aggregation; by default types appropriate to the combine_fn
6158
are accepted.
6259
@@ -67,13 +64,16 @@ class Aggregator(object):
6764
complex_counter = df.Aggregator('other-counter', df.Mean(), float)
6865
"""
6966

70-
def __init__(self,
71-
name,
72-
combine_fn=sum,
73-
input_type=None): # inferred from combine_fn
67+
def __init__(self, name, combine_fn=sum, input_type=int):
68+
combine_fn = core.CombineFn.maybe_from_callable(combine_fn).for_input_type(
69+
input_type)
70+
if not _is_supported_kind(combine_fn):
71+
raise ValueError(
72+
'combine_fn %r (class %r) '
73+
'does not map to a supported aggregation kind'
74+
% (combine_fn, combine_fn.__class__))
7475
self.name = name
7576
self.combine_fn = combine_fn
76-
self.aggregation_kind = self._aggregator_counter_kind(combine_fn)
7777
self.input_type = input_type
7878

7979
def __str__(self):
@@ -98,30 +98,8 @@ def get_name(thing):
9898
combine_call = ' %s%s' % (combine_fn_str, input_arg)
9999
return 'Aggregator %s%s' % (self.name, combine_call)
100100

101-
@staticmethod
102-
def _aggregator_counter_kind(combine_fn):
103-
"""Returns the counter aggregation kind for the combine_fn passed in.
104-
105-
Args:
106-
combine_fn: The combining function used in an Aggregator.
107-
108-
Returns:
109-
The aggregation_kind (to use in a Counter) that matches combine_fn.
110-
111-
Raises:
112-
ValueError if the combine_fn doesn't map to any supported
113-
aggregation kind.
114-
"""
115-
# We don't have combiner types that implement AND or OR.
116-
combine_kind_map = {sum: Counter.SUM, max: Counter.MAX, min: Counter.MIN,
117-
combiners.Mean: Counter.MEAN}
118-
try:
119-
return combine_kind_map[combine_fn]
120-
except KeyError:
121-
try:
122-
return combine_kind_map[combine_fn.__class__]
123-
except KeyError:
124-
raise ValueError(
125-
'combine_fn %r (class %r) '
126-
'does not map to a supported aggregation kind'
127-
% (combine_fn, combine_fn.__class__))
101+
102+
def _is_supported_kind(combine_fn):
103+
# pylint: disable=g-import-not-at-top
104+
from google.cloud.dataflow.internal.apiclient import metric_translations
105+
return combine_fn.__class__ in metric_translations

google/cloud/dataflow/transforms/aggregator_test.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818

19+
import google.cloud.dataflow as df
1920
from google.cloud.dataflow.transforms import combiners
2021
from google.cloud.dataflow.transforms.aggregator import Aggregator
2122

@@ -24,16 +25,48 @@ class AggregatorTest(unittest.TestCase):
2425

2526
def test_str(self):
2627
basic = Aggregator('a-name')
27-
self.assertEqual('<Aggregator a-name>', str(basic))
28+
self.assertEqual('<Aggregator a-name SumInt64Fn(int)>', str(basic))
2829

2930
for_max = Aggregator('max-name', max)
30-
self.assertEqual('<Aggregator max-name max>', str(for_max))
31+
self.assertEqual('<Aggregator max-name MaxInt64Fn(int)>', str(for_max))
3132

3233
for_float = Aggregator('f-name', sum, float)
33-
self.assertEqual('<Aggregator f-name sum(float)>', str(for_float))
34+
self.assertEqual('<Aggregator f-name SumFloatFn(float)>', str(for_float))
3435

35-
for_mean = Aggregator('m-name', combiners.Mean(), float)
36-
self.assertEqual('<Aggregator m-name Mean(float)>', str(for_mean))
36+
for_mean = Aggregator('m-name', combiners.MeanCombineFn(), float)
37+
self.assertEqual('<Aggregator m-name MeanFloatFn(float)>', str(for_mean))
38+
39+
def test_aggregation(self):
40+
41+
mean = combiners.MeanCombineFn()
42+
mean.__name__ = 'mean'
43+
counter_types = [
44+
(sum, int, 6),
45+
(min, int, 0),
46+
(max, int, 3),
47+
(mean, int, 1),
48+
(sum, float, 6.0),
49+
(min, float, 0.0),
50+
(max, float, 3.0),
51+
(mean, float, 1.5),
52+
(any, int, True),
53+
(all, float, False),
54+
]
55+
aggeregators = [Aggregator('%s_%s' % (f.__name__, t.__name__), f, t)
56+
for f, t, _ in counter_types]
57+
58+
class UpdateAggregators(df.DoFn):
59+
def process(self, context):
60+
for a in aggeregators:
61+
context.aggregate_to(a, context.element)
62+
63+
p = df.Pipeline('DirectPipelineRunner')
64+
p | df.Create([0, 1, 2, 3]) | df.ParDo(UpdateAggregators())
65+
res = p.run()
66+
for (_, _, expected), a in zip(counter_types, aggeregators):
67+
actual = res.aggregated_values(a).values()[0]
68+
self.assertEqual(expected, actual)
69+
self.assertEqual(type(expected), type(actual))
3770

3871

3972
if __name__ == '__main__':

google/cloud/dataflow/transforms/combiners.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import random
2222

2323
from google.cloud.dataflow.transforms import core
24+
from google.cloud.dataflow.transforms import cy_combiners
2425
from google.cloud.dataflow.transforms import ptransform
2526
from google.cloud.dataflow.typehints import Any
2627
from google.cloud.dataflow.typehints import Dict
@@ -81,6 +82,14 @@ def extract_output(self, (sum_, count)):
8182
return float('NaN')
8283
return sum_ / float(count)
8384

85+
def for_input_type(self, input_type):
86+
if input_type is int:
87+
return cy_combiners.MeanInt64Fn()
88+
elif input_type is float:
89+
return cy_combiners.MeanFloatFn()
90+
else:
91+
return self
92+
8493

8594
class Count(object):
8695
"""Combiners for counting elements."""

google/cloud/dataflow/transforms/core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,16 @@ def apply(self, elements, *args, **kwargs):
341341
*args, **kwargs),
342342
*args, **kwargs)
343343

344+
def for_input_type(self, input_type):
345+
"""Returns a specialized implementation of self, if it exists.
346+
347+
Otherwise, returns self.
348+
349+
Args:
350+
input_type: the type of input elements.
351+
"""
352+
return self
353+
344354
@staticmethod
345355
def from_callable(fn):
346356
return CallableWrapperCombineFn(fn)
@@ -431,6 +441,24 @@ def default_type_hints(self):
431441
hints.set_input_types(*input_args, **input_kwargs)
432442
return hints
433443

444+
def for_input_type(self, input_type):
445+
# Avoid circular imports.
446+
from google.cloud.dataflow.transforms import cy_combiners
447+
if self._fn is any:
448+
return cy_combiners.AnyCombineFn()
449+
elif self._fn is all:
450+
return cy_combiners.AllCombineFn()
451+
else:
452+
known_types = {
453+
(sum, int): cy_combiners.SumInt64Fn(),
454+
(min, int): cy_combiners.MinInt64Fn(),
455+
(max, int): cy_combiners.MaxInt64Fn(),
456+
(sum, float): cy_combiners.SumFloatFn(),
457+
(min, float): cy_combiners.MinFloatFn(),
458+
(max, float): cy_combiners.MaxFloatFn(),
459+
}
460+
return known_types.get((self._fn, input_type), self)
461+
434462

435463
class PartitionFn(WithTypeHints):
436464
"""A function object used by a Partition transform.

0 commit comments

Comments
 (0)