1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/partial_tensor_shape.h" 16 #include "tensorflow/core/framework/tensor.h" 17 #include "tensorflow/core/kernels/data/dataset.h" 18 19 namespace tensorflow { 20 21 namespace { 22 23 // See documentation in ../ops/dataset_ops.cc for a high-level 24 // description of the following op. 25 26 class ConcatenateDatasetOp : public BinaryDatasetOpKernel { 27 public: 28 explicit ConcatenateDatasetOp(OpKernelConstruction* ctx) 29 : BinaryDatasetOpKernel(ctx) {} 30 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 31 DatasetBase* to_concatenate, DatasetBase** output) override { 32 OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(), 33 errors::InvalidArgument( 34 "input dataset and dataset to concatenate" 35 " have different output_types %s and %s", 36 (DataTypeVectorString(input->output_dtypes()), 37 DataTypeVectorString(to_concatenate->output_dtypes())))); 38 *output = new Dataset(ctx, input, to_concatenate); 39 } 40 41 private: 42 class Dataset : public GraphDatasetBase { 43 public: 44 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, 45 const DatasetBase* to_concatenate) 46 : GraphDatasetBase(ctx), 47 input_(input), 48 to_concatenate_(to_concatenate) { 49 input_->Ref(); 50 to_concatenate_->Ref(); 51 52 auto os_input = input->output_shapes(); 53 auto os_concatenate = to_concatenate->output_shapes(); 54 for (int i = 0; i < os_input.size(); i++) { 55 output_shapes_.push_back( 56 MostSpecificCompatibleShape(os_input[i], os_concatenate[i])); 57 } 58 } 59 ~Dataset() override { 60 input_->Unref(); 61 to_concatenate_->Unref(); 62 } 63 64 std::unique_ptr<IteratorBase> MakeIterator( 65 const string& prefix) const override { 66 return std::unique_ptr<IteratorBase>( 67 new Iterator({this, strings::StrCat(prefix, "::Concatenate")})); 68 } 69 70 const DataTypeVector& output_dtypes() const override { 71 return input_->output_dtypes(); 72 } 73 74 const std::vector<PartialTensorShape>& output_shapes() const override { 75 return output_shapes_; 76 } 77 78 string DebugString() override { return "ConcatenateDatasetOp::Dataset"; } 79 80 protected: 81 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 82 Node** output) const override { 83 Node* input_graph = nullptr; 84 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph)); 85 Node* to_concatenate_graph = nullptr; 86 TF_RETURN_IF_ERROR( 87 b->AddParentDataset(ctx, to_concatenate_, &to_concatenate_graph)); 88 TF_RETURN_IF_ERROR( 89 b->AddDataset(this, {input_graph, to_concatenate_graph}, output)); 90 return Status::OK(); 91 } 92 93 private: 94 class Iterator : public DatasetIterator<Dataset> { 95 public: 96 explicit Iterator(const Params& params) 97 : DatasetIterator<Dataset>(params), 98 i_(0), 99 input_impl_(params.dataset->input_->MakeIterator( 100 strings::StrCat(params.prefix, "[0]"))) {} 101 102 Status GetNextInternal(IteratorContext* ctx, 103 std::vector<Tensor>* out_tensors, 104 bool* end_of_sequence) override { 105 mutex_lock l(mu_); 106 if (!input_impl_) { 107 *end_of_sequence = true; 108 return Status::OK(); 109 } 110 while (i_ < 2) { 111 TF_RETURN_IF_ERROR( 112 input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); 113 if (!*end_of_sequence) { 114 return Status::OK(); 115 } 116 if (++i_ < 2) { 117 input_impl_ = dataset()->to_concatenate_->MakeIterator( 118 strings::StrCat(prefix(), "[1]")); 119 } 120 } 121 *end_of_sequence = true; 122 input_impl_.reset(); 123 return Status::OK(); 124 } 125 126 protected: 127 Status SaveInternal(IteratorStateWriter* writer) override { 128 mutex_lock l(mu_); 129 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); 130 if (input_impl_) { 131 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 132 } else { 133 TF_RETURN_IF_ERROR( 134 writer->WriteScalar(full_name("input_impl_uninitialized"), "")); 135 } 136 return Status::OK(); 137 } 138 139 Status RestoreInternal(IteratorContext* ctx, 140 IteratorStateReader* reader) override { 141 mutex_lock l(mu_); 142 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); 143 if (reader->Contains(full_name("input_impl_uninitialized"))) { 144 input_impl_.reset(); 145 return Status::OK(); 146 } 147 if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2)) 148 return errors::InvalidArgument("i_ must be in range [0, 2]."); 149 if (i_ == 1) { 150 input_impl_ = dataset()->to_concatenate_->MakeIterator( 151 strings::StrCat(prefix(), "[1]")); 152 } else if (i_ == 2) { 153 input_impl_.reset(); 154 } 155 if (input_impl_) { 156 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 157 } 158 return Status::OK(); 159 } 160 161 private: 162 mutex mu_; 163 int64 i_ GUARDED_BY(mu_); 164 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 165 }; 166 167 static PartialTensorShape MostSpecificCompatibleShape( 168 const PartialTensorShape& ts1, const PartialTensorShape& ts2) { 169 PartialTensorShape output_tensorshape; 170 if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) 171 return output_tensorshape; 172 auto dims1 = ts1.dim_sizes(); 173 auto dims2 = ts2.dim_sizes(); 174 for (int d = 0; d < ts1.dims(); d++) { 175 if (dims1[d] == dims2[d]) 176 output_tensorshape.Concatenate(dims1[d]); 177 else 178 output_tensorshape.Concatenate(-1); 179 } 180 return output_tensorshape; 181 } 182 183 const DatasetBase* input_; 184 const DatasetBase* to_concatenate_; 185 std::vector<PartialTensorShape> output_shapes_; 186 }; 187 }; 188 189 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU), 190 ConcatenateDatasetOp); 191 192 } // namespace 193 194 } // namespace tensorflow 195