Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 # pylint: disable=line-too-long
     17 """Inputs and Readers.
     18 
     19 See the @{$python/io_ops} guide.
     20 
     21 @@placeholder
     22 @@placeholder_with_default
     23 @@sparse_placeholder
     24 @@ReaderBase
     25 @@TextLineReader
     26 @@WholeFileReader
     27 @@IdentityReader
     28 @@TFRecordReader
     29 @@LMDBReader
     30 @@FixedLengthRecordReader
     31 @@decode_csv
     32 @@decode_raw
     33 @@VarLenFeature
     34 @@FixedLenFeature
     35 @@FixedLenSequenceFeature
     36 @@SparseFeature
     37 @@parse_example
     38 @@parse_single_example
     39 @@parse_tensor
     40 @@serialize_tensor
     41 @@decode_json_example
     42 @@QueueBase
     43 @@FIFOQueue
     44 @@PaddingFIFOQueue
     45 @@RandomShuffleQueue
     46 @@PriorityQueue
     47 @@ConditionalAccumulatorBase
     48 @@ConditionalAccumulator
     49 @@SparseConditionalAccumulator
     50 @@matching_files
     51 @@read_file
     52 @@write_file
     53 @@match_filenames_once
     54 @@limit_epochs
     55 @@input_producer
     56 @@range_input_producer
     57 @@slice_input_producer
     58 @@string_input_producer
     59 @@batch
     60 @@maybe_batch
     61 @@batch_join
     62 @@maybe_batch_join
     63 @@shuffle_batch
     64 @@maybe_shuffle_batch
     65 @@shuffle_batch_join
     66 @@maybe_shuffle_batch_join
     67 """
     68 
     69 from __future__ import absolute_import
     70 from __future__ import division
     71 from __future__ import print_function
     72 
     73 from tensorflow.python.eager import context
     74 from tensorflow.python.framework import dtypes
     75 from tensorflow.python.framework import ops
     76 from tensorflow.python.lib.io import python_io
     77 from tensorflow.python.ops import gen_data_flow_ops
     78 from tensorflow.python.ops import gen_io_ops
     79 # go/tf-wildcard-import
     80 # pylint: disable=wildcard-import
     81 from tensorflow.python.ops.gen_io_ops import *
     82 from tensorflow.python.util.tf_export import tf_export
     83 # pylint: enable=wildcard-import
     84 
     85 
     86 # pylint: disable=protected-access
     87 def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"):
     88   """Save a list of tensors to a file with given names.
     89 
     90   Example usage without slice info:
     91     Save("/foo/bar", ["w", "b"], [w, b])
     92 
     93   Example usage with slices:
     94     Save("/foo/bar", ["w", "w"], [slice0, slice1],
     95          tensor_slices=["4 10 0,2:-", "4 10 2,2:-"])
     96 
     97   Args:
     98     filename: the file name of the sstable.
     99     tensor_names: a list of strings.
    100     tensors: the list of tensors to be saved.
    101     tensor_slices: Optional list of strings to specify the shape and slices of
    102       a larger virtual tensor that each tensor is a part of.  If not specified
    103       each tensor is saved as a full slice.
    104     name: string.  Optional name for the op.
    105 
    106   Requires:
    107     The length of tensors should match the size of tensor_names and of
    108     tensor_slices.
    109 
    110   Returns:
    111     An Operation that saves the tensors.
    112   """
    113   if tensor_slices is None:
    114     return gen_io_ops._save(filename, tensor_names, tensors, name=name)
    115   else:
    116     return gen_io_ops._save_slices(filename, tensor_names, tensor_slices,
    117                                    tensors, name=name)
    118 
    119 
    120 def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
    121                    name="restore_slice", preferred_shard=-1):
    122   """Restore a tensor slice from a set of files with a given pattern.
    123 
    124   Example usage:
    125     RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)
    126 
    127   Args:
    128     file_pattern: the file pattern used to match a set of checkpoint files.
    129     tensor_name: the name of the tensor to restore.
    130     shape_and_slice: the shape-and-slice spec of the slice.
    131     tensor_type: the type of the tensor to restore.
    132     name: string.  Optional name for the op.
    133     preferred_shard: Int. Optional shard to open first in the checkpoint file.
    134 
    135   Returns:
    136     A tensor of type "tensor_type".
    137   """
    138   base_type = dtypes.as_dtype(tensor_type).base_dtype
    139   return gen_io_ops._restore_slice(
    140       file_pattern, tensor_name, shape_and_slice, base_type,
    141       preferred_shard, name=name)
    142 
    143 
    144 @tf_export("ReaderBase")
    145 class ReaderBase(object):
    146   """Base class for different Reader types, that produce a record every step.
    147 
    148   Conceptually, Readers convert string 'work units' into records (key,
    149   value pairs).  Typically the 'work units' are filenames and the
    150   records are extracted from the contents of those files.  We want a
    151   single record produced per step, but a work unit can correspond to
    152   many records.
    153 
    154   Therefore we introduce some decoupling using a queue.  The queue
    155   contains the work units and the Reader dequeues from the queue when
    156   it is asked to produce a record (via Read()) but it has finished the
    157   last work unit.
    158 
    159   @compatibility(eager)
    160   Readers are not compatible with eager execution. Instead, please
    161   use `tf.data` to get data into your model.
    162   @end_compatibility
    163   """
    164 
    165   def __init__(self, reader_ref, supports_serialize=False):
    166     """Creates a new ReaderBase.
    167 
    168     Args:
    169       reader_ref: The operation that implements the reader.
    170       supports_serialize: True if the reader implementation can
    171         serialize its state.
    172 
    173     Raises:
    174       RuntimeError: If eager execution is enabled.
    175     """
    176     if context.in_eager_mode():
    177       raise RuntimeError(
    178           "Readers are not supported when eager execution is enabled. "
    179           "Instead, please use tf.data to get data into your model.")
    180 
    181     self._reader_ref = reader_ref
    182     self._supports_serialize = supports_serialize
    183 
    184   @property
    185   def reader_ref(self):
    186     """Op that implements the reader."""
    187     return self._reader_ref
    188 
    189   def read(self, queue, name=None):
    190     """Returns the next record (key, value) pair produced by a reader.
    191 
    192     Will dequeue a work unit from queue if necessary (e.g. when the
    193     Reader needs to start reading from a new file since it has
    194     finished with the previous file).
    195 
    196     Args:
    197       queue: A Queue or a mutable string Tensor representing a handle
    198         to a Queue, with string work items.
    199       name: A name for the operation (optional).
    200 
    201     Returns:
    202       A tuple of Tensors (key, value).
    203       key: A string scalar Tensor.
    204       value: A string scalar Tensor.
    205     """
    206     if isinstance(queue, ops.Tensor):
    207       queue_ref = queue
    208     else:
    209       queue_ref = queue.queue_ref
    210     if self._reader_ref.dtype == dtypes.resource:
    211       return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name)
    212     else:
    213       # For compatibility with pre-resource queues, create a ref(string) tensor
    214       # which can be looked up as the same queue by a resource manager.
    215       old_queue_op = gen_data_flow_ops._fake_queue(queue_ref)
    216       return gen_io_ops._reader_read(self._reader_ref, old_queue_op, name=name)
    217 
    218   def read_up_to(self, queue, num_records,  # pylint: disable=invalid-name
    219                  name=None):
    220     """Returns up to num_records (key, value) pairs produced by a reader.
    221 
    222     Will dequeue a work unit from queue if necessary (e.g., when the
    223     Reader needs to start reading from a new file since it has
    224     finished with the previous file).
    225     It may return less than num_records even before the last batch.
    226 
    227     Args:
    228       queue: A Queue or a mutable string Tensor representing a handle
    229         to a Queue, with string work items.
    230       num_records: Number of records to read.
    231       name: A name for the operation (optional).
    232 
    233     Returns:
    234       A tuple of Tensors (keys, values).
    235       keys: A 1-D string Tensor.
    236       values: A 1-D string Tensor.
    237     """
    238     if isinstance(queue, ops.Tensor):
    239       queue_ref = queue
    240     else:
    241       queue_ref = queue.queue_ref
    242     if self._reader_ref.dtype == dtypes.resource:
    243       return gen_io_ops._reader_read_up_to_v2(self._reader_ref,
    244                                               queue_ref,
    245                                               num_records,
    246                                               name=name)
    247     else:
    248       # For compatibility with pre-resource queues, create a ref(string) tensor
    249       # which can be looked up as the same queue by a resource manager.
    250       old_queue_op = gen_data_flow_ops._fake_queue(queue_ref)
    251       return gen_io_ops._reader_read_up_to(self._reader_ref,
    252                                            old_queue_op,
    253                                            num_records,
    254                                            name=name)
    255 
    256   def num_records_produced(self, name=None):
    257     """Returns the number of records this reader has produced.
    258 
    259     This is the same as the number of Read executions that have
    260     succeeded.
    261 
    262     Args:
    263       name: A name for the operation (optional).
    264 
    265     Returns:
    266       An int64 Tensor.
    267 
    268     """
    269     if self._reader_ref.dtype == dtypes.resource:
    270       return gen_io_ops._reader_num_records_produced_v2(self._reader_ref,
    271                                                         name=name)
    272     else:
    273       return gen_io_ops._reader_num_records_produced(self._reader_ref,
    274                                                      name=name)
    275 
    276   def num_work_units_completed(self, name=None):
    277     """Returns the number of work units this reader has finished processing.
    278 
    279     Args:
    280       name: A name for the operation (optional).
    281 
    282     Returns:
    283       An int64 Tensor.
    284     """
    285     if self._reader_ref.dtype == dtypes.resource:
    286       return gen_io_ops._reader_num_work_units_completed_v2(self._reader_ref,
    287                                                             name=name)
    288     else:
    289       return gen_io_ops._reader_num_work_units_completed(self._reader_ref,
    290                                                          name=name)
    291 
    292   def serialize_state(self, name=None):
    293     """Produce a string tensor that encodes the state of a reader.
    294 
    295     Not all Readers support being serialized, so this can produce an
    296     Unimplemented error.
    297 
    298     Args:
    299       name: A name for the operation (optional).
    300 
    301     Returns:
    302       A string Tensor.
    303     """
    304     if self._reader_ref.dtype == dtypes.resource:
    305       return gen_io_ops._reader_serialize_state_v2(self._reader_ref, name=name)
    306     else:
    307       return gen_io_ops._reader_serialize_state(self._reader_ref, name=name)
    308 
    309   def restore_state(self, state, name=None):
    310     """Restore a reader to a previously saved state.
    311 
    312     Not all Readers support being restored, so this can produce an
    313     Unimplemented error.
    314 
    315     Args:
    316       state: A string Tensor.
    317         Result of a SerializeState of a Reader with matching type.
    318       name: A name for the operation (optional).
    319 
    320     Returns:
    321       The created Operation.
    322     """
    323     if self._reader_ref.dtype == dtypes.resource:
    324       return gen_io_ops._reader_restore_state_v2(
    325           self._reader_ref, state, name=name)
    326     else:
    327       return gen_io_ops._reader_restore_state(
    328           self._reader_ref, state, name=name)
    329 
    330   @property
    331   def supports_serialize(self):
    332     """Whether the Reader implementation can serialize its state."""
    333     return self._supports_serialize
    334 
    335   def reset(self, name=None):
    336     """Restore a reader to its initial clean state.
    337 
    338     Args:
    339       name: A name for the operation (optional).
    340 
    341     Returns:
    342       The created Operation.
    343     """
    344     if self._reader_ref.dtype == dtypes.resource:
    345       return gen_io_ops._reader_reset_v2(self._reader_ref, name=name)
    346     else:
    347       return gen_io_ops._reader_reset(self._reader_ref, name=name)
    348 
    349 
    350 ops.NotDifferentiable("ReaderRead")
    351 ops.NotDifferentiable("ReaderReadUpTo")
    352 ops.NotDifferentiable("ReaderNumRecordsProduced")
    353 ops.NotDifferentiable("ReaderNumWorkUnitsCompleted")
    354 ops.NotDifferentiable("ReaderSerializeState")
    355 ops.NotDifferentiable("ReaderRestoreState")
    356 ops.NotDifferentiable("ReaderReset")
    357 
    358 
    359 @tf_export("WholeFileReader")
    360 class WholeFileReader(ReaderBase):
    361   """A Reader that outputs the entire contents of a file as a value.
    362 
    363   To use, enqueue filenames in a Queue.  The output of Read will
    364   be a filename (key) and the contents of that file (value).
    365 
    366   See ReaderBase for supported methods.
    367 
    368   @compatibility(eager)
    369   Readers are not compatible with eager execution. Instead, please
    370   use `tf.data` to get data into your model.
    371   @end_compatibility
    372   """
    373 
    374   def __init__(self, name=None):
    375     """Create a WholeFileReader.
    376 
    377     Args:
    378       name: A name for the operation (optional).
    379     """
    380     rr = gen_io_ops._whole_file_reader_v2(name=name)
    381     super(WholeFileReader, self).__init__(rr, supports_serialize=True)
    382 
    383 
    384 ops.NotDifferentiable("WholeFileReader")
    385 
    386 
    387 @tf_export("TextLineReader")
    388 class TextLineReader(ReaderBase):
    389   """A Reader that outputs the lines of a file delimited by newlines.
    390 
    391   Newlines are stripped from the output.
    392   See ReaderBase for supported methods.
    393 
    394   @compatibility(eager)
    395   Readers are not compatible with eager execution. Instead, please
    396   use `tf.data` to get data into your model.
    397   @end_compatibility
    398   """
    399   # TODO(josh11b): Support serializing and restoring state.
    400 
    401   def __init__(self, skip_header_lines=None, name=None):
    402     """Create a TextLineReader.
    403 
    404     Args:
    405       skip_header_lines: An optional int. Defaults to 0.  Number of lines
    406         to skip from the beginning of every file.
    407       name: A name for the operation (optional).
    408     """
    409     rr = gen_io_ops._text_line_reader_v2(skip_header_lines=skip_header_lines,
    410                                          name=name)
    411     super(TextLineReader, self).__init__(rr)
    412 
    413 
    414 ops.NotDifferentiable("TextLineReader")
    415 
    416 
    417 @tf_export("FixedLengthRecordReader")
    418 class FixedLengthRecordReader(ReaderBase):
    419   """A Reader that outputs fixed-length records from a file.
    420 
    421   See ReaderBase for supported methods.
    422 
    423   @compatibility(eager)
    424   Readers are not compatible with eager execution. Instead, please
    425   use `tf.data` to get data into your model.
    426   @end_compatibility
    427   """
    428   # TODO(josh11b): Support serializing and restoring state.
    429 
    430   def __init__(self,
    431                record_bytes,
    432                header_bytes=None,
    433                footer_bytes=None,
    434                hop_bytes=None,
    435                name=None,
    436                encoding=None):
    437     """Create a FixedLengthRecordReader.
    438 
    439     Args:
    440       record_bytes: An int.
    441       header_bytes: An optional int. Defaults to 0.
    442       footer_bytes: An optional int. Defaults to 0.
    443       hop_bytes: An optional int. Defaults to 0.
    444       name: A name for the operation (optional).
    445       encoding: The type of encoding for the file. Defaults to none.
    446     """
    447     rr = gen_io_ops._fixed_length_record_reader_v2(
    448         record_bytes=record_bytes,
    449         header_bytes=header_bytes,
    450         footer_bytes=footer_bytes,
    451         hop_bytes=hop_bytes,
    452         encoding=encoding,
    453         name=name)
    454     super(FixedLengthRecordReader, self).__init__(rr)
    455 
    456 
    457 ops.NotDifferentiable("FixedLengthRecordReader")
    458 
    459 
    460 @tf_export("TFRecordReader")
    461 class TFRecordReader(ReaderBase):
    462   """A Reader that outputs the records from a TFRecords file.
    463 
    464   See ReaderBase for supported methods.
    465 
    466   @compatibility(eager)
    467   Readers are not compatible with eager execution. Instead, please
    468   use `tf.data` to get data into your model.
    469   @end_compatibility
    470   """
    471   # TODO(josh11b): Support serializing and restoring state.
    472 
    473   def __init__(self, name=None, options=None):
    474     """Create a TFRecordReader.
    475 
    476     Args:
    477       name: A name for the operation (optional).
    478       options: A TFRecordOptions object (optional).
    479     """
    480     compression_type = python_io.TFRecordOptions.get_compression_type_string(
    481         options)
    482 
    483     rr = gen_io_ops._tf_record_reader_v2(
    484         name=name, compression_type=compression_type)
    485     super(TFRecordReader, self).__init__(rr)
    486 
    487 
    488 ops.NotDifferentiable("TFRecordReader")
    489 
    490 
    491 @tf_export("LMDBReader")
    492 class LMDBReader(ReaderBase):
    493   """A Reader that outputs the records from a LMDB file.
    494 
    495   See ReaderBase for supported methods.
    496 
    497   @compatibility(eager)
    498   Readers are not compatible with eager execution. Instead, please
    499   use `tf.data` to get data into your model.
    500   @end_compatibility
    501   """
    502   def __init__(self, name=None, options=None):
    503     """Create a LMDBReader.
    504 
    505     Args:
    506       name: A name for the operation (optional).
    507       options: A LMDBRecordOptions object (optional).
    508     """
    509     rr = gen_io_ops._lmdb_reader(name=name)
    510     super(LMDBReader, self).__init__(rr)
    511 
    512 
    513 ops.NotDifferentiable("LMDBReader")
    514 
    515 
    516 @tf_export("IdentityReader")
    517 class IdentityReader(ReaderBase):
    518   """A Reader that outputs the queued work as both the key and value.
    519 
    520   To use, enqueue strings in a Queue.  Read will take the front
    521   work string and output (work, work).
    522 
    523   See ReaderBase for supported methods.
    524 
    525   @compatibility(eager)
    526   Readers are not compatible with eager execution. Instead, please
    527   use `tf.data` to get data into your model.
    528   @end_compatibility
    529   """
    530 
    531   def __init__(self, name=None):
    532     """Create a IdentityReader.
    533 
    534     Args:
    535       name: A name for the operation (optional).
    536     """
    537     rr = gen_io_ops._identity_reader_v2(name=name)
    538     super(IdentityReader, self).__init__(rr, supports_serialize=True)
    539 
    540 
    541 ops.NotDifferentiable("IdentityReader")
    542