Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Ignore_errors dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib.data.python.ops import contrib_op_loader  # pylint: disable=unused-import
     21 from tensorflow.contrib.data.python.ops import gen_dataset_ops
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.data.util import nest
     24 from tensorflow.python.data.util import sparse
     25 
     26 
     27 def ignore_errors():
     28   """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
     29 
     30   Use this transformation to produce a dataset that contains the same elements
     31   as the input, but silently drops any elements that caused an error. For
     32   example:
     33 
     34   ```python
     35   dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
     36 
     37   # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
     38   dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
     39 
     40   # Using `ignore_errors()` will drop the element that causes an error.
     41   dataset =
     42       dataset.apply(tf.contrib.data.ignore_errors())  # ==> { 1., 0.5, 0.2 }
     43   ```
     44 
     45   Returns:
     46     A `Dataset` transformation function, which can be passed to
     47     @{tf.data.Dataset.apply}.
     48   """
     49 
     50   def _apply_fn(dataset):
     51     return IgnoreErrorsDataset(dataset)
     52 
     53   return _apply_fn
     54 
     55 
     56 class IgnoreErrorsDataset(dataset_ops.Dataset):
     57   """A `Dataset` that silently ignores errors when computing its input."""
     58 
     59   def __init__(self, input_dataset):
     60     """See `Dataset.ignore_errors()` for details."""
     61     super(IgnoreErrorsDataset, self).__init__()
     62     self._input_dataset = input_dataset
     63 
     64   def _as_variant_tensor(self):
     65     return gen_dataset_ops.ignore_errors_dataset(
     66         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
     67         output_shapes=nest.flatten(
     68             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
     69         output_types=nest.flatten(
     70             sparse.as_dense_types(self.output_types, self.output_classes)))
     71 
     72   @property
     73   def output_classes(self):
     74     return self._input_dataset.output_classes
     75 
     76   @property
     77   def output_shapes(self):
     78     return self._input_dataset.output_shapes
     79 
     80   @property
     81   def output_types(self):
     82     return self._input_dataset.output_types
     83