1717import logging
1818import unittest
1919
20+ from google .cloud .dataflow .io import iobase
2021from google .cloud .dataflow .worker import inmemory
2122
2223
@@ -41,31 +42,115 @@ def test_norange(self):
4142 def test_in_memory_source_updates_progress_none (self ):
4243 source = inmemory .InMemorySource ([], coder = FakeCoder ())
4344 with source .reader () as reader :
44- self .assertEqual (1 , reader .get_progress (). percent_complete )
45+ self .assertEqual (None , reader .get_progress ())
4546
4647 def test_in_memory_source_updates_progress_one (self ):
4748 source = inmemory .InMemorySource ([1 ], coder = FakeCoder ())
4849 with source .reader () as reader :
49- self .assertEqual (0 , reader .get_progress (). percent_complete )
50+ self .assertEqual (None , reader .get_progress ())
5051 i = 0
5152 for item in reader :
52- i += 1
53+ self . assertEqual ( i , reader . get_progress (). position . record_index )
5354 self .assertEqual (11 , item )
54- self . assertEqual ( 1 , reader . get_progress (). percent_complete )
55+ i += 1
5556 self .assertEqual (1 , i )
56- self .assertEqual (1 , reader .get_progress ().percent_complete )
57+ self .assertEqual (0 , reader .get_progress ().position . record_index )
5758
5859 def test_in_memory_source_updates_progress_many (self ):
5960 source = inmemory .InMemorySource ([1 , 2 , 3 , 4 , 5 ], coder = FakeCoder ())
6061 with source .reader () as reader :
61- self .assertEqual (0 , reader .get_progress (). percent_complete )
62+ self .assertEqual (None , reader .get_progress ())
6263 i = 0
6364 for item in reader :
65+ self .assertEqual (i , reader .get_progress ().position .record_index )
66+ self .assertEqual (11 + i , item )
6467 i += 1
65- self .assertEqual (i + 10 , item )
66- self .assertEqual (float (i ) / 5 , reader .get_progress ().percent_complete )
6768 self .assertEqual (5 , i )
68- self .assertEqual (1 , reader .get_progress ().percent_complete )
69+ self .assertEqual (4 , reader .get_progress ().position .record_index )
70+
71+ def try_splitting_reader_at (self , reader , split_request , expected_response ):
72+ actual_response = reader .request_dynamic_split (split_request )
73+
74+ if expected_response is None :
75+ self .assertIsNone (actual_response )
76+ else :
77+ self .assertIsNotNone (actual_response .stop_position )
78+ self .assertIsInstance (actual_response .stop_position ,
79+ iobase .ReaderPosition )
80+ self .assertIsNotNone (actual_response .stop_position .record_index )
81+ self .assertEqual (expected_response .stop_position .record_index ,
82+ actual_response .stop_position .record_index )
83+
84+ def test_in_memory_source_dynamic_split (self ):
85+ source = inmemory .InMemorySource ([10 , 20 , 30 , 40 , 50 , 60 ],
86+ coder = FakeCoder ())
87+
88+ # Unstarted reader
89+ with source .reader () as reader :
90+ self .try_splitting_reader_at (
91+ reader ,
92+ iobase .DynamicSplitRequest (
93+ iobase .ReaderProgress (
94+ position = iobase .ReaderPosition (record_index = 2 ))),
95+ None )
96+
97+ # Proposed split position out of range
98+ with source .reader () as reader :
99+ reader_iter = iter (reader )
100+ next (reader_iter )
101+ self .try_splitting_reader_at (
102+ reader ,
103+ iobase .DynamicSplitRequest (
104+ iobase .ReaderProgress (
105+ position = iobase .ReaderPosition (record_index = - 1 ))),
106+ None )
107+ self .try_splitting_reader_at (
108+ reader ,
109+ iobase .DynamicSplitRequest (
110+ iobase .ReaderProgress (
111+ position = iobase .ReaderPosition (record_index = 10 ))),
112+ None )
113+
114+ # Already read past proposed split position
115+ with source .reader () as reader :
116+ reader_iter = iter (reader )
117+ next (reader_iter )
118+ next (reader_iter )
119+ next (reader_iter )
120+ self .try_splitting_reader_at (
121+ reader ,
122+ iobase .DynamicSplitRequest (
123+ iobase .ReaderProgress (
124+ position = iobase .ReaderPosition (record_index = 1 ))),
125+ None )
126+
127+ self .try_splitting_reader_at (
128+ reader ,
129+ iobase .DynamicSplitRequest (
130+ iobase .ReaderProgress (
131+ position = iobase .ReaderPosition (record_index = 2 ))),
132+ None )
133+
134+ # Successful split
135+ with source .reader () as reader :
136+ reader_iter = iter (reader )
137+ next (reader_iter )
138+ self .try_splitting_reader_at (
139+ reader ,
140+ iobase .DynamicSplitRequest (
141+ iobase .ReaderProgress (
142+ position = iobase .ReaderPosition (record_index = 4 ))),
143+ iobase .DynamicSplitResultWithPosition (
144+ stop_position = iobase .ReaderPosition (record_index = 4 )))
145+
146+ self .try_splitting_reader_at (
147+ reader ,
148+ iobase .DynamicSplitRequest (
149+ iobase .ReaderProgress (
150+ position = iobase .ReaderPosition (record_index = 2 ))),
151+ iobase .DynamicSplitResultWithPosition (
152+ stop_position = iobase .ReaderPosition (record_index = 2 )))
153+
69154
70155if __name__ == '__main__' :
71156 logging .getLogger ().setLevel (logging .INFO )
0 commit comments