1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Ignore_errors dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import 21 from tensorflow.contrib.data.python.ops import gen_dataset_ops 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.data.util import nest 24 from tensorflow.python.data.util import sparse 25 26 27 def ignore_errors(): 28 """Creates a `Dataset` from another `Dataset` and silently ignores any errors. 29 30 Use this transformation to produce a dataset that contains the same elements 31 as the input, but silently drops any elements that caused an error. For 32 example: 33 34 ```python 35 dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) 36 37 # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. 38 dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) 39 40 # Using `ignore_errors()` will drop the element that causes an error. 41 dataset = 42 dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.2 } 43 ``` 44 45 Returns: 46 A `Dataset` transformation function, which can be passed to 47 @{tf.data.Dataset.apply}. 48 """ 49 50 def _apply_fn(dataset): 51 return IgnoreErrorsDataset(dataset) 52 53 return _apply_fn 54 55 56 class IgnoreErrorsDataset(dataset_ops.Dataset): 57 """A `Dataset` that silently ignores errors when computing its input.""" 58 59 def __init__(self, input_dataset): 60 """See `Dataset.ignore_errors()` for details.""" 61 super(IgnoreErrorsDataset, self).__init__() 62 self._input_dataset = input_dataset 63 64 def _as_variant_tensor(self): 65 return gen_dataset_ops.ignore_errors_dataset( 66 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 67 output_shapes=nest.flatten( 68 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 69 output_types=nest.flatten( 70 sparse.as_dense_types(self.output_types, self.output_classes))) 71 72 @property 73 def output_classes(self): 74 return self._input_dataset.output_classes 75 76 @property 77 def output_shapes(self): 78 return self._input_dataset.output_shapes 79 80 @property 81 def output_types(self): 82 return self._input_dataset.output_types 83