Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit cb60c6c

Browse files
chamikaramjaaltay
authored andcommitted
Dynamic work rebalancing support for InMemory reader.
----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=123075083
1 parent 1d44fdb commit cb60c6c

3 files changed

Lines changed: 149 additions & 35 deletions

File tree

google/cloud/dataflow/worker/executor_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, source):
9595
def get_progress(self):
9696
next_progress = super(ProgressRequestRecordingInMemoryReader,
9797
self).get_progress()
98-
self.progress_record.append(next_progress.percent_complete)
98+
self.progress_record.append(next_progress.position.record_index)
9999
return next_progress
100100

101101

@@ -332,11 +332,7 @@ def test_in_memory_source_progress_reporting(self):
332332
]))
333333
self.assertEqual(elements, output_buffer)
334334

335-
expected_progress_record = []
336-
len_elements = len(elements)
337-
for i in range(len_elements):
338-
expected_progress_record.append(float(i + 1) / len_elements)
339-
335+
expected_progress_record = range(len(elements))
340336
self.assertEqual(expected_progress_record,
341337
source.last_reader.progress_record)
342338

google/cloud/dataflow/worker/inmemory.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
"""In-memory input source."""
1616

1717
import itertools
18+
import logging
1819

1920
from google.cloud.dataflow import coders
2021
from google.cloud.dataflow.io import iobase
22+
from google.cloud.dataflow.io import range_trackers
2123

2224

2325
class InMemorySource(iobase.NativeSource):
@@ -53,11 +55,14 @@ class InMemoryReader(iobase.NativeSourceReader):
5355
"""A reader for in-memory source."""
5456

5557
def __init__(self, source):
56-
self.source = source
58+
self._source = source
5759

58-
# Index of the next item to be read by the InMemoryReader.
59-
# Starts at source.start_index.
60-
self.current_index = source.start_index
60+
# Index of the last item returned by InMemoryReader.
61+
# Initialized to None.
62+
self._current_index = None
63+
64+
self._range_tracker = range_trackers.OffsetRangeTracker(
65+
self._source.start_index, self._source.end_index)
6166

6267
def __enter__(self):
6368
return self
@@ -66,21 +71,49 @@ def __exit__(self, exception_type, exception_value, traceback):
6671
pass
6772

6873
def __iter__(self):
69-
for value in itertools.islice(self.source.elements,
70-
self.source.start_index,
71-
self.source.end_index):
72-
self.current_index += 1
73-
yield self.source.coder.decode(value)
74+
for value in itertools.islice(self._source.elements,
75+
self._source.start_index,
76+
self._source.end_index):
77+
claimed = False
78+
if self._current_index is None:
79+
claimed = self._range_tracker.try_claim(
80+
self._source.start_index)
81+
else:
82+
claimed = self._range_tracker.try_claim(
83+
self._current_index + 1)
84+
85+
if claimed:
86+
if self._current_index is None:
87+
self._current_index = self._source.start_index
88+
else:
89+
self._current_index += 1
90+
91+
yield self._source.coder.decode(value)
92+
else:
93+
return
7494

7595
def get_progress(self):
76-
if (self.current_index >= self.source.end_index or
77-
self.source.start_index >= self.source.end_index):
78-
percent_complete = 1
79-
elif self.current_index == self.source.start_index:
80-
percent_complete = 0
81-
else:
82-
percent_complete = (
83-
float(self.current_index - self.source.start_index) / (
84-
self.source.end_index - self.source.start_index))
85-
86-
return iobase.ReaderProgress(percent_complete=percent_complete)
96+
if self._current_index is None:
97+
return None
98+
99+
return iobase.ReaderProgress(
100+
position=iobase.ReaderPosition(record_index=self._current_index))
101+
102+
def request_dynamic_split(self, dynamic_split_request):
103+
assert dynamic_split_request is not None
104+
progress = dynamic_split_request.progress
105+
split_position = progress.position
106+
if split_position is None:
107+
logging.debug('InMemory reader only supports split requests that are '
108+
'based on positions. Received : %r', dynamic_split_request)
109+
return None
110+
111+
index_position = split_position.record_index
112+
if index_position is None:
113+
logging.debug('InMemory reader only supports split requests that are '
114+
'based on index positions. Received : %r',
115+
dynamic_split_request)
116+
return None
117+
118+
if self._range_tracker.try_split(index_position):
119+
return iobase.DynamicSplitResultWithPosition(split_position)

google/cloud/dataflow/worker/inmemory_test.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import unittest
1919

20+
from google.cloud.dataflow.io import iobase
2021
from google.cloud.dataflow.worker import inmemory
2122

2223

@@ -41,31 +42,115 @@ def test_norange(self):
4142
def test_in_memory_source_updates_progress_none(self):
4243
source = inmemory.InMemorySource([], coder=FakeCoder())
4344
with source.reader() as reader:
44-
self.assertEqual(1, reader.get_progress().percent_complete)
45+
self.assertEqual(None, reader.get_progress())
4546

