|
17 | 17 | from __future__ import absolute_import |
18 | 18 |
|
19 | 19 | import glob |
| 20 | +import gzip |
20 | 21 | import logging |
21 | 22 | import os |
22 | 23 | import re |
@@ -281,6 +282,31 @@ def rmdir(path): |
281 | 282 | except OSError as err: |
282 | 283 | raise IOError(err) |
283 | 284 |
|
| 285 | + @staticmethod |
| 286 | + def glob(path): |
| 287 | + if path.startswith('gs://'): |
| 288 | + # pylint: disable=g-import-not-at-top |
| 289 | + from google.cloud.dataflow.io import gcsio |
| 290 | + return gcsio.GcsIO().glob(path) |
| 291 | + else: |
| 292 | + return glob.glob(path) |
| 293 | + |
| 294 | + |
| 295 | +class _CompressionType(object): |
| 296 | + """Object representing single compression type.""" |
| 297 | + |
| 298 | + def __init__(self, identifier): |
| 299 | + self.identifier = identifier |
| 300 | + |
| 301 | + def __eq__(self, other): |
| 302 | + return self.identifier == other.identifier |
| 303 | + |
| 304 | + |
| 305 | +class CompressionTypes(object): |
| 306 | + """Enum-like class representing known compression types.""" |
| 307 | + NO_COMPRESSION = _CompressionType(1) # No compression. |
| 308 | + DEFLATE = _CompressionType(2) # 'Deflate' ie gzip compression. |
| 309 | + |
284 | 310 |
|
285 | 311 | class FileSink(iobase.Sink): |
286 | 312 | """A sink to a GCS or local files. |
@@ -396,13 +422,53 @@ def __init__(self, |
396 | 422 | file_path_prefix, |
397 | 423 | file_name_suffix='', |
398 | 424 | coder=coders.ToStringCoder(), |
| 425 | + compression_type=CompressionTypes.NO_COMPRESSION, |
399 | 426 | append_trailing_newlines=True): |
| 427 | + """Initialize a PureTextFileSink. |
| 428 | +
|
| 429 | + Args: |
| 430 | + file_path_prefix: The file path to write to. The files written will begin |
| 431 | + with this prefix, followed by a shard identifier (see num_shards), and |
| 432 | + end in a common extension, if given by file_name_suffix. In most cases, |
| 433 | + only this argument is specified and num_shards, shard_name_template, and |
| 434 | + file_name_suffix use default values. |
| 435 | + file_name_suffix: Suffix for the files written. |
| 436 | + coder: Coder used to encode each line. |
| 437 | + compression_type: Type of compression to use for this sink. |
| 438 | + append_trailing_newlines: indicate whether this sink should write an |
| 439 | + additional newline char after writing each element. |
| 440 | +
|
| 441 | + Raises: |
| 442 | + TypeError: if file_path is not a string or if compression_type is not |
| 443 | + member of CompressionTypes. |
| 444 | +
|
| 445 | + Returns: |
| 446 | + A PureTextFileSink object usable for writing. |
| 447 | + """ |
| 448 | + if not isinstance(compression_type, _CompressionType): |
| 449 | + raise TypeError('compression_type must be CompressionType object but ' |
| 450 | + 'was %s' % type(compression_type)) |
| 451 | + |
| 452 | + if compression_type == CompressionTypes.DEFLATE: |
| 453 | + mime_type = 'application/x-gzip' |
| 454 | + else: |
| 455 | + mime_type = 'text/plain' |
| 456 | + self.compression_type = compression_type |
| 457 | + |
400 | 458 | super(PureTextFileSink, self).__init__(file_path_prefix, |
401 | 459 | file_name_suffix=file_name_suffix, |
402 | 460 | coder=coder, |
403 | | - mime_type='text/plain') |
| 461 | + mime_type=mime_type) |
| 462 | + |
404 | 463 | self.append_trailing_newlines = append_trailing_newlines |
405 | 464 |
|
| 465 | + def open(self, temp_path): |
| 466 | + """Opens ''temp_path'', returning a writeable file object.""" |
| 467 | + fobj = ChannelFactory.open(temp_path, 'wb', self.mime_type) |
| 468 | + if self.compression_type == CompressionTypes.DEFLATE: |
| 469 | + return gzip.GzipFile(fileobj=fobj) |
| 470 | + return fobj |
| 471 | + |
406 | 472 | def write_encoded_record(self, file_handle, encoded_value): |
407 | 473 | file_handle.write(encoded_value) |
408 | 474 | if self.append_trailing_newlines: |
@@ -569,12 +635,7 @@ class TextMultiFileReader(iobase.NativeSourceReader): |
569 | 635 |
|
570 | 636 | def __init__(self, source): |
571 | 637 | self.source = source |
572 | | - if source.is_gcs_source: |
573 | | - # pylint: disable=g-import-not-at-top |
574 | | - from google.cloud.dataflow.io import gcsio |
575 | | - self.file_paths = gcsio.GcsIO().glob(self.source.file_path) |
576 | | - else: |
577 | | - self.file_paths = glob.glob(self.source.file_path) |
| 638 | + self.file_paths = ChannelFactory.glob(self.source.file_path) |
578 | 639 | if not self.file_paths: |
579 | 640 | raise RuntimeError( |
580 | 641 | 'No files found for path: %s' % self.source.file_path) |
|
0 commit comments