3434
3535__all__ = ['TextFileSource' , 'TextFileSink' ]
3636
37+ DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
38+
3739
3840# Retrying is needed because there are transient errors that can happen.
3941@retry .with_exponential_backoff (num_retries = 4 , retry_filter = lambda _ : True )
@@ -150,77 +152,6 @@ def reader(self):
150152 return TextFileReader (self )
151153
152154
153- def TextFileSink (file_path_prefix , # pylint: disable=invalid-name
154- append_trailing_newlines = True ,
155- file_name_suffix = '' ,
156- num_shards = 0 ,
157- shard_name_template = None ,
158- validate = True ,
159- coder = coders .ToStringCoder ()):
160- """Initialize a TextSink.
161-
162- Args:
163- file_path_prefix: The file path to write to. The files written will begin
164- with this prefix, followed by a shard identifier (see num_shards), and
165- end in a common extension, if given by file_name_suffix. In most cases,
166- only this argument is specified and num_shards, shard_name_template, and
167- file_name_suffix use default values.
168- append_trailing_newlines: indicate whether this sink should write an
169- additional newline char after writing each element.
170- file_name_suffix: Suffix for the files written.
171- num_shards: The number of files (shards) used for output. If not set, the
172- service will decide on the optimal number of shards.
173- Constraining the number of shards is likely to reduce
174- the performance of a pipeline. Setting this value is not recommended
175- unless you require a specific number of output files.
176- shard_name_template: A template string containing placeholders for
177- the shard number and shard count. Currently only '' and
178- '-SSSSS-of-NNNNN' are patterns accepted by the service.
179- When constructing a filename for a particular shard number, the
180- upper-case letters 'S' and 'N' are replaced with the 0-padded shard
181- number and shard count respectively. This argument can be '' in which
182- case it behaves as if num_shards was set to 1 and only one file will be
183- generated. The default pattern used is '-SSSSS-of-NNNNN'.
184- validate: Enable path validation on pipeline creation.
185- coder: Coder used to encode each line.
186-
187- Raises:
188- TypeError: if file_path is not a string.
189- ValueError: if shard_name_template is not of expected format.
190-
191- Returns:
192- A TextFileSink object usable for writing.
193- """
194- if not isinstance (file_path_prefix , basestring ):
195- raise TypeError (
196- 'TextFileSink: file_path_prefix must be a string; got %r instead' %
197- file_path_prefix )
198- if not isinstance (file_name_suffix , basestring ):
199- raise TypeError (
200- 'TextFileSink: file_name_suffix must be a string; got %r instead' %
201- file_name_suffix )
202- if shard_name_template not in (None , '' , '-SSSSS-of-NNNNN' ):
203- raise ValueError (
204- 'The shard_name_template argument must be an empty string or the '
205- 'pattern -SSSSS-of-NNNNN instead of %s' % shard_name_template )
206- if shard_name_template == '' : # pylint: disable=g-explicit-bool-comparison
207- num_shards = 1
208-
209- if num_shards :
210- return NativeTextFileSink (file_path_prefix ,
211- append_trailing_newlines = append_trailing_newlines ,
212- file_name_suffix = file_name_suffix ,
213- num_shards = num_shards ,
214- shard_name_template = shard_name_template ,
215- validate = validate ,
216- coder = coder )
217- else :
218- return PureTextFileSink (file_path_prefix ,
219- append_trailing_newlines = append_trailing_newlines ,
220- file_name_suffix = file_name_suffix ,
221- coder = coder )
222-
223-
224155class ChannelFactory (object ):
225156 # TODO(robertwb): Generalize into extensible framework.
226157
@@ -239,7 +170,7 @@ def open(path, mode, mime_type):
239170 if path .startswith ('gs://' ):
240171 # pylint: disable=g-import-not-at-top
241172 from google .cloud .dataflow .io import gcsio
242- return gcsio .GcsIO ().open (path , mode , mime_type )
173+ return gcsio .GcsIO ().open (path , mode , mime_type = mime_type )
243174 else :
244175 return open (path , mode )
245176
@@ -358,11 +289,19 @@ def __init__(self,
358289 file_path_prefix ,
359290 coder ,
360291 file_name_suffix = '' ,
292+ num_shards = 0 ,
293+ shard_name_template = None ,
361294 mime_type = 'application/octet-stream' ):
295+ if shard_name_template is None :
296+ shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
297+ elif shard_name_template is '' :
298+ num_shards = 1
362299 self .file_path_prefix = file_path_prefix
363300 self .file_name_suffix = file_name_suffix
301+ self .num_shards = num_shards
364302 self .coder = coder
365303 self .mime_type = mime_type
304+ self .shard_name_format = self ._template_to_format (shard_name_template )
366305
367306 def open (self , temp_path ):
368307 """Opens ``temp_path``, returning an opaque file handle object.
@@ -410,8 +349,11 @@ def finalize_write(self, init_result, writer_results):
410349 # TODO(robertwb): Threadpool?
411350 channel_factory = ChannelFactory ()
412351 for shard_num , shard in enumerate (writer_results ):
413- final_name = '%s-%05d-of-%05d%s' % (self .file_path_prefix , shard_num ,
414- num_shards , self .file_name_suffix )
352+ final_name = '' .join ([
353+ self .file_path_prefix ,
354+ self .shard_name_format % dict (shard_num = shard_num ,
355+ num_shards = num_shards ),
356+ self .file_name_suffix ])
415357 try :
416358 channel_factory .rename (shard , final_name )
417359 except IOError :
@@ -426,6 +368,22 @@ def finalize_write(self, init_result, writer_results):
426368 # May have already been removed.
427369 pass
428370
371+ @staticmethod
372+ def _template_to_format (shard_name_template ):
373+ if not shard_name_template :
374+ return ''
375+ m = re .search ('S+' , shard_name_template )
376+ if m is None :
377+ raise ValueError ("Shard number pattern S+ not found in template '%s'"
378+ % shard_name_template )
379+ shard_name_format = shard_name_template .replace (
380+ m .group (0 ), '%%(shard_num)0%dd' % len (m .group (0 )))
381+ m = re .search ('N+' , shard_name_format )
382+ if m :
383+ shard_name_format = shard_name_format .replace (
384+ m .group (0 ), '%%(num_shards)0%dd' % len (m .group (0 )))
385+ return shard_name_format
386+
429387 def __eq__ (self , other ):
430388 # TODO(robertwb): Clean up workitem_test which uses this.
431389 # pylint: disable=unidiomatic-typecheck
@@ -449,16 +407,19 @@ def close(self):
449407 return self .temp_shard_path
450408
451409
452- class PureTextFileSink (FileSink ):
410+ class TextFileSink (FileSink ):
453411 """A sink to a GCS or local text file or files."""
454412
455413 def __init__ (self ,
456414 file_path_prefix ,
457415 file_name_suffix = '' ,
416+ append_trailing_newlines = True ,
417+ num_shards = 0 ,
418+ shard_name_template = None ,
458419 coder = coders .ToStringCoder (),
459420 compression_type = CompressionTypes .NO_COMPRESSION ,
460- append_trailing_newlines = True ):
461- """Initialize a PureTextFileSink .
421+ ):
422+ """Initialize a TextFileSink .
462423
463424 Args:
464425 file_path_prefix: The file path to write to. The files written will begin
@@ -467,33 +428,57 @@ def __init__(self,
467428 only this argument is specified and num_shards, shard_name_template, and
468429 file_name_suffix use default values.
469430 file_name_suffix: Suffix for the files written.
470- coder: Coder used to encode each line.
471- compression_type: Type of compression to use for this sink.
472431 append_trailing_newlines: indicate whether this sink should write an
473432 additional newline char after writing each element.
433+ num_shards: The number of files (shards) used for output. If not set, the
434+ service will decide on the optimal number of shards.
435+ Constraining the number of shards is likely to reduce
436+ the performance of a pipeline. Setting this value is not recommended
437+ unless you require a specific number of output files.
438+ shard_name_template: A template string containing placeholders for
439+ the shard number and shard count. Currently only '' and
440+ '-SSSSS-of-NNNNN' are patterns accepted by the service.
441+ When constructing a filename for a particular shard number, the
442+ upper-case letters 'S' and 'N' are replaced with the 0-padded shard
443+ number and shard count respectively. This argument can be '' in which
444+ case it behaves as if num_shards was set to 1 and only one file will be
445+ generated. The default pattern used is '-SSSSS-of-NNNNN'.
446+ coder: Coder used to encode each line.
447+ compression_type: Type of compression to use for this sink.
474448
475449 Raises:
476- TypeError: if file_path is not a string or if compression_type is not
477- member of CompressionTypes.
450+ TypeError: if file path parameters are not a string or if compression_type
451+ is not member of CompressionTypes.
452+ ValueError: if shard_name_template is not of expected format.
478453
479454 Returns:
480- A PureTextFileSink object usable for writing.
455+ A TextFileSink object usable for writing.
481456 """
457+ if not isinstance (file_path_prefix , basestring ):
458+ raise TypeError (
459+ 'TextFileSink: file_path_prefix must be a string; got %r instead' %
460+ file_path_prefix )
461+ if not isinstance (file_name_suffix , basestring ):
462+ raise TypeError (
463+ 'TextFileSink: file_name_suffix must be a string; got %r instead' %
464+ file_name_suffix )
465+
482466 if not CompressionTypes .valid_compression_type (compression_type ):
483467 raise TypeError ('compression_type must be CompressionType object but '
484468 'was %s' % type (compression_type ))
485-
486469 if compression_type == CompressionTypes .DEFLATE :
487470 mime_type = 'application/x-gzip'
488471 else :
489472 mime_type = 'text/plain'
490- self .compression_type = compression_type
491473
492- super (PureTextFileSink , self ).__init__ (file_path_prefix ,
493- file_name_suffix = file_name_suffix ,
494- coder = coder ,
495- mime_type = mime_type )
474+ super (TextFileSink , self ).__init__ (file_path_prefix ,
475+ file_name_suffix = file_name_suffix ,
476+ num_shards = num_shards ,
477+ shard_name_template = shard_name_template ,
478+ coder = coder ,
479+ mime_type = mime_type )
496480
481+ self .compression_type = compression_type
497482 self .append_trailing_newlines = append_trailing_newlines
498483
499484 def open (self , temp_path ):
0 commit comments