Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 8a31915

Browse files
robertwbaaltay
authored andcommitted
Implement fixed sharding in Text sink.
Added support for general sharding template specifications. Also ensures there is at least on file produced when writing the empty PCollection. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=124614940
1 parent 5ff2c3e commit 8a31915

4 files changed

Lines changed: 174 additions & 134 deletions

File tree

google/cloud/dataflow/io/fileio.py

Lines changed: 73 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
__all__ = ['TextFileSource', 'TextFileSink']
3636

37+
DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
38+
3739

3840
# Retrying is needed because there are transient errors that can happen.
3941
@retry.with_exponential_backoff(num_retries=4, retry_filter=lambda _: True)
@@ -150,77 +152,6 @@ def reader(self):
150152
return TextFileReader(self)
151153

152154

153-
def TextFileSink(file_path_prefix, # pylint: disable=invalid-name
154-
append_trailing_newlines=True,
155-
file_name_suffix='',
156-
num_shards=0,
157-
shard_name_template=None,
158-
validate=True,
159-
coder=coders.ToStringCoder()):
160-
"""Initialize a TextSink.
161-
162-
Args:
163-
file_path_prefix: The file path to write to. The files written will begin
164-
with this prefix, followed by a shard identifier (see num_shards), and
165-
end in a common extension, if given by file_name_suffix. In most cases,
166-
only this argument is specified and num_shards, shard_name_template, and
167-
file_name_suffix use default values.
168-
append_trailing_newlines: indicate whether this sink should write an
169-
additional newline char after writing each element.
170-
file_name_suffix: Suffix for the files written.
171-
num_shards: The number of files (shards) used for output. If not set, the
172-
service will decide on the optimal number of shards.
173-
Constraining the number of shards is likely to reduce
174-
the performance of a pipeline. Setting this value is not recommended
175-
unless you require a specific number of output files.
176-
shard_name_template: A template string containing placeholders for
177-
the shard number and shard count. Currently only '' and
178-
'-SSSSS-of-NNNNN' are patterns accepted by the service.
179-
When constructing a filename for a particular shard number, the
180-
upper-case letters 'S' and 'N' are replaced with the 0-padded shard
181-
number and shard count respectively. This argument can be '' in which
182-
case it behaves as if num_shards was set to 1 and only one file will be
183-
generated. The default pattern used is '-SSSSS-of-NNNNN'.
184-
validate: Enable path validation on pipeline creation.
185-
coder: Coder used to encode each line.
186-
187-
Raises:
188-
TypeError: if file_path is not a string.
189-
ValueError: if shard_name_template is not of expected format.
190-
191-
Returns:
192-
A TextFileSink object usable for writing.
193-
"""
194-
if not isinstance(file_path_prefix, basestring):
195-
raise TypeError(
196-
'TextFileSink: file_path_prefix must be a string; got %r instead' %
197-
file_path_prefix)
198-
if not isinstance(file_name_suffix, basestring):
199-
raise TypeError(
200-
'TextFileSink: file_name_suffix must be a string; got %r instead' %
201-
file_name_suffix)
202-
if shard_name_template not in (None, '', '-SSSSS-of-NNNNN'):
203-
raise ValueError(
204-
'The shard_name_template argument must be an empty string or the '
205-
'pattern -SSSSS-of-NNNNN instead of %s' % shard_name_template)
206-
if shard_name_template == '': # pylint: disable=g-explicit-bool-comparison
207-
num_shards = 1
208-
209-
if num_shards:
210-
return NativeTextFileSink(file_path_prefix,
211-
append_trailing_newlines=append_trailing_newlines,
212-
file_name_suffix=file_name_suffix,
213-
num_shards=num_shards,
214-
shard_name_template=shard_name_template,
215-
validate=validate,
216-
coder=coder)
217-
else:
218-
return PureTextFileSink(file_path_prefix,
219-
append_trailing_newlines=append_trailing_newlines,
220-
file_name_suffix=file_name_suffix,
221-
coder=coder)
222-
223-
224155
class ChannelFactory(object):
225156
# TODO(robertwb): Generalize into extensible framework.
226157

@@ -239,7 +170,7 @@ def open(path, mode, mime_type):
239170
if path.startswith('gs://'):
240171
# pylint: disable=g-import-not-at-top
241172
from google.cloud.dataflow.io import gcsio
242-
return gcsio.GcsIO().open(path, mode, mime_type)
173+
return gcsio.GcsIO().open(path, mode, mime_type=mime_type)
243174
else:
244175
return open(path, mode)
245176

