Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/partial_tensor_shape.h"
     16 #include "tensorflow/core/framework/tensor.h"
     17 #include "tensorflow/core/kernels/data/dataset.h"
     18 
     19 namespace tensorflow {
     20 
     21 namespace {
     22 
     23 // See documentation in ../ops/dataset_ops.cc for a high-level
     24 // description of the following op.
     25 
     26 class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
     27  public:
     28   explicit ConcatenateDatasetOp(OpKernelConstruction* ctx)
     29       : BinaryDatasetOpKernel(ctx) {}
     30   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     31                    DatasetBase* to_concatenate, DatasetBase** output) override {
     32     OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(),
     33                 errors::InvalidArgument(
     34                     "input dataset and dataset to concatenate"
     35                     " have different output_types %s and %s",
     36                     (DataTypeVectorString(input->output_dtypes()),
     37                      DataTypeVectorString(to_concatenate->output_dtypes()))));
     38     *output = new Dataset(ctx, input, to_concatenate);
     39   }
     40 
     41  private:
     42   class Dataset : public GraphDatasetBase {
     43    public:
     44     explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
     45                      const DatasetBase* to_concatenate)
     46         : GraphDatasetBase(ctx),
     47           input_(input),
     48           to_concatenate_(to_concatenate) {
     49       input_->Ref();
     50       to_concatenate_->Ref();
     51 
     52       auto os_input = input->output_shapes();
     53       auto os_concatenate = to_concatenate->output_shapes();
     54       for (int i = 0; i < os_input.size(); i++) {
     55         output_shapes_.push_back(
     56             MostSpecificCompatibleShape(os_input[i], os_concatenate[i]));
     57       }
     58     }
     59     ~Dataset() override {
     60       input_->Unref();
     61       to_concatenate_->Unref();
     62     }
     63 
     64     std::unique_ptr<IteratorBase> MakeIterator(
     65         const string& prefix) const override {
     66       return std::unique_ptr<IteratorBase>(
     67           new Iterator({this, strings::StrCat(prefix, "::Concatenate")}));
     68     }
     69 
     70     const DataTypeVector& output_dtypes() const override {
     71       return input_->output_dtypes();
     72     }
     73 
     74     const std::vector<PartialTensorShape>& output_shapes() const override {
     75       return output_shapes_;
     76     }
     77 
     78     string DebugString() override { return "ConcatenateDatasetOp::Dataset"; }
     79 
     80    protected:
     81     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
     82                               Node** output) const override {
     83       Node* input_graph = nullptr;
     84       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph));
     85       Node* to_concatenate_graph = nullptr;
     86       TF_RETURN_IF_ERROR(
     87           b->AddParentDataset(ctx, to_concatenate_, &to_concatenate_graph));
     88       TF_RETURN_IF_ERROR(
     89           b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
     90       return Status::OK();
     91     }
     92 
     93    private:
     94     class Iterator : public DatasetIterator<Dataset> {
     95      public:
     96       explicit Iterator(const Params& params)
     97           : DatasetIterator<Dataset>(params),
     98             i_(0),
     99             input_impl_(params.dataset->input_->MakeIterator(
    100                 strings::StrCat(params.prefix, "[0]"))) {}
    101 
    102       Status GetNextInternal(IteratorContext* ctx,
    103                              std::vector<Tensor>* out_tensors,
    104                              bool* end_of_sequence) override {
    105         mutex_lock l(mu_);
    106         if (!input_impl_) {
    107           *end_of_sequence = true;
    108           return Status::OK();
    109         }
    110         while (i_ < 2) {
    111           TF_RETURN_IF_ERROR(
    112               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
    113           if (!*end_of_sequence) {
    114             return Status::OK();
    115           }
    116           if (++i_ < 2) {
    117             input_impl_ = dataset()->to_concatenate_->MakeIterator(
    118                 strings::StrCat(prefix(), "[1]"));
    119           }
    120         }
    121         *end_of_sequence = true;
    122         input_impl_.reset();
    123         return Status::OK();
    124       }
    125 
    126      protected:
    127       Status SaveInternal(IteratorStateWriter* writer) override {
    128         mutex_lock l(mu_);
    129         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
    130         if (input_impl_) {
    131           TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
    132         } else {
    133           TF_RETURN_IF_ERROR(
    134               writer->WriteScalar(full_name("input_impl_uninitialized"), ""));
    135         }
    136         return Status::OK();
    137       }
    138 
    139       Status RestoreInternal(IteratorContext* ctx,
    140                              IteratorStateReader* reader) override {
    141         mutex_lock l(mu_);
    142         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
    143         if (reader->Contains(full_name("input_impl_uninitialized"))) {
    144           input_impl_.reset();
    145           return Status::OK();
    146         }
    147         if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
    148           return errors::InvalidArgument("i_ must be in range [0, 2].");
    149         if (i_ == 1) {
    150           input_impl_ = dataset()->to_concatenate_->MakeIterator(
    151               strings::StrCat(prefix(), "[1]"));
    152         } else if (i_ == 2) {
    153           input_impl_.reset();
    154         }
    155         if (input_impl_) {
    156           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
    157         }
    158         return Status::OK();
    159       }
    160 
    161      private:
    162       mutex mu_;
    163       int64 i_ GUARDED_BY(mu_);
    164       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    165     };
    166 
    167     static PartialTensorShape MostSpecificCompatibleShape(
    168         const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
    169       PartialTensorShape output_tensorshape;
    170       if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
    171         return output_tensorshape;
    172       auto dims1 = ts1.dim_sizes();
    173       auto dims2 = ts2.dim_sizes();
    174       for (int d = 0; d < ts1.dims(); d++) {
    175         if (dims1[d] == dims2[d])
    176           output_tensorshape.Concatenate(dims1[d]);
    177         else
    178           output_tensorshape.Concatenate(-1);
    179       }
    180       return output_tensorshape;
    181     }
    182 
    183     const DatasetBase* input_;
    184     const DatasetBase* to_concatenate_;
    185     std::vector<PartialTensorShape> output_shapes_;
    186   };
    187 };
    188 
    189 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
    190                         ConcatenateDatasetOp);
    191 
    192 }  // namespace
    193 
    194 }  // namespace tensorflow
    195