Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit b051312

Browse files
sbilacaaltay
authored andcommitted
Enable gzip compression on text files sink.
----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=122491595
1 parent a724abb commit b051312

2 files changed

Lines changed: 97 additions & 7 deletions

File tree

google/cloud/dataflow/io/fileio.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import absolute_import
1818

1919
import glob
20+
import gzip
2021
import logging
2122
import os
2223
import re
@@ -281,6 +282,31 @@ def rmdir(path):
281282
except OSError as err:
282283
raise IOError(err)
283284

285+
@staticmethod
286+
def glob(path):
287+
if path.startswith('gs://'):
288+
# pylint: disable=g-import-not-at-top
289+
from google.cloud.dataflow.io import gcsio
290+
return gcsio.GcsIO().glob(path)
291+
else:
292+
return glob.glob(path)
293+
294+
295+
class _CompressionType(object):
296+
"""Object representing single compression type."""
297+
298+
def __init__(self, identifier):
299+
self.identifier = identifier
300+
301+
def __eq__(self, other):
302+
return self.identifier == other.identifier
303+
304+
305+
class CompressionTypes(object):
306+
"""Enum-like class representing known compression types."""
307+
NO_COMPRESSION = _CompressionType(1) # No compression.
308+
DEFLATE = _CompressionType(2) # 'Deflate' ie gzip compression.
309+
284310

285311
class FileSink(iobase.Sink):
286312
"""A sink to a GCS or local files.
@@ -396,13 +422,53 @@ def __init__(self,
396422
file_path_prefix,
397423
file_name_suffix='',
398424
coder=coders.ToStringCoder(),
425+
compression_type=CompressionTypes.NO_COMPRESSION,
399426
append_trailing_newlines=True):
427+
"""Initialize a PureTextFileSink.
428+
429+
Args:
430+
file_path_prefix: The file path to write to. The files written will begin
431+
with this prefix, followed by a shard identifier (see num_shards), and
432+
end in a common extension, if given by file_name_suffix. In most cases,
433+
only this argument is specified and num_shards, shard_name_template, and
434+
file_name_suffix use default values.
435+
file_name_suffix: Suffix for the files written.
436+
coder: Coder used to encode each line.
437+
compression_type: Type of compression to use for this sink.
438+
append_trailing_newlines: indicate whether this sink should write an
439+
additional newline char after writing each element.
440+
441+
Raises:
442+
TypeError: if file_path is not a string or if compression_type is not
443+
member of CompressionTypes.
444+
445+
Returns:
446+
A PureTextFileSink object usable for writing.
447+
"""
448+
if not isinstance(compression_type, _CompressionType):
449+
raise TypeError('compression_type must be CompressionType object but '
450+
'was %s' % type(compression_type))
451+
452+
if compression_type == CompressionTypes.DEFLATE:
453+
mime_type = 'application/x-gzip'
454+
else:
455+
mime_type = 'text/plain'
456+
self.compression_type = compression_type
457+
400458
super(PureTextFileSink, self).__init__(file_path_prefix,
401459
file_name_suffix=file_name_suffix,
402460
coder=coder,
403-
mime_type='text/plain')
461+
mime_type=mime_type)
462+
404463
self.append_trailing_newlines = append_trailing_newlines
405464

465+
def open(self, temp_path):
466+
"""Opens ''temp_path'', returning a writeable file object."""
467+
fobj = ChannelFactory.open(temp_path, 'wb', self.mime_type)
468+
if self.compression_type == CompressionTypes.DEFLATE:
469+
return gzip.GzipFile(fileobj=fobj)
470+
return fobj
471+
406472
def write_encoded_record(self, file_handle, encoded_value):
407473
file_handle.write(encoded_value)
408474
if self.append_trailing_newlines:
@@ -569,12 +635,7 @@ class TextMultiFileReader(iobase.NativeSourceReader):
569635

570636
def __init__(self, source):
571637
self.source = source
572-
if source.is_gcs_source:
573-
# pylint: disable=g-import-not-at-top
574-
from google.cloud.dataflow.io import gcsio
575-
self.file_paths = gcsio.GcsIO().glob(self.source.file_path)
576-
else:
577-
self.file_paths = glob.glob(self.source.file_path)
638+
self.file_paths = ChannelFactory.glob(self.source.file_path)
578639
if not self.file_paths:
579640
raise RuntimeError(
580641
'No files found for path: %s' % self.source.file_path)

google/cloud/dataflow/io/fileio_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Unit tests for local and GCS sources and sinks."""
1717

1818
import glob
19+
import gzip
1920
import logging
2021
import tempfile
2122
import unittest
@@ -341,6 +342,34 @@ def test_write_entire_file(self):
341342
self.assertEqual(f.read().splitlines(), lines)
342343

343344

345+
class TestPureTextFileSink(unittest.TestCase):
346+
347+
def setUp(self):
348+
self.lines = ['Line %d' % d for d in range(100)]
349+
self.path = tempfile.NamedTemporaryFile().name
350+
351+
def _write_lines(self, sink, lines):
352+
f = sink.open(self.path)
353+
for line in lines:
354+
sink.write_record(f, line)
355+
sink.close(f)
356+
357+
def test_write_text_file(self):
358+
sink = fileio.PureTextFileSink(self.path)
359+
self._write_lines(sink, self.lines)
360+
361+
with open(self.path, 'r') as f:
362+
self.assertEqual(f.read().splitlines(), self.lines)
363+
364+
def test_write_gzip_file(self):
365+
sink = fileio.PureTextFileSink(
366+
self.path, compression_type=fileio.CompressionTypes.DEFLATE)
367+
self._write_lines(sink, self.lines)
368+
369+
with gzip.GzipFile(self.path, 'r') as f:
370+
self.assertEqual(f.read().splitlines(), self.lines)
371+
372+
344373
class MyFileSink(fileio.FileSink):
345374

346375
def open(self, temp_path):

0 commit comments

Comments
 (0)