@@ -358,11 +289,19 @@ def __init__(self,
358289
file_path_prefix,
359290
coder,
360291
file_name_suffix='',
292+
num_shards=0,
293+
shard_name_template=None,
361294
mime_type='application/octet-stream'):
295+
if shard_name_template is None:
296+
shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
297+
elif shard_name_template is '':
298+
num_shards = 1
362299
self.file_path_prefix = file_path_prefix
363300
self.file_name_suffix = file_name_suffix
301+
self.num_shards = num_shards
364302
self.coder = coder
365303
self.mime_type = mime_type
304+
self.shard_name_format = self._template_to_format(shard_name_template)
366305

367306
def open(self, temp_path):
368307
"""Opens ``temp_path``, returning an opaque file handle object.
@@ -410,8 +349,11 @@ def finalize_write(self, init_result, writer_results):
410349
# TODO(robertwb): Threadpool?
411350
channel_factory = ChannelFactory()
412351
for shard_num, shard in enumerate(writer_results):
413-
final_name = '%s-%05d-of-%05d%s' % (self.file_path_prefix, shard_num,
414-
num_shards, self.file_name_suffix)
352+
final_name = ''.join([
353+
self.file_path_prefix,
354+
self.shard_name_format % dict(shard_num=shard_num,
355+
num_shards=num_shards),
356+
self.file_name_suffix])
415357
try:
416358
channel_factory.rename(shard, final_name)
417359
except IOError:
@@ -426,6 +368,22 @@ def finalize_write(self, init_result, writer_results):
426368
# May have already been removed.
427369
pass
428370

371+
@staticmethod
372+
def _template_to_format(shard_name_template):
373+
if not shard_name_template:
374+
return ''
375+
m = re.search('S+', shard_name_template)
376+
if m is None:
377+
raise ValueError("Shard number pattern S+ not found in template '%s'"
378+
% shard_name_template)
379+
shard_name_format = shard_name_template.replace(
380+
m.group(0), '%%(shard_num)0%dd' % len(m.group(0)))
381+
m = re.search('N+', shard_name_format)
382+
if m:
383+
shard_name_format = shard_name_format.replace(
384+
m.group(0), '%%(num_shards)0%dd' % len(m.group(0)))
385+
return shard_name_format
386+
429387
def __eq__(self, other):
430388
# TODO(robertwb): Clean up workitem_test which uses this.
431389
# pylint: disable=unidiomatic-typecheck
@@ -449,16 +407,19 @@ def close(self):
449407
return self.temp_shard_path
450408

451409

452-
class PureTextFileSink(FileSink):
410+
class TextFileSink(FileSink):
453411
"""A sink to a GCS or local text file or files."""
454412

455413
def __init__(self,
456414
file_path_prefix,
457415
file_name_suffix='',
416+
append_trailing_newlines=True,
417+
num_shards=0,
418+
shard_name_template=None,
458419
coder=coders.ToStringCoder(),
459420
compression_type=CompressionTypes.NO_COMPRESSION,
460-
append_trailing_newlines=True):
461-
"""Initialize a PureTextFileSink.
421+
):
422+
"""Initialize a TextFileSink.
462423
463424
Args:
464425
file_path_prefix: The file path to write to. The files written will begin
@@ -467,33 +428,57 @@ def __init__(self,
467428
only this argument is specified and num_shards, shard_name_template, and
468429
file_name_suffix use default values.
469430
file_name_suffix: Suffix for the files written.
470-
coder: Coder used to encode each line.
471-
compression_type: Type of compression to use for this sink.
472431
append_trailing_newlines: indicate whether this sink should write an
473432
additional newline char after writing each element.
433+
num_shards: The number of files (shards) used for output. If not set, the
434+
service will decide on the optimal number of shards.
435+
Constraining the number of shards is likely to reduce
436+
the performance of a pipeline. Setting this value is not recommended
437+
unless you require a specific number of output files.
438+
shard_name_template: A template string containing placeholders for
439+
the shard number and shard count. Currently only '' and
440+
'-SSSSS-of-NNNNN' are patterns accepted by the service.
441+
When constructing a filename for a particular shard number, the
442+
upper-case letters 'S' and 'N' are replaced with the 0-padded shard
443+
number and shard count respectively. This argument can be '' in which
444+
case it behaves as if num_shards was set to 1 and only one file will be
445+
generated. The default pattern used is '-SSSSS-of-NNNNN'.
446+
coder: Coder used to encode each line.
447+
compression_type: Type of compression to use for this sink.
474448
475449
Raises:
476-
TypeError: if file_path is not a string or if compression_type is not
477-
member of CompressionTypes.
450+
TypeError: if file path parameters are not a string or if compression_type
451+
is not member of CompressionTypes.
452+
ValueError: if shard_name_template is not of expected format.
478453
479454
Returns:
480-
A PureTextFileSink object usable for writing.
455+
A TextFileSink object usable for writing.
481456
"""
457+
if not isinstance(file_path_prefix, basestring):
458+
raise TypeError(
459+
'TextFileSink: file_path_prefix must be a string; got %r instead' %
460+
file_path_prefix)
461+
if not isinstance(file_name_suffix, basestring):
462+
raise TypeError(
463+
'TextFileSink: file_name_suffix must be a string; got %r instead' %
464+
file_name_suffix)
465+
482466
if not CompressionTypes.valid_compression_type(compression_type):
483467
raise TypeError('compression_type must be CompressionType object but '
484468
'was %s' % type(compression_type))
485-
486469
if compression_type == CompressionTypes.DEFLATE:
487470
mime_type = 'application/x-gzip'
488471
else:
489472
mime_type = 'text/plain'
490-
self.compression_type = compression_type
491473