4647
def test_in_memory_source_updates_progress_one(self):
4748
source = inmemory.InMemorySource([1], coder=FakeCoder())
4849
with source.reader() as reader:
49-
self.assertEqual(0, reader.get_progress().percent_complete)
50+
self.assertEqual(None, reader.get_progress())
5051
i = 0
5152
for item in reader:
52-
i += 1
53+
self.assertEqual(i, reader.get_progress().position.record_index)
5354
self.assertEqual(11, item)
54-
self.assertEqual(1, reader.get_progress().percent_complete)
55+
i += 1
5556
self.assertEqual(1, i)
56-
self.assertEqual(1, reader.get_progress().percent_complete)
57+
self.assertEqual(0, reader.get_progress().position.record_index)
5758

5859
def test_in_memory_source_updates_progress_many(self):
5960
source = inmemory.InMemorySource([1, 2, 3, 4, 5], coder=FakeCoder())
6061
with source.reader() as reader:
61-
self.assertEqual(0, reader.get_progress().percent_complete)
62+
self.assertEqual(None, reader.get_progress())
6263
i = 0
6364
for item in reader:
65+
self.assertEqual(i, reader.get_progress().position.record_index)
66+
self.assertEqual(11 + i, item)
6467
i += 1
65-
self.assertEqual(i + 10, item)
66-
self.assertEqual(float(i) / 5, reader.get_progress().percent_complete)
6768
self.assertEqual(5, i)
68-
self.assertEqual(1, reader.get_progress().percent_complete)
69+
self.assertEqual(4, reader.get_progress().position.record_index)
70+
71+
def try_splitting_reader_at(self, reader, split_request, expected_response):
72+
actual_response = reader.request_dynamic_split(split_request)
73+
74+
if expected_response is None:
75+
self.assertIsNone(actual_response)
76+
else:
77+
self.assertIsNotNone(actual_response.stop_position)
78+
self.assertIsInstance(actual_response.stop_position,
79+
iobase.ReaderPosition)
80+
self.assertIsNotNone(actual_response.stop_position.record_index)
81+
self.assertEqual(expected_response.stop_position.record_index,
82+
actual_response.stop_position.record_index)
83+
84+
def test_in_memory_source_dynamic_split(self):
85+
source = inmemory.InMemorySource([10, 20, 30, 40, 50, 60],
86+
coder=FakeCoder())
87+
88+
# Unstarted reader
89+
with source.reader() as reader:
90+
self.try_splitting_reader_at(
91+
reader,
92+
iobase.DynamicSplitRequest(
93+
iobase.ReaderProgress(
94+
position=iobase.ReaderPosition(record_index=2))),
95+
None)
96+
97+
# Proposed split position out of range
98+
with source.reader() as reader:
99+
reader_iter = iter(reader)
100+
next(reader_iter)
101+
self.try_splitting_reader_at(
102+
reader,
103+
iobase.DynamicSplitRequest(
104+
iobase.ReaderProgress(
105+
position=iobase.ReaderPosition(record_index=-1))),
106+
None)
107+
self.try_splitting_reader_at(
108+
reader,
109+
iobase.DynamicSplitRequest(
110+
iobase.ReaderProgress(
111+
position=iobase.ReaderPosition(record_index=10))),
112+
None)
113+
114+
# Already read past proposed split position
115+
with source.reader() as reader:
116+
reader_iter = iter(reader)
117+
next(reader_iter)
118+
next(reader_iter)
119+
next(reader_iter)
120+
self.try_splitting_reader_at(
121+
reader,
122+
iobase.DynamicSplitRequest(
123+
iobase.ReaderProgress(
124+
position=iobase.ReaderPosition(record_index=1))),
125+
None)
126+
127+
self.try_splitting_reader_at(
128+
reader,
129+
iobase.DynamicSplitRequest(
130+
iobase.ReaderProgress(
131+
position=iobase.ReaderPosition(record_index=2))),
132+
None)
133+
134+
# Successful split
135+
with source.reader() as reader:
136+
reader_iter = iter(reader)
137+
next(reader_iter)
138+
self.try_splitting_reader_at(
139+
reader,
140+
iobase.DynamicSplitRequest(
141+
iobase.ReaderProgress(
142+
position=iobase.ReaderPosition(record_index=4))),
143+
iobase.DynamicSplitResultWithPosition(
144+
stop_position=iobase.ReaderPosition(record_index=4)))
145+
146+
self.try_splitting_reader_at(
147+
reader,
148+
iobase.DynamicSplitRequest(
149+
iobase.ReaderProgress(
150+
position=iobase.ReaderPosition(record_index=2))),
151+
iobase.DynamicSplitResultWithPosition(
152+
stop_position=iobase.ReaderPosition(record_index=2)))
153+
69154

70155
if __name__ == '__main__':
71156
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)