1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unique element dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.data.util import sparse 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.ops import gen_dataset_ops 25 26 27 def unique(): 28 """Creates a `Dataset` from another `Dataset`, discarding duplicates. 29 30 Use this transformation to produce a dataset that contains one instance of 31 each unique element in the input. For example: 32 33 ```python 34 dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) 35 36 # Using `unique()` will drop the duplicate elements. 37 dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 } 38 ``` 39 40 Returns: 41 A `Dataset` transformation function, which can be passed to 42 @{tf.data.Dataset.apply}. 43 """ 44 45 def _apply_fn(dataset): 46 return UniqueDataset(dataset) 47 48 return _apply_fn 49 50 51 class UniqueDataset(dataset_ops.Dataset): 52 """A `Dataset` contains the unique elements from its input.""" 53 54 def __init__(self, input_dataset): 55 """See `unique()` for details.""" 56 super(UniqueDataset, self).__init__() 57 self._input_dataset = input_dataset 58 if input_dataset.output_types not in (dtypes.int32, dtypes.int64, 59 dtypes.string): 60 raise TypeError( 61 "`tf.contrib.data.unique()` only supports inputs with a single " 62 "`tf.int32`, `tf.int64`, or `tf.string` component.") 63 64 def _as_variant_tensor(self): 65 return gen_dataset_ops.unique_dataset( 66 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 67 output_shapes=nest.flatten( 68 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 69 output_types=nest.flatten( 70 sparse.as_dense_types(self.output_types, self.output_classes))) 71 72 @property 73 def output_classes(self): 74 return self._input_dataset.output_classes 75 76 @property 77 def output_shapes(self): 78 return self._input_dataset.output_shapes 79 80 @property 81 def output_types(self): 82 return self._input_dataset.output_types 83