Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit c2c1bc7

Browse files
committed
Use multiple file rename threads in finalize_write
----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=124755869
1 parent 4e38084 commit c2c1bc7

2 files changed

Lines changed: 111 additions & 9 deletions

File tree

google/cloud/dataflow/io/fileio.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import glob
2020
import gzip
2121
import logging
22+
from multiprocessing.pool import ThreadPool
2223
import os
2324
import re
2425
import shutil
@@ -285,6 +286,13 @@ class FileSink(iobase.Sink):
285286
The output of this write is a PCollection of all written shards.
286287
"""
287288

289+
# Approximate number of write results be assigned for each rename thread.
290+
_WRITE_RESULTS_PER_RENAME_THREAD = 100
291+
292+
# Max number of threads to be used for renaming even if it means each thread
293+
# will process more write results.
294+
_MAX_RENAME_THREADS = 64
295+
288296
def __init__(self,
289297
file_path_prefix,
290298
coder,
@@ -346,22 +354,56 @@ def open_writer(self, init_result, uid):
346354
def finalize_write(self, init_result, writer_results):
347355
writer_results = sorted(writer_results)
348356
num_shards = len(writer_results)
349-
# TODO(robertwb): Threadpool?
350357
channel_factory = ChannelFactory()
358+
num_threads = max(1, min(
359+
num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD,
360+
FileSink._MAX_RENAME_THREADS))
361+
362+
rename_ops = []
351363
for shard_num, shard in enumerate(writer_results):
352364
final_name = ''.join([
353365
self.file_path_prefix,
354366
self.shard_name_format % dict(shard_num=shard_num,
355367
num_shards=num_shards),
356368
self.file_name_suffix])
369+
rename_ops.append((shard, final_name))
370+
371+
logging.info(
372+
'Starting finalize_write threads with num_shards: %d, num_threads: %d',
373+
num_shards, num_threads)
374+
start_time = time.time()
375+
376+
# Use a thread pool for renaming operations.
377+
def _rename_file(rename_op):
378+
"""_rename_file executes single (old_name, new_name) rename operation."""
379+
old_name, final_name = rename_op
357380
try:
358-
channel_factory.rename(shard, final_name)
359-
except IOError:
381+
channel_factory.rename(old_name, final_name)
382+
except IOError as e:
360383
# May have already been copied.
361-
print shard, final_name, os.path.exists(final_name)
362-
if not channel_factory.exists(final_name):
363-
raise
364-
yield final_name
384+
exists = channel_factory.exists(final_name)
385+
if not exists:
386+
logging.warning(('IOError in _rename_file. old_name: %s, '
387+
'final_name: %s, err: %s'), old_name, final_name, e)
388+
return(None, e)
389+
except Exception as e: # pylint: disable=broad-except
390+
logging.warning(('Exception in _rename_file. old_name: %s, '
391+
'final_name: %s, err: %s'), old_name, final_name, e)
392+
return(None, e)
393+
return (final_name, None)
394+
395+
rename_results = ThreadPool(num_threads).map(_rename_file, rename_ops)
396+
397+
for final_name, err in rename_results:
398+
if err:
399+
logging.warning('Error when processing rename_results: %s', err)
400+
raise err
401+
else:
402+
yield final_name
403+
404+
logging.info('Renamed %d shards in %.2f seconds.',
405+
num_shards, time.time() - start_time)
406+
365407
try:
366408
channel_factory.rmdir(init_result)
367409
except IOError:

google/cloud/dataflow/io/fileio_test.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import glob
1919
import gzip
2020
import logging
21+
import os
2122
import tempfile
2223
import unittest
2324

@@ -426,7 +427,7 @@ def test_file_sink_writing(self):
426427
self.assertEqual(open(shard2).read(), '[start][x][y][z][end]')
427428

428429
# Check that any temp files are deleted.
429-
self.assertEqual([shard1, shard2], sorted(glob.glob(temp_path + '*')))
430+
self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*'))
430431

431432
def test_empty_write(self):
432433
temp_path = tempfile.NamedTemporaryFile().name
@@ -436,7 +437,6 @@ def test_empty_write(self):
436437
p = df.Pipeline('DirectPipelineRunner')
437438
p | df.Create([]) | df.io.Write(sink) # pylint: disable=expression-not-assigned
438439
p.run()
439-
440440
self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(),
441441
'[start][end]')
442442

@@ -457,6 +457,66 @@ def test_fixed_shard_write(self):
457457
self.assertTrue('][a][' in concat, concat)
458458
self.assertTrue('][b][' in concat, concat)
459459

460+
def test_file_sink_multi_shards(self):
461+
temp_path = tempfile.NamedTemporaryFile().name
462+
sink = MyFileSink(temp_path,
463+
file_name_suffix='.foo',
464+
coder=coders.ToStringCoder())
465+
466+
# Manually invoke the generic Sink API.
467+
init_token = sink.initialize_write()
468+
469+
num_shards = 1000
470+
writer_results = []
471+
for i in range(num_shards):
472+
uuid = 'uuid-%05d' % i
473+
writer = sink.open_writer(init_token, uuid)
474+
writer.write('a')
475+
writer.write('b')
476+
writer.write(uuid)
477+
writer_results.append(writer.close())
478+
479+
res_first = list(sink.finalize_write(init_token, writer_results))
480+
# Retry the finalize operation (as if the first attempt was lost).
481+
res_second = list(sink.finalize_write(init_token, writer_results))
482+
483+
self.assertItemsEqual(res_first, res_second)
484+
485+
res = sorted(res_second)
486+
for i in range(num_shards):
487+
shard_name = '%s-%05d-of-%05d.foo' % (temp_path, i, num_shards)
488+
uuid = 'uuid-%05d' % i
489+
self.assertEqual(res[i], shard_name)
490+
self.assertEqual(
491+
open(shard_name).read(), ('[start][a][b][%s][end]' % uuid))
492+
493+
# Check that any temp files are deleted.
494+
self.assertItemsEqual(res, glob.glob(temp_path + '*'))
495+
496+
def test_file_sink_io_error(self):
497+
temp_path = tempfile.NamedTemporaryFile().name
498+
sink = MyFileSink(temp_path,
499+
file_name_suffix='.foo',
500+
coder=coders.ToStringCoder())
501+
502+
# Manually invoke the generic Sink API.
503+
init_token = sink.initialize_write()
504+
505+
writer1 = sink.open_writer(init_token, '1')
506+
writer1.write('a')
507+
writer1.write('b')
508+
res1 = writer1.close()
509+
510+
writer2 = sink.open_writer(init_token, '2')
511+
writer2.write('x')
512+
writer2.write('y')
513+
writer2.write('z')
514+
res2 = writer2.close()
515+
516+
os.remove(res2)
517+
with self.assertRaises(IOError):
518+
list(sink.finalize_write(init_token, [res1, res2]))
519+
460520
if __name__ == '__main__':
461521
logging.getLogger().setLevel(logging.INFO)
462522
unittest.main()

0 commit comments

Comments
 (0)