492-
super(PureTextFileSink, self).__init__(file_path_prefix,
493-
file_name_suffix=file_name_suffix,
494-
coder=coder,
495-
mime_type=mime_type)
474+
super(TextFileSink, self).__init__(file_path_prefix,
475+
file_name_suffix=file_name_suffix,
476+
num_shards=num_shards,
477+
shard_name_template=shard_name_template,
478+
coder=coder,
479+
mime_type=mime_type)
496480

481+
self.compression_type = compression_type
497482
self.append_trailing_newlines = append_trailing_newlines
498483

499484
def open(self, temp_path):

google/cloud/dataflow/io/fileio_test.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tempfile
2222
import unittest
2323

24+
import google.cloud.dataflow as df
2425
from google.cloud.dataflow import coders
2526
from google.cloud.dataflow.io import fileio
2627
from google.cloud.dataflow.io import iobase
@@ -355,14 +356,14 @@ def _write_lines(self, sink, lines):
355356
sink.close(f)
356357

357358
def test_write_text_file(self):
358-
sink = fileio.PureTextFileSink(self.path)
359+
sink = fileio.TextFileSink(self.path)
359360
self._write_lines(sink, self.lines)
360361

361362
with open(self.path, 'r') as f:
362363
self.assertEqual(f.read().splitlines(), self.lines)
363364

364365
def test_write_gzip_file(self):
365-
sink = fileio.PureTextFileSink(
366+
sink = fileio.TextFileSink(
366367
self.path, compression_type=fileio.CompressionTypes.DEFLATE)
367368
self._write_lines(sink, self.lines)
368369

@@ -373,7 +374,9 @@ def test_write_gzip_file(self):
373374
class MyFileSink(fileio.FileSink):
374375

375376
def open(self, temp_path):
376-
file_handle = super(MyFileSink, self).open(temp_path)
377+
# TODO(robertwb): Fix main session pickling.
378+
# file_handle = super(MyFileSink, self).open(temp_path)
379+
file_handle = fileio.FileSink.open(self, temp_path)
377380
file_handle.write('[start]')
378381
return file_handle
379382

@@ -384,7 +387,9 @@ def write_encoded_record(self, file_handle, encoded_value):
384387

385388
def close(self, file_handle):
386389
file_handle.write('[end]')
387-
file_handle = super(MyFileSink, self).close(file_handle)
390+
# TODO(robertwb): Fix main session pickling.
391+
# file_handle = super(MyFileSink, self).close(file_handle)
392+
file_handle = fileio.FileSink.close(self, file_handle)
388393

389394

390395
class TestFileSink(unittest.TestCase):
@@ -423,6 +428,34 @@ def test_file_sink_writing(self):
423428
# Check that any temp files are deleted.
424429
self.assertEqual([shard1, shard2], sorted(glob.glob(temp_path + '*')))
425430

431+
def test_empty_write(self):
432+
temp_path = tempfile.NamedTemporaryFile().name
433+
sink = MyFileSink(temp_path,
434+
file_name_suffix='.foo',
435+
coder=coders.ToStringCoder())
436+
p = df.Pipeline('DirectPipelineRunner')
437+
p | df.Create([]) | df.io.Write(sink) # pylint: disable=expression-not-assigned
438+
p.run()
439+
440+
self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(),
441+
'[start][end]')
442+
443+
def test_fixed_shard_write(self):
444+
temp_path = tempfile.NamedTemporaryFile().name
445+
sink = MyFileSink(temp_path,
446+
file_name_suffix='.foo',
447+
num_shards=3,
448+
shard_name_template='_NN_SSS_',
449+
coder=coders.ToStringCoder())
450+
p = df.Pipeline('DirectPipelineRunner')
451+
p | df.Create(['a', 'b']) | df.io.Write(sink) # pylint: disable=expression-not-assigned
452+
453+
p.run()
454+
455+
concat = ''.join(open(temp_path + '_03_%03d_.foo' % shard_num).read()
456+
for shard_num in range(3))
457+
self.assertTrue('][a][' in concat, concat)
458+
self.assertTrue('][b][' in concat, concat)
426459

427460
if __name__ == '__main__':
428461
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)