Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python wrappers for reader Datasets."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops.dataset_ops import Dataset
     21 from tensorflow.python.data.util import convert
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_shape
     25 from tensorflow.python.ops import gen_dataset_ops
     26 from tensorflow.python.util.tf_export import tf_export
     27 
     28 
     29 # TODO(b/64974358): Increase default buffer size to 256 MB.
     30 _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024  # 256 KB
     31 
     32 
     33 @tf_export("data.TextLineDataset")
     34 class TextLineDataset(Dataset):
     35   """A `Dataset` comprising lines from one or more text files."""
     36 
     37   def __init__(self, filenames, compression_type=None, buffer_size=None):
     38     """Creates a `TextLineDataset`.
     39 
     40     Args:
     41       filenames: A `tf.string` tensor containing one or more filenames.
     42       compression_type: (Optional.) A `tf.string` scalar evaluating to one of
     43         `""` (no compression), `"ZLIB"`, or `"GZIP"`.
     44       buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
     45         to buffer. A value of 0 results in the default buffering values chosen
     46         based on the compression type.
     47     """
     48     super(TextLineDataset, self).__init__()
     49     self._filenames = ops.convert_to_tensor(
     50         filenames, dtype=dtypes.string, name="filenames")
     51     self._compression_type = convert.optional_param_to_tensor(
     52         "compression_type",
     53         compression_type,
     54         argument_default="",
     55         argument_dtype=dtypes.string)
     56     self._buffer_size = convert.optional_param_to_tensor(
     57         "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
     58 
     59   def _as_variant_tensor(self):
     60     return gen_dataset_ops.text_line_dataset(
     61         self._filenames, self._compression_type, self._buffer_size)
     62 
     63   @property
     64   def output_classes(self):
     65     return ops.Tensor
     66 
     67   @property
     68   def output_shapes(self):
     69     return tensor_shape.scalar()
     70 
     71   @property
     72   def output_types(self):
     73     return dtypes.string
     74 
     75 
     76 @tf_export("data.TFRecordDataset")
     77 class TFRecordDataset(Dataset):
     78   """A `Dataset` comprising records from one or more TFRecord files."""
     79 
     80   def __init__(self, filenames, compression_type=None, buffer_size=None):
     81     """Creates a `TFRecordDataset`.
     82 
     83     Args:
     84       filenames: A `tf.string` tensor containing one or more filenames.
     85       compression_type: (Optional.) A `tf.string` scalar evaluating to one of
     86         `""` (no compression), `"ZLIB"`, or `"GZIP"`.
     87       buffer_size: (Optional.) A `tf.int64` scalar representing the number of
     88         bytes in the read buffer. 0 means no buffering.
     89     """
     90     super(TFRecordDataset, self).__init__()
     91     # Force the type to string even if filenames is an empty list.
     92     self._filenames = ops.convert_to_tensor(
     93         filenames, dtypes.string, name="filenames")
     94     self._compression_type = convert.optional_param_to_tensor(
     95         "compression_type",
     96         compression_type,
     97         argument_default="",
     98         argument_dtype=dtypes.string)
     99     self._buffer_size = convert.optional_param_to_tensor(
    100         "buffer_size",
    101         buffer_size,
    102         argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
    103 
    104   def _as_variant_tensor(self):
    105     return gen_dataset_ops.tf_record_dataset(
    106         self._filenames, self._compression_type, self._buffer_size)
    107 
    108   @property
    109   def output_classes(self):
    110     return ops.Tensor
    111 
    112   @property
    113   def output_shapes(self):
    114     return tensor_shape.TensorShape([])
    115 
    116   @property
    117   def output_types(self):
    118     return dtypes.string
    119 
    120 
    121 @tf_export("data.FixedLengthRecordDataset")
    122 class FixedLengthRecordDataset(Dataset):
    123   """A `Dataset` of fixed-length records from one or more binary files."""
    124 
    125   def __init__(self,
    126                filenames,
    127                record_bytes,
    128                header_bytes=None,
    129                footer_bytes=None,
    130                buffer_size=None):
    131     """Creates a `FixedLengthRecordDataset`.
    132 
    133     Args:
    134       filenames: A `tf.string` tensor containing one or more filenames.
    135       record_bytes: A `tf.int64` scalar representing the number of bytes in
    136         each record.
    137       header_bytes: (Optional.) A `tf.int64` scalar representing the number of
    138         bytes to skip at the start of a file.
    139       footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
    140         bytes to ignore at the end of a file.
    141       buffer_size: (Optional.) A `tf.int64` scalar representing the number of
    142         bytes to buffer when reading.
    143     """
    144     super(FixedLengthRecordDataset, self).__init__()
    145     self._filenames = ops.convert_to_tensor(
    146         filenames, dtype=dtypes.string, name="filenames")
    147     self._record_bytes = ops.convert_to_tensor(
    148         record_bytes, dtype=dtypes.int64, name="record_bytes")
    149 
    150     self._header_bytes = convert.optional_param_to_tensor(
    151         "header_bytes", header_bytes)
    152     self._footer_bytes = convert.optional_param_to_tensor(
    153         "footer_bytes", footer_bytes)
    154     self._buffer_size = convert.optional_param_to_tensor(
    155         "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
    156 
    157   def _as_variant_tensor(self):
    158     return gen_dataset_ops.fixed_length_record_dataset(
    159         self._filenames, self._header_bytes, self._record_bytes,
    160         self._footer_bytes, self._buffer_size)
    161 
    162   @property
    163   def output_classes(self):
    164     return ops.Tensor
    165 
    166   @property
    167   def output_shapes(self):
    168     return tensor_shape.scalar()
    169 
    170   @property
    171   def output_types(self):
    172     return dtypes.string
    173