Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/dataset.h"
     16 #include "tensorflow/core/framework/partial_tensor_shape.h"
     17 #include "tensorflow/core/framework/tensor.h"
     18 #include "tensorflow/core/lib/random/random.h"
     19 
     20 namespace tensorflow {
     21 
     22 namespace {
     23 
     24 // See documentation in ../ops/dataset_ops.cc for a high-level
     25 // description of the following op.
     26 
     27 class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
     28  public:
     29   explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
     30       : UnaryDatasetOpKernel(ctx) {}
     31 
     32   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     33                    DatasetBase** output) override {
     34     *output = new Dataset(ctx, input);
     35   }
     36 
     37  private:
     38   class Dataset : public GraphDatasetBase {
     39    public:
     40     explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
     41         : GraphDatasetBase(ctx), input_(input) {
     42       input_->Ref();
     43     }
     44 
     45     ~Dataset() override { input_->Unref(); }
     46 
     47     std::unique_ptr<IteratorBase> MakeIterator(
     48         const string& prefix) const override {
     49       return std::unique_ptr<IteratorBase>(
     50           new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
     51     }
     52 
     53     const DataTypeVector& output_dtypes() const override {
     54       return input_->output_dtypes();
     55     }
     56     const std::vector<PartialTensorShape>& output_shapes() const override {
     57       return input_->output_shapes();
     58     }
     59 
     60     string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; }
     61 
     62    protected:
     63     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
     64                               Node** output) const override {
     65       Node* input_graph_node = nullptr;
     66       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
     67       TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
     68       return Status::OK();
     69     }
     70 
     71    private:
     72     class Iterator : public DatasetIterator<Dataset> {
     73      public:
     74       explicit Iterator(const Params& params)
     75           : DatasetIterator<Dataset>(params),
     76             input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
     77 
     78       Status GetNextInternal(IteratorContext* ctx,
     79                              std::vector<Tensor>* out_tensors,
     80                              bool* end_of_sequence) override {
     81         {
     82           tf_shared_lock l(mu_);
     83           if (!input_impl_) {
     84             *end_of_sequence = true;
     85             return Status::OK();
     86           }
     87           Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
     88           while (!s.ok()) {
     89             out_tensors->clear();
     90             s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
     91           }
     92         }
     93         if (*end_of_sequence) {
     94           mutex_lock l(mu_);
     95           input_impl_.reset();
     96         }
     97         return Status::OK();
     98       }
     99 
    100      protected:
    101       Status SaveInternal(IteratorStateWriter* writer) override {
    102         mutex_lock l(mu_);
    103         if (input_impl_)
    104           TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
    105         else
    106           TF_RETURN_IF_ERROR(
    107               writer->WriteScalar(full_name("input_impls_empty"), ""));
    108         return Status::OK();
    109       }
    110 
    111       Status RestoreInternal(IteratorContext* ctx,
    112                              IteratorStateReader* reader) override {
    113         mutex_lock l(mu_);
    114         if (reader->Contains(full_name("input_impls_empty")))
    115           input_impl_.reset();
    116         else
    117           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
    118         return Status::OK();
    119       }
    120 
    121      private:
    122       mutex mu_;
    123       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    124     };
    125 
    126     const DatasetBase* const input_;
    127   };
    128 };
    129 
    130 REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
    131                         IgnoreErrorsDatasetOp);
    132 
    133 }  // namespace
    134 
    135 }  // namespace tensorflow
    136