1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <map> 16 17 #include "tensorflow/core/common_runtime/function.h" 18 #include "tensorflow/core/framework/dataset.h" 19 #include "tensorflow/core/framework/partial_tensor_shape.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/kernels/data/captured_function.h" 22 #include "tensorflow/core/lib/random/random.h" 23 24 namespace tensorflow { 25 namespace data { 26 namespace { 27 28 // See documentation in ../../ops/dataset_ops.cc for a high-level 29 // description of the following op. 30 class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { 31 public: 32 explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx) 33 : UnaryDatasetOpKernel(ctx) { 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); 35 OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); 37 OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); 38 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 39 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 40 } 41 42 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 43 DatasetBase** output) override { 44 std::unique_ptr<CapturedFunction> captured_key_func; 45 OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, 46 "key_func_other_arguments", 47 &captured_key_func)); 48 std::unique_ptr<CapturedFunction> captured_init_func; 49 OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx, 50 "init_func_other_arguments", 51 &captured_init_func)); 52 std::unique_ptr<CapturedFunction> captured_reduce_func; 53 OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, 54 "reduce_func_other_arguments", 55 &captured_reduce_func)); 56 std::unique_ptr<CapturedFunction> captured_finalize_func; 57 OP_REQUIRES_OK(ctx, 58 CapturedFunction::Create(finalize_func_, ctx, 59 "finalize_func_other_arguments", 60 &captured_finalize_func)); 61 62 *output = new Dataset( 63 ctx, input, std::move(captured_key_func), std::move(captured_init_func), 64 std::move(captured_reduce_func), std::move(captured_finalize_func), 65 output_types_, output_shapes_); 66 } 67 68 private: 69 class Dataset : public DatasetBase { 70 public: 71 Dataset(OpKernelContext* ctx, const DatasetBase* input, 72 std::unique_ptr<CapturedFunction> captured_key_func, 73 std::unique_ptr<CapturedFunction> captured_init_func, 74 std::unique_ptr<CapturedFunction> captured_reduce_func, 75 std::unique_ptr<CapturedFunction> captured_finalize_func, 76 const DataTypeVector& output_types, 77 const std::vector<PartialTensorShape>& output_shapes) 78 : DatasetBase(DatasetContext(ctx)), 79 input_(input), 80 captured_key_func_(std::move(captured_key_func)), 81 captured_init_func_(std::move(captured_init_func)), 82 captured_reduce_func_(std::move(captured_reduce_func)), 83 captured_finalize_func_(std::move(captured_finalize_func)), 84 output_types_(output_types), 85 output_shapes_(output_shapes) { 86 input_->Ref(); 87 } 88 89 ~Dataset() override { input_->Unref(); } 90 91 std::unique_ptr<IteratorBase> MakeIteratorInternal( 92 const string& prefix) const override { 93 return absl::make_unique<Iterator>( 94 Iterator::Params{this, strings::StrCat(prefix, "::GroupByReducer")}); 95 } 96 97 const DataTypeVector& output_dtypes() const override { 98 return output_types_; 99 } 100 const std::vector<PartialTensorShape>& output_shapes() const override { 101 return output_shapes_; 102 } 103 104 string DebugString() const override { 105 return "GroupByReducerDatasetOp::Dataset"; 106 } 107 108 protected: 109 Status AsGraphDefInternal(SerializationContext* ctx, 110 DatasetGraphDefBuilder* b, 111 Node** output) const override { 112 TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name())); 113 TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name())); 114 TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name())); 115 TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name())); 116 Node* input_graph_node = nullptr; 117 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 118 119 std::vector<Node*> key_func_other_arguments_node; 120 DataTypeVector key_func_other_arguments_types; 121 TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( 122 ctx, b, captured_key_func_, &key_func_other_arguments_node, 123 &key_func_other_arguments_types)); 124 125 std::vector<Node*> init_func_other_arguments_node; 126 DataTypeVector init_func_other_arguments_types; 127 TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( 128 ctx, b, captured_init_func_, &init_func_other_arguments_node, 129 &init_func_other_arguments_types)); 130 131 std::vector<Node*> reduce_func_other_arguments_node; 132 DataTypeVector reduce_func_other_arguments_types; 133 TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( 134 ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node, 135 &reduce_func_other_arguments_types)); 136 137 std::vector<Node*> finalize_func_other_arguments_node; 138 DataTypeVector finalize_func_other_arguments_types; 139 TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( 140 ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node, 141 &finalize_func_other_arguments_types)); 142 143 AttrValue key_func; 144 b->BuildAttrValue(this->key_func(), &key_func); 145 AttrValue init_func; 146 b->BuildAttrValue(this->init_func(), &init_func); 147 AttrValue reduce_func; 148 b->BuildAttrValue(this->reduce_func(), &reduce_func); 149 AttrValue finalize_func; 150 b->BuildAttrValue(this->finalize_func(), &finalize_func); 151 152 AttrValue key_func_other_arguments_types_attr; 153 b->BuildAttrValue(key_func_other_arguments_types, 154 &key_func_other_arguments_types_attr); 155 AttrValue init_func_other_arguments_types_attr; 156 b->BuildAttrValue(init_func_other_arguments_types, 157 &init_func_other_arguments_types_attr); 158 AttrValue reduce_func_other_arguments_types_attr; 159 b->BuildAttrValue(reduce_func_other_arguments_types, 160 &reduce_func_other_arguments_types_attr); 161 AttrValue finalize_func_other_arguments_types_attr; 162 b->BuildAttrValue(finalize_func_other_arguments_types, 163 &finalize_func_other_arguments_types_attr); 164 165 TF_RETURN_IF_ERROR(b->AddDataset( 166 this, {{0, input_graph_node}}, 167 {{1, key_func_other_arguments_node}, 168 {2, init_func_other_arguments_node}, 169 {3, reduce_func_other_arguments_node}, 170 {4, finalize_func_other_arguments_node}}, 171 {{"key_func", key_func}, 172 {"init_func", init_func}, 173 {"reduce_func", reduce_func}, 174 {"finalize_func", finalize_func}, 175 {"Tkey_func_other_arguments", key_func_other_arguments_types_attr}, 176 {"Tinit_func_other_arguments", init_func_other_arguments_types_attr}, 177 {"Treduce_func_other_arguments", 178 reduce_func_other_arguments_types_attr}, 179 {"Tfinalize_func_other_arguments", 180 finalize_func_other_arguments_types_attr}}, 181 output)); 182 return Status::OK(); 183 } 184 185 private: 186 class Iterator : public DatasetIterator<Dataset> { 187 public: 188 explicit Iterator(const Params& params) 189 : DatasetIterator<Dataset>(params) {} 190 191 Status Initialize(IteratorContext* ctx) override { 192 TF_RETURN_IF_ERROR( 193 dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); 194 TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate( 195 ctx, &instantiated_key_func_)); 196 TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate( 197 ctx, &instantiated_init_func_)); 198 TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate( 199 ctx, &instantiated_reduce_func_)); 200 TF_RETURN_IF_ERROR(dataset()->captured_finalize_func_->Instantiate( 201 ctx, &instantiated_finalize_func_)); 202 return Status::OK(); 203 } 204 205 Status GetNextInternal(IteratorContext* ctx, 206 std::vector<Tensor>* out_tensors, 207 bool* end_of_sequence) override { 208 mutex_lock l(mu_); 209 210 // Iterate through the input dataset, keying input elements to reducers. 211 while (!end_of_input_) { 212 std::vector<Tensor> next_input_element; 213 TF_RETURN_IF_ERROR( 214 input_impl_->GetNext(ctx, &next_input_element, &end_of_input_)); 215 216 if (!end_of_input_) { 217 // Run the key function on the input element. 218 std::vector<Tensor> key_func_output; 219 TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs( 220 ctx, next_input_element, &key_func_output)); 221 222 if (key_func_output.size() != 1 || 223 key_func_output[0].dtype() != DT_INT64 || 224 key_func_output[0].NumElements() != 1) { 225 // TODO(b/78665031): Support non-int64 keys. 226 return errors::InvalidArgument( 227 "`key_func` must return a scalar int64."); 228 } 229 const int64 key = key_func_output[0].scalar<int64>()(); 230 231 if (states_.find(key) == states_.end()) { 232 // Run the init function to create the initial state. 233 std::vector<Tensor> init_func_output; 234 TF_RETURN_IF_ERROR(instantiated_init_func_->Run( 235 ctx, std::move(key_func_output), &init_func_output)); 236 states_[key] = init_func_output; 237 } 238 239 // Run the reduce function to update the current state. 240 std::vector<Tensor> args; 241 args.reserve(states_[key].size() + next_input_element.size()); 242 std::copy(states_[key].begin(), states_[key].end(), 243 std::back_inserter(args)); 244 std::copy(next_input_element.begin(), next_input_element.end(), 245 std::back_inserter(args)); 246 247 std::vector<Tensor> reduce_func_output; 248 TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run( 249 ctx, std::move(args), &reduce_func_output)); 250 states_[key] = reduce_func_output; 251 } else { 252 keys_.resize(states_.size()); 253 int idx = 0; 254 for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) { 255 keys_[idx] = it->first; 256 } 257 } 258 } 259 260 if (keys_index_ == keys_.size()) { 261 *end_of_sequence = true; 262 return Status::OK(); 263 } 264 TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs( 265 ctx, states_[keys_[keys_index_++]], out_tensors)); 266 *end_of_sequence = false; 267 return Status::OK(); 268 } 269 270 protected: 271 std::shared_ptr<model::Node> CreateNode( 272 IteratorContext* ctx, model::Node::Args args) const override { 273 return model::MakeUnknownRatioNode(std::move(args)); 274 } 275 276 Status SaveInternal(IteratorStateWriter* writer) override { 277 mutex_lock l(mu_); 278 TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); 279 280 if (end_of_input_) { 281 TF_RETURN_IF_ERROR( 282 writer->WriteScalar(full_name("end_of_input"), "")); 283 } 284 285 // Saving states_. 286 if (!states_.empty()) { 287 TF_RETURN_IF_ERROR( 288 writer->WriteScalar(full_name("states_size"), states_.size())); 289 int idx = 0; 290 for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) { 291 int64 key = it->first; 292 TF_RETURN_IF_ERROR(writer->WriteScalar( 293 full_name(strings::StrCat("states[", idx, "]->key")), key)); 294 if (!it->second.empty()) { 295 TF_RETURN_IF_ERROR(writer->WriteScalar( 296 full_name(strings::StrCat("states[", idx, "]->state_size")), 297 it->second.size())); 298 for (int j = 0; j < it->second.size(); ++j) { 299 TF_RETURN_IF_ERROR(writer->WriteTensor( 300 full_name( 301 strings::StrCat("states[", idx, "]->state[", j, "]")), 302 it->second[j])); 303 } 304 } 305 } 306 } 307 308 // Saving keys_index_ and keys_. 309 if (end_of_input_) { 310 TF_RETURN_IF_ERROR( 311 writer->WriteScalar(full_name("keys_index"), keys_index_)); 312 if (!keys_.empty()) { 313 TF_RETURN_IF_ERROR( 314 writer->WriteScalar(full_name("keys_size"), keys_.size())); 315 for (int idx = 0; idx < keys_.size(); ++idx) { 316 TF_RETURN_IF_ERROR(writer->WriteScalar( 317 full_name(strings::StrCat("keys[", idx, "]")), keys_[idx])); 318 } 319 } 320 } 321 322 return Status::OK(); 323 } 324 325 Status RestoreInternal(IteratorContext* ctx, 326 IteratorStateReader* reader) override { 327 mutex_lock l(mu_); 328 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); 329 330 if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; 331 332 // Restoring states_. 333 if (reader->Contains(full_name("states_size"))) { 334 int64 size; 335 TF_RETURN_IF_ERROR( 336 reader->ReadScalar(full_name("states_size"), &size)); 337 for (int idx = 0; idx < size; ++idx) { 338 int64 key; 339 TF_RETURN_IF_ERROR(reader->ReadScalar( 340 full_name(strings::StrCat("states[", idx, "]->key")), &key)); 341 std::vector<Tensor> state; 342 if (reader->Contains(full_name( 343 strings::StrCat("states[", idx, "]->state_size")))) { 344 int64 state_size; 345 TF_RETURN_IF_ERROR(reader->ReadScalar( 346 full_name(strings::StrCat("states[", idx, "]->state_size")), 347 &state_size)); 348 state.resize(state_size); 349 for (int j = 0; j < state_size; ++j) { 350 TF_RETURN_IF_ERROR(reader->ReadTensor( 351 full_name( 352 strings::StrCat("states[", idx, "]->state[", j, "]")), 353 &state[j])); 354 } 355 } 356 states_[key] = state; 357 } 358 } 359 360 // Restoring keys_index_ and keys_. 361 if (end_of_input_) { 362 TF_RETURN_IF_ERROR( 363 reader->ReadScalar(full_name("keys_index"), &keys_index_)); 364 if (reader->Contains(full_name("keys_size"))) { 365 int64 size; 366 TF_RETURN_IF_ERROR( 367 reader->ReadScalar(full_name("keys_size"), &size)); 368 keys_.resize(size); 369 for (int idx = 0; idx < size; ++idx) { 370 int64 key; 371 TF_RETURN_IF_ERROR(reader->ReadScalar( 372 full_name(strings::StrCat("keys[", idx, "]")), &key)); 373 keys_[idx] = key; 374 } 375 } 376 } 377 378 return Status::OK(); 379 } 380 381 private: 382 mutex mu_; 383 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 384 bool end_of_input_ GUARDED_BY(mu_) = false; 385 std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_); 386 std::vector<int64> keys_ GUARDED_BY(mu_); 387 int64 keys_index_ GUARDED_BY(mu_) = 0; 388 std::unique_ptr<InstantiatedCapturedFunction> instantiated_key_func_; 389 std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_; 390 std::unique_ptr<InstantiatedCapturedFunction> instantiated_reduce_func_; 391 std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_; 392 }; 393 394 const NameAttrList& key_func() const { return captured_key_func_->func(); } 395 396 const NameAttrList& init_func() const { 397 return captured_init_func_->func(); 398 } 399 400 const NameAttrList& reduce_func() const { 401 return captured_reduce_func_->func(); 402 } 403 404 const NameAttrList& finalize_func() const { 405 return captured_finalize_func_->func(); 406 } 407 408 Status OtherArgumentsNodeAndType( 409 SerializationContext* ctx, DatasetGraphDefBuilder* b, 410 const std::unique_ptr<CapturedFunction>& captured_func, 411 std::vector<Node*>* other_arguments_node, 412 DataTypeVector* other_arguments_types) const { 413 other_arguments_node->reserve(captured_func->captured_inputs().size()); 414 other_arguments_types->reserve(captured_func->captured_inputs().size()); 415 for (const Tensor& t : captured_func->captured_inputs()) { 416 Node* node; 417 DatasetBase* input; 418 Status s = GetDatasetFromVariantTensor(t, &input); 419 if (s.ok()) { 420 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); 421 } else { 422 TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); 423 } 424 other_arguments_node->emplace_back(node); 425 other_arguments_types->emplace_back(t.dtype()); 426 } 427 return Status::OK(); 428 } 429 430 const DatasetBase* const input_; 431 const std::unique_ptr<CapturedFunction> captured_key_func_; 432 const std::unique_ptr<CapturedFunction> captured_init_func_; 433 const std::unique_ptr<CapturedFunction> captured_reduce_func_; 434 const std::unique_ptr<CapturedFunction> captured_finalize_func_; 435 const DataTypeVector output_types_; 436 const std::vector<PartialTensorShape> output_shapes_; 437 }; 438 439 DataTypeVector output_types_; 440 std::vector<PartialTensorShape> output_shapes_; 441 NameAttrList key_func_; 442 NameAttrList init_func_; 443 NameAttrList reduce_func_; 444 NameAttrList finalize_func_; 445 }; 446 447 REGISTER_KERNEL_BUILDER( 448 Name("ExperimentalGroupByReducerDataset").Device(DEVICE_CPU), 449 GroupByReducerDatasetOp); 450 451 } // namespace 452 } // namespace data 453 } // namespace tensorflow 454