1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/dataset.h" 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/lib/random/random.h" 19 20 namespace tensorflow { 21 22 namespace { 23 24 // See documentation in ../ops/dataset_ops.cc for a high-level 25 // description of the following op. 26 27 class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { 28 public: 29 explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx) 30 : UnaryDatasetOpKernel(ctx) {} 31 32 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 33 DatasetBase** output) override { 34 *output = new Dataset(ctx, input); 35 } 36 37 private: 38 class Dataset : public GraphDatasetBase { 39 public: 40 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) 41 : GraphDatasetBase(ctx), input_(input) { 42 input_->Ref(); 43 } 44 45 ~Dataset() override { input_->Unref(); } 46 47 std::unique_ptr<IteratorBase> MakeIterator( 48 const string& prefix) const override { 49 return std::unique_ptr<IteratorBase>( 50 new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); 51 } 52 53 const DataTypeVector& output_dtypes() const override { 54 return input_->output_dtypes(); 55 } 56 const std::vector<PartialTensorShape>& output_shapes() const override { 57 return input_->output_shapes(); 58 } 59 60 string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } 61 62 protected: 63 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 64 Node** output) const override { 65 Node* input_graph_node = nullptr; 66 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); 67 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); 68 return Status::OK(); 69 } 70 71 private: 72 class Iterator : public DatasetIterator<Dataset> { 73 public: 74 explicit Iterator(const Params& params) 75 : DatasetIterator<Dataset>(params), 76 input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} 77 78 Status GetNextInternal(IteratorContext* ctx, 79 std::vector<Tensor>* out_tensors, 80 bool* end_of_sequence) override { 81 { 82 tf_shared_lock l(mu_); 83 if (!input_impl_) { 84 *end_of_sequence = true; 85 return Status::OK(); 86 } 87 Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); 88 while (!s.ok()) { 89 out_tensors->clear(); 90 s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); 91 } 92 } 93 if (*end_of_sequence) { 94 mutex_lock l(mu_); 95 input_impl_.reset(); 96 } 97 return Status::OK(); 98 } 99 100 protected: 101 Status SaveInternal(IteratorStateWriter* writer) override { 102 mutex_lock l(mu_); 103 if (input_impl_) 104 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 105 else 106 TF_RETURN_IF_ERROR( 107 writer->WriteScalar(full_name("input_impls_empty"), "")); 108 return Status::OK(); 109 } 110 111 Status RestoreInternal(IteratorContext* ctx, 112 IteratorStateReader* reader) override { 113 mutex_lock l(mu_); 114 if (reader->Contains(full_name("input_impls_empty"))) 115 input_impl_.reset(); 116 else 117 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 118 return Status::OK(); 119 } 120 121 private: 122 mutex mu_; 123 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 124 }; 125 126 const DatasetBase* const input_; 127 }; 128 }; 129 130 REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU), 131 IgnoreErrorsDatasetOp); 132 133 } // namespace 134 135 } // namespace tensorflow 136