2323
2424import collections
2525import itertools
26+ import logging
2627
2728from google .cloud .dataflow import coders
2829from google .cloud .dataflow import error
@@ -54,6 +55,14 @@ def __init__(self, cache=None):
5455 # Cache of values computed while the runner executes a pipeline.
5556 self ._cache = cache if cache is not None else PValueCache ()
5657 self ._counter_factory = counters .CounterFactory ()
58+ # Element counts used for debugging footprint issues in the direct runner.
59+ # The values computed are used only for logging and do not take part in
60+ # any decision making logic. The key for the counter dictionary is either
61+ # the full label for the transform producing the elements or a tuple
62+ # (full label, output tag) for ParDo transforms since they can output values
63+ # on multiple outputs.
64+ self .debug_counters = {}
65+ self .debug_counters ['element_counts' ] = collections .Counter ()
5766
5867 def get_pvalue (self , pvalue ):
5968 """Gets the PValue's computed value from the runner's cache."""
@@ -70,12 +79,17 @@ def skip_if_cached(func): # pylint: disable=no-self-argument
7079 """Decorator to skip execution of a transform if value is cached."""
7180
7281 def func_wrapper (self , pvalue , * args , ** kwargs ):
82+ logging .debug ('Current: Debug counters: %s' , self .debug_counters )
7383 if self ._cache .is_cached (pvalue ): # pylint: disable=protected-access
7484 return
7585 else :
7686 func (self , pvalue , * args , ** kwargs )
7787 return func_wrapper
7888
89+ def run (self , pipeline , node = None ):
90+ super (DirectPipelineRunner , self ).run (pipeline , node )
91+ logging .info ('Final: Debug counters: %s' , self .debug_counters )
92+
7993 @skip_if_cached
8094 def run_ParDo (self , transform_node ):
8195 transform = transform_node .transform
@@ -143,6 +157,8 @@ def __missing__(self, key):
143157
144158 self ._cache .cache_output (transform_node , [])
145159 for tag , value in results .items ():
160+ self .debug_counters ['element_counts' ][
161+ (transform_node .full_label , tag )] += len (value )
146162 self ._cache .cache_output (transform_node , tag , value )
147163
148164 @skip_if_cached
@@ -166,24 +182,29 @@ def run_GroupByKeyOnly(self, transform_node):
166182 'windowed key-value pairs. Instead received: %r.'
167183 % wv )
168184
169- self ._cache .cache_output (
170- transform_node ,
171- map (GlobalWindows .WindowedValue ,
172- ((key_coder .decode (k ), v ) for k , v in result_dict .iteritems ())))
185+ gbk_result = map (
186+ GlobalWindows .WindowedValue ,
187+ ((key_coder .decode (k ), v ) for k , v in result_dict .iteritems ()))
188+ self .debug_counters ['element_counts' ][
189+ transform_node .full_label ] += len (gbk_result )
190+ self ._cache .cache_output (transform_node , gbk_result )
173191
174192 @skip_if_cached
175193 def run_Create (self , transform_node ):
176194 transform = transform_node .transform
177- self ._cache .cache_output (
178- transform_node ,
179- [GlobalWindows .WindowedValue (v ) for v in transform .value ])
195+ create_result = [GlobalWindows .WindowedValue (v ) for v in transform .value ]
196+ self .debug_counters ['element_counts' ][
197+ transform_node .full_label ] += len (create_result )
198+ self ._cache .cache_output (transform_node , create_result )
180199
181200 @skip_if_cached
182201 def run_Flatten (self , transform_node ):
183- self ._cache .cache_output (
184- transform_node ,
185- list (itertools .chain .from_iterable (
186- self ._cache .get_pvalue (pc ) for pc in transform_node .inputs )))
202+ flatten_result = list (
203+ itertools .chain .from_iterable (
204+ self ._cache .get_pvalue (pc ) for pc in transform_node .inputs ))
205+ self .debug_counters ['element_counts' ][
206+ transform_node .full_label ] += len (flatten_result )
207+ self ._cache .cache_output (transform_node , flatten_result )
187208
188209 @skip_if_cached
189210 def run_Read (self , transform_node ):
@@ -192,13 +213,16 @@ def run_Read(self, transform_node):
192213 source = transform_node .transform .source
193214 source .pipeline_options = transform_node .inputs [0 ].pipeline .options
194215 with source .reader () as reader :
195- self ._cache .cache_output (
196- transform_node , [GlobalWindows .WindowedValue (e ) for e in reader ])
216+ read_result = [GlobalWindows .WindowedValue (e ) for e in reader ]
217+ self .debug_counters ['element_counts' ][
218+ transform_node .full_label ] += len (read_result )
219+ self ._cache .cache_output (transform_node , read_result )
197220
198221 @skip_if_cached
199222 def run__NativeWrite (self , transform_node ):
200223 sink = transform_node .transform .sink
201224 sink .pipeline_options = transform_node .inputs [0 ].pipeline .options
202225 with sink .writer () as writer :
203226 for v in self ._cache .get_pvalue (transform_node .inputs [0 ]):
227+ self .debug_counters ['element_counts' ][transform_node .full_label ] += 1
204228 writer .Write (v .value )
0 commit comments