Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Enumerate dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.framework import dtypes
     24 
     25 
     26 def enumerate_dataset(start=0):
     27   """A transformation that enumerate the elements of a dataset.
     28 
     29   It is Similar to python's `enumerate`.
     30   For example:
     31 
     32   ```python
     33   # NOTE: The following examples use `{ ... }` to represent the
     34   # contents of a dataset.
     35   a = { 1, 2, 3 }
     36   b = { (7, 8), (9, 10) }
     37 
     38   # The nested structure of the `datasets` argument determines the
     39   # structure of elements in the resulting dataset.
     40   a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
     41   b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
     42   ```
     43 
     44   Args:
     45     start: A `tf.int64` scalar `tf.Tensor`, representing the start
     46       value for enumeration.
     47 
     48   Returns:
     49     A `Dataset` transformation function, which can be passed to
     50     @{tf.data.Dataset.apply}.
     51   """
     52 
     53   def _apply_fn(dataset):
     54     max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
     55     return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
     56                                     dataset))
     57 
     58   return _apply_fn
     59