Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 6084ff5

Browse files
gildeasilviulica
authored andcommitted
Make element iterators observable
Mix class ObservableMixin into various values iterators, so that we have a hook to estimate the size of the elements we iterate over. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=119993968
1 parent 29083f5 commit 6084ff5

3 files changed

Lines changed: 19 additions & 7 deletions

File tree

google/cloud/dataflow/transforms/trigger.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import collections
2323
import copy
2424

25+
from google.cloud.dataflow.coders import observable
2526
from google.cloud.dataflow.transforms import combiners
2627
from google.cloud.dataflow.transforms import core
2728
from google.cloud.dataflow.transforms.timeutil import MAX_TIMESTAMP
@@ -721,9 +722,12 @@ def process_elements(self, state, windowed_values, unused_output_watermark):
721722
if isinstance(windowed_values, list):
722723
unwindowed = [wv.value for wv in windowed_values]
723724
else:
724-
class UnwindowedValues(object):
725+
class UnwindowedValues(observable.ObservableMixin):
725726
def __iter__(self):
726-
return (wv.value for wv in windowed_values)
727+
for wv in windowed_values:
728+
unwindowed_value = wv.value
729+
self.notify_observers(unwindowed_value)
730+
yield unwindowed_value
727731
def __repr__(self):
728732
return '<UnwindowedValues of %s>' % windowed_values
729733
unwindowed = UnwindowedValues()

google/cloud/dataflow/worker/shuffle.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import logging
3939
import struct
4040

41+
from google.cloud.dataflow.coders import observable
4142
from google.cloud.dataflow.io import iobase
4243
from google.cloud.dataflow.io import range_trackers
4344

@@ -246,7 +247,7 @@ def clone(self, start_position, end_position, key):
246247
self.iterable.reader, start_position, end_position, key))
247248

248249

249-
class ShuffleKeyValuesIterable(object):
250+
class ShuffleKeyValuesIterable(observable.ObservableMixin):
250251
"""An iterable over all values associated with a key.
251252
252253
The class supports reiteration over the values by cloning the underlying
@@ -257,6 +258,7 @@ class ShuffleKeyValuesIterable(object):
257258

258259
def __init__(self, entries_iterator, key, value_coder,
259260
start_position, end_position=''):
261+
super(ShuffleKeyValuesIterable, self).__init__()
260262
self.key = key
261263
self.value_coder = value_coder
262264
self.start_position = start_position
@@ -290,7 +292,9 @@ def values_iterator(self):
290292
self.end_position = entry.position
291293
self.entries_iterator.push_back(entry)
292294
break
293-
yield self.value_coder.decode(entry.value)
295+
decoded_value = self.value_coder.decode(entry.value)
296+
self.notify_observers(entry.value, is_encoded=True)
297+
yield decoded_value
294298

295299
def __str__(self):
296300
return '<%s>' % self._str_internal()

google/cloud/dataflow/worker/windmillio.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import absolute_import
2020

21+
from google.cloud.dataflow.coders import observable
2122
from google.cloud.dataflow.io import coders
2223
from google.cloud.dataflow.io import iobase
2324
from google.cloud.dataflow.io import pubsub
@@ -216,10 +217,11 @@ def __repr__(self):
216217
self.state_family)
217218

218219

219-
class KeyedWorkItem(object):
220+
class KeyedWorkItem(observable.ObservableMixin):
220221
"""Keyed work item used by a StreamingGroupAlsoByWindowsOperation."""
221222

222223
def __init__(self, work_item, coder):
224+
super(KeyedWorkItem, self).__init__()
223225
self.work_item = work_item
224226
self.coder = coder
225227
self.key_coder = coder.key_coder()
@@ -236,7 +238,9 @@ def __init__(self, work_item, coder):
236238
def elements(self):
237239
for bundle in self.work_item.message_bundles:
238240
for message in bundle.messages:
239-
yield self.wv_coder.decode(message.data)
241+
element = self.wv_coder.decode(message.data)
242+
self.notify_observers(message.data, is_encoded=True)
243+
yield element
240244

241245
def timers(self):
242246
if self.work_item.timers:
@@ -252,7 +256,7 @@ def timers(self):
252256
state_family=timer_item.state_family)
253257

254258
def __repr__(self):
255-
return 'KeyedWorkItem(%r)' % self.key
259+
return '<%s %s>' % (self.__class__.__name__, self.key)
256260

257261

258262
class WindowingWindmillSource(iobase.Source):

0 commit comments

Comments
 (0)