Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unique element dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.ops import gen_dataset_ops
     25 
     26 
     27 def unique():
     28   """Creates a `Dataset` from another `Dataset`, discarding duplicates.
     29 
     30   Use this transformation to produce a dataset that contains one instance of
     31   each unique element in the input. For example:
     32 
     33   ```python
     34   dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
     35 
     36   # Using `unique()` will drop the duplicate elements.
     37   dataset = dataset.apply(tf.contrib.data.unique())  # ==> { 1, 37, 2 }
     38   ```
     39 
     40   Returns:
     41     A `Dataset` transformation function, which can be passed to
     42     @{tf.data.Dataset.apply}.
     43   """
     44 
     45   def _apply_fn(dataset):
     46     return UniqueDataset(dataset)
     47 
     48   return _apply_fn
     49 
     50 
     51 class UniqueDataset(dataset_ops.Dataset):
     52   """A `Dataset` contains the unique elements from its input."""
     53 
     54   def __init__(self, input_dataset):
     55     """See `unique()` for details."""
     56     super(UniqueDataset, self).__init__()
     57     self._input_dataset = input_dataset
     58     if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
     59                                           dtypes.string):
     60       raise TypeError(
     61           "`tf.contrib.data.unique()` only supports inputs with a single "
     62           "`tf.int32`, `tf.int64`, or `tf.string` component.")
     63 
     64   def _as_variant_tensor(self):
     65     return gen_dataset_ops.unique_dataset(
     66         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
     67         output_shapes=nest.flatten(
     68             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
     69         output_types=nest.flatten(
     70             sparse.as_dense_types(self.output_types, self.output_classes)))
     71 
     72   @property
     73   def output_classes(self):
     74     return self._input_dataset.output_classes
     75 
     76   @property
     77   def output_shapes(self):
     78     return self._input_dataset.output_shapes
     79 
     80   @property
     81   def output_types(self):
     82     return self._input_dataset.output_types
     83