1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/common_runtime/function.h" 16 #include "tensorflow/core/common_runtime/graph_runner.h" 17 #include "tensorflow/core/common_runtime/renamed_device.h" 18 #include "tensorflow/core/common_runtime/threadpool_device.h" 19 #include "tensorflow/core/framework/iterator.pb.h" 20 #include "tensorflow/core/framework/partial_tensor_shape.h" 21 #include "tensorflow/core/framework/resource_op_kernel.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/variant_op_registry.h" 24 #include "tensorflow/core/graph/graph_constructor.h" 25 #include "tensorflow/core/kernels/data/dataset.h" 26 #include "tensorflow/core/kernels/data/stats_aggregator.h" 27 #include "tensorflow/core/kernels/ops_util.h" 28 #include "tensorflow/core/lib/core/threadpool.h" 29 #include "tensorflow/core/lib/gtl/cleanup.h" 30 #include "tensorflow/core/lib/random/random.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/lib/strings/stringprintf.h" 33 #include "tensorflow/core/platform/env.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 namespace { 39 40 // See documentation in ../ops/dataset_ops.cc for a high-level 41 // description of the following ops. 42 43 const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; 44 45 Status VerifyTypesMatch(const DataTypeVector& expected, 46 const DataTypeVector& received) { 47 if (expected.size() != received.size()) { 48 return errors::InvalidArgument( 49 "Number of components does not match: expected ", expected.size(), 50 " types but got ", received.size(), "."); 51 } 52 for (size_t i = 0; i < expected.size(); ++i) { 53 if (expected[i] != received[i]) { 54 return errors::InvalidArgument("Data type mismatch at component ", i, 55 ": expected ", DataTypeString(expected[i]), 56 " but got ", DataTypeString(received[i]), 57 "."); 58 } 59 } 60 return Status::OK(); 61 } 62 63 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, 64 const std::vector<PartialTensorShape>& received) { 65 if (expected.size() != received.size()) { 66 return errors::InvalidArgument( 67 "Number of components does not match: expected ", expected.size(), 68 " shapes but got ", received.size(), "."); 69 } 70 for (size_t i = 0; i < expected.size(); ++i) { 71 if (!expected[i].IsCompatibleWith(received[i])) { 72 return errors::InvalidArgument("Incompatible shapes at component ", i, 73 ": expected ", expected[i].DebugString(), 74 " but got ", received[i].DebugString(), 75 "."); 76 } 77 } 78 79 return Status::OK(); 80 } 81 82 class IteratorResource : public ResourceBase { 83 public: 84 IteratorResource(const DataTypeVector& output_dtypes, 85 const std::vector<PartialTensorShape>& output_shapes, 86 const int /*unused: graph_def_version*/, 87 std::unique_ptr<DeviceMgr> device_mgr, 88 std::unique_ptr<FunctionLibraryDefinition> flib_def, 89 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 90 FunctionLibraryRuntime* lib) 91 : device_mgr_(std::move(device_mgr)), 92 flib_def_(std::move(flib_def)), 93 pflr_(std::move(pflr)), 94 lib_(lib), 95 iterator_(nullptr), 96 output_dtypes_(output_dtypes), 97 output_shapes_(output_shapes) {} 98 99 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, 100 bool* end_of_sequence) { 101 std::shared_ptr<IteratorBase> captured_iterator(iterator_); 102 if (captured_iterator) { 103 if (lib_ != nullptr) { 104 ctx->set_lib(lib_); 105 } 106 return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence); 107 } else { 108 return errors::FailedPrecondition( 109 "GetNext() failed because the iterator has not been initialized. " 110 "Ensure that you have run the initializer operation for this " 111 "iterator before getting the next element."); 112 } 113 } 114 115 Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { 116 std::shared_ptr<IteratorBase> captured_iterator(iterator_); 117 if (captured_iterator) { 118 return captured_iterator->Save(ctx, writer); 119 } else { 120 return errors::FailedPrecondition( 121 "Save() failed because the iterator has not been initialized. " 122 "Ensure that you have run the initializer operation for this " 123 "iterator before saving it."); 124 } 125 } 126 127 Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { 128 string serialized_graph_def; 129 TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey, 130 &serialized_graph_def)); 131 GraphDef graph_def; 132 if (!graph_def.ParseFromString(serialized_graph_def)) { 133 return errors::Internal("Error parsing dataset GraphDef."); 134 } 135 string output_node; 136 TF_RETURN_IF_ERROR(reader->ReadScalar( 137 GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node)); 138 DatasetBase* dataset = nullptr; 139 Graph graph(OpRegistry::Global()); 140 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); 141 std::vector<Tensor> outputs; 142 GraphRunner graph_runner(ctx->env()); 143 144 // Build a new FLR that knows about the functions in the graph. 145 std::shared_ptr<FunctionLibraryDefinition> flib_def( 146 new FunctionLibraryDefinition( 147 *ctx->function_library()->GetFunctionLibraryDefinition())); 148 TF_RETURN_IF_ERROR(flib_def->AddLibrary(graph_def.library())); 149 150 TF_RETURN_IF_ERROR( 151 graph_runner.Run(&graph, lib_, {}, {output_node}, &outputs)); 152 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); 153 154 TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); 155 std::shared_ptr<IteratorBase> captured_iterator(iterator_); 156 157 if (captured_iterator) { 158 IteratorContext::Params params; 159 params.env = ctx->env(); 160 params.runner = *(ctx->runner()); 161 params.function_library = flib_def; 162 params.lib = lib_; 163 DeviceBase* device = lib_->device(); 164 params.allocator_getter = [device](AllocatorAttributes attrs) { 165 return device->GetAllocator(attrs); 166 }; 167 IteratorContext iter_ctx(std::move(params)); 168 169 TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader)); 170 mutex_lock l(mu_); 171 lib_def_ = std::move(flib_def); 172 return Status::OK(); 173 } else { 174 return errors::FailedPrecondition( 175 "Failed to restore iterator. Make sure the checkpoint ", 176 "is not corrupt. If the checkpoint does not contain the GraphDef, ", 177 "you will need to initialize your iterator before restoring."); 178 } 179 } 180 181 std::shared_ptr<const FunctionLibraryDefinition> function_library() { 182 tf_shared_lock l(mu_); 183 return lib_def_; 184 } 185 186 // Transfers ownership of iterator to this. This method is thread-safe. 187 Status set_iterator(std::unique_ptr<IteratorBase> iterator) { 188 if (iterator) { 189 TF_RETURN_IF_ERROR( 190 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes())); 191 TF_RETURN_IF_ERROR( 192 VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); 193 } 194 iterator_.reset(iterator.release()); 195 return Status::OK(); 196 } 197 198 void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) { 199 mutex_lock l(mu_); 200 stats_aggregator_ = std::move(stats_aggregator); 201 } 202 203 std::shared_ptr<StatsAggregator> stats_aggregator() { 204 tf_shared_lock l(mu_); 205 return stats_aggregator_; 206 } 207 208 string DebugString() override { return "Iterator resource"; } 209 210 const DataTypeVector& output_dtypes() const { return output_dtypes_; } 211 212 const std::vector<PartialTensorShape>& output_shapes() const { 213 return output_shapes_; 214 } 215 216 private: 217 // The following (device_mgr_, flib_def_, pflr_) are only used when the 218 // IteratorResource is shared between sessions and in that case we create 219 // a new FLR. Otherwise these are set to null. 220 std::unique_ptr<DeviceMgr> device_mgr_; 221 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 222 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 223 FunctionLibraryRuntime* lib_ = nullptr; // not owned. 224 std::shared_ptr<IteratorBase> iterator_; 225 mutex mu_; 226 std::shared_ptr<StatsAggregator> stats_aggregator_ GUARDED_BY(mu_); 227 std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_); 228 const DataTypeVector output_dtypes_; 229 const std::vector<PartialTensorShape> output_shapes_; 230 }; 231 232 // Helper class for reading data from a VariantTensorData object. 233 class VariantTensorDataReader : public IteratorStateReader { 234 public: 235 explicit VariantTensorDataReader(const VariantTensorData* data) 236 : data_(data) { 237 PreProcess(); 238 } 239 240 // Returns OK iff the initialization was successful, i.e., 241 // pre-processing did not have errors. 242 Status status() const { return status_; } 243 244 Status ReadScalar(StringPiece key, int64* val) override { 245 return ReadScalarInternal(key, val); 246 } 247 248 Status ReadScalar(StringPiece key, string* val) override { 249 return ReadScalarInternal(key, val); 250 } 251 252 Status ReadTensor(StringPiece key, Tensor* val) override { 253 return ReadTensorInternal(key, val); 254 } 255 256 bool Contains(StringPiece key) override { 257 return map_.find(key.ToString()) != map_.end(); 258 } 259 260 private: 261 void PreProcess() { 262 string metadata; 263 data_->get_metadata(&metadata); 264 IteratorStateMetadata proto; 265 if (!proto.ParseFromString(metadata)) { 266 status_ = errors::Internal("Error parsing IteratorStateMetadata."); 267 return; 268 } 269 size_t num_entries = proto.keys_size(); 270 CHECK_EQ(num_entries, data_->tensors_size()); 271 for (size_t i = 0; i < num_entries; i++) { 272 map_[proto.keys(i)] = i; 273 } 274 } 275 276 template <typename T> 277 Status ReadScalarInternal(StringPiece key, T* val) { 278 if (map_.find(key.ToString()) == map_.end()) { 279 return errors::NotFound(key); 280 } 281 *val = data_->tensors(map_[key.ToString()]).scalar<T>()(); 282 return Status::OK(); 283 } 284 285 Status ReadTensorInternal(StringPiece key, Tensor* val) { 286 if (map_.find(key.ToString()) == map_.end()) { 287 return errors::NotFound(key); 288 } 289 *val = data_->tensors(map_[key.ToString()]); 290 return Status::OK(); 291 } 292 293 std::map<string, size_t> map_; 294 const VariantTensorData* data_; // Not owned. 295 Status status_; 296 }; 297 298 // Helper class for writing data to a VariantTensorData object. 299 class VariantTensorDataWriter : public IteratorStateWriter { 300 public: 301 // Does not take ownership of data. 302 explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {} 303 304 Status WriteScalar(StringPiece key, const int64 val) override { 305 return WriteScalarInternal(key, val); 306 } 307 308 Status WriteScalar(StringPiece key, const string& val) override { 309 return WriteScalarInternal(key, val); 310 } 311 312 Status WriteTensor(StringPiece key, const Tensor& val) override { 313 return WriteTensorInternal(key, val); 314 } 315 316 // Writes the metadata to `data_`. 317 Status Flush() { 318 string metadata; 319 if (!metadata_proto_.SerializeToString(&metadata)) { 320 return errors::Internal("Unable to serialize IteratorStateMetadata."); 321 } 322 data_->set_metadata(metadata); 323 return Status::OK(); 324 } 325 326 private: 327 template <typename T> 328 Status WriteScalarInternal(StringPiece key, const T& val) { 329 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); 330 val_t.scalar<T>()() = val; 331 return WriteTensorInternal(key, val_t); 332 } 333 334 Status WriteTensorInternal(StringPiece key, const Tensor& val) { 335 // Write key to the metadata proto. This gets written to `data_` 336 // when `Flush()` is called. We do this lazily to avoid multiple 337 // serialization calls. 338 metadata_proto_.add_keys(key.ToString()); 339 340 // Update tensors. 341 *(data_->add_tensors()) = val; 342 return Status::OK(); 343 } 344 345 VariantTensorData* data_; 346 // TODO(srbs): Set the version string. 347 IteratorStateMetadata metadata_proto_; 348 }; 349 350 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor. 351 // The get() method returns an IteratorStateReader which can be used 352 // to restore iterator state. 353 // 354 // Usage example: 355 // 356 // Encoding: 357 // 358 // Tensor t(DT_VARIANT, TensorShape({})); 359 // t->scalar<Variant>()() = IteratorStateVariant(iterator_resource); 360 // 361 // Encode() sets the type_name of the VariantTensorData object to 362 // IteratorStateVariant::TypeName(). 363 // 364 // Decoding: 365 // 366 // Variant v = <VariantTensorDataProto object>; 367 // DecodeUnaryVariant(&v); 368 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>(); 369 // iterator_resource->Restore(ctx, wrapper->get()) 370 // 371 // The type_name of the VariantTensorData object to be decoded must 372 // match IteratorStateVariant::TypeName(). 373 class IteratorStateVariant { 374 public: 375 IteratorStateVariant() : data_(nullptr) {} 376 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) { 377 if (other.data_) { 378 Decode(*other.data_); 379 } 380 } 381 // Initializes this object with the current state of the iterator so 382 // that it can be written on the next call to Encode(). 383 Status InitializeFromIterator(OpKernelContext* ctx, 384 IteratorResource* iterator_resource) { 385 data_.reset(new VariantTensorData()); 386 data_->set_type_name(TypeName()); 387 VariantTensorDataWriter writer(data_.get()); 388 TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer)); 389 TF_RETURN_IF_ERROR(writer.Flush()); 390 return Status::OK(); 391 } 392 string TypeName() const { return kIteratorVariantTypeName; } 393 void Encode(VariantTensorData* data) const { *data = *data_; } 394 bool Decode(const VariantTensorData& data) { 395 if (data.type_name() != TypeName()) { 396 return false; 397 } 398 std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData); 399 *tensor_data = data; 400 std::unique_ptr<VariantTensorDataReader> reader( 401 new VariantTensorDataReader(tensor_data.get())); 402 status_ = reader->status(); 403 if (!status_.ok()) { 404 return false; 405 } 406 data_ = std::move(tensor_data); 407 reader_ = std::move(reader); 408 return true; 409 } 410 IteratorStateReader* get() { return reader_.get(); } 411 Status status() const { return status_; } 412 string DebugString() const { 413 if (data_) { 414 return strings::StrCat("IteratorStateVariant<", 415 "data: ", data_->DebugString(), 416 " status: ", status_.ToString(), ">"); 417 } else { 418 return strings::StrCat("IteratorStateVariant<empty>"); 419 } 420 } 421 422 private: 423 std::unique_ptr<IteratorStateReader> reader_; 424 Status status_; 425 std::unique_ptr<VariantTensorData> data_; 426 }; 427 428 // Register the reader class in the global variant decode_fn registry 429 // so that a Variant containing a serialized representation of iterator state 430 // can be decoded using DecodeUnaryVariant. If we don't do this we will need 431 // to manually decode the returned Variant using MaybeDecodeAndCopy in 432 // DeserializeIteratorOp which is not recommended. 433 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, 434 kIteratorVariantTypeName); 435 436 class IteratorHandleOp : public OpKernel { 437 public: 438 explicit IteratorHandleOp(OpKernelConstruction* ctx) 439 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { 440 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); 441 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 442 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); 443 } 444 445 // The resource is deleted from the resource manager only when it is private 446 // to kernel. Ideally the resource should be deleted when it is no longer held 447 // by anyone, but it would break backward compatibility. 448 ~IteratorHandleOp() override { 449 if (resource_ != nullptr) { 450 resource_->Unref(); 451 if (cinfo_.resource_is_private_to_kernel()) { 452 if (!cinfo_.resource_manager() 453 ->template Delete<IteratorResource>(cinfo_.container(), 454 cinfo_.name()) 455 .ok()) { 456 // Do nothing; the resource can have been deleted by session resets. 457 } 458 } 459 } 460 } 461 462 void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { 463 { 464 mutex_lock l(mu_); 465 if (resource_ == nullptr) { 466 FunctionLibraryRuntime* lib; 467 std::unique_ptr<DeviceMgr> device_mgr(nullptr); 468 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); 469 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); 470 // If the iterator is shared then we construct a new FLR, and pass that 471 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote 472 // functions from the iterator. We may add this functionality if there 473 // is sufficient demand, but it will require a significant refactoring. 474 if (!name_.empty()) { 475 lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); 476 } else { 477 OP_REQUIRES_OK(context, context->function_library()->Clone( 478 &flib_def, &pflr, &lib)); 479 } 480 481 ResourceMgr* mgr = context->resource_manager(); 482 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); 483 484 IteratorResource* resource; 485 OP_REQUIRES_OK( 486 context, 487 mgr->LookupOrCreate<IteratorResource>( 488 cinfo_.container(), cinfo_.name(), &resource, 489 [lib, &device_mgr, &flib_def, &pflr, 490 this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 491 *ret = new IteratorResource( 492 output_dtypes_, output_shapes_, graph_def_version_, 493 std::move(device_mgr), std::move(flib_def), 494 std::move(pflr), lib); 495 return Status::OK(); 496 })); 497 498 Status s = VerifyResource(resource); 499 if (TF_PREDICT_FALSE(!s.ok())) { 500 resource->Unref(); 501 context->SetStatus(s); 502 return; 503 } 504 505 resource_ = resource; 506 } 507 } 508 OP_REQUIRES_OK(context, MakeResourceHandleToOutput( 509 context, 0, cinfo_.container(), cinfo_.name(), 510 MakeTypeIndex<IteratorResource>())); 511 } 512 513 private: 514 // During the first Compute(), resource is either created or looked up using 515 // shared_name. In the latter case, the resource found should be verified if 516 // it is compatible with this op's configuration. The verification may fail in 517 // cases such as two graphs asking queues of the same shared name to have 518 // inconsistent capacities. 519 Status VerifyResource(IteratorResource* resource) { 520 TF_RETURN_IF_ERROR( 521 VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); 522 TF_RETURN_IF_ERROR( 523 VerifyShapesCompatible(output_shapes_, resource->output_shapes())); 524 return Status::OK(); 525 } 526 527 template <typename To, typename From> // use like this: down_cast<T*>(foo); 528 static inline To down_cast(From* f) { // so we only accept pointers 529 static_assert( 530 (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), 531 "target type not derived from source type"); 532 533 // We skip the assert and hence the dynamic_cast if RTTI is disabled. 534 #if !defined(__GNUC__) || defined(__GXX_RTTI) 535 // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. 536 assert(f == nullptr || dynamic_cast<To>(f) != nullptr); 537 #endif // !defined(__GNUC__) || defined(__GXX_RTTI) 538 return static_cast<To>(f); 539 } 540 541 FunctionLibraryRuntime* CreatePrivateFLR( 542 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, 543 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 544 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) { 545 // Wrap the existing device in order to see any captured resources 546 // in its resource manager. The existing device will outlive the 547 // IteratorResource, because we are storing the IteratorResource 548 // in that device's resource manager. 549 Device* wrapped_device = RenamedDevice::NewRenamedDevice( 550 ctx->device()->name(), down_cast<Device*>(ctx->device()), 551 false /* owns_underlying */, false /* isolate_session_state */); 552 device_mgr->reset(new DeviceMgr({wrapped_device})); 553 flib_def->reset(new FunctionLibraryDefinition( 554 *ctx->function_library()->GetFunctionLibraryDefinition())); 555 pflr->reset(new ProcessFunctionLibraryRuntime( 556 device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(), 557 {} /* TODO(mrry): OptimizerOptions? */, 558 nullptr /* TODO(mrry): ClusterFLR */)); 559 560 return (*pflr)->GetFLR(ctx->device()->name()); 561 } 562 563 mutex mu_; 564 ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. 565 IteratorResource* resource_ GUARDED_BY(mu_) = nullptr; 566 DataTypeVector output_dtypes_; 567 std::vector<PartialTensorShape> output_shapes_; 568 const int graph_def_version_; 569 string name_; 570 }; 571 572 class MakeIteratorOp : public OpKernel { 573 public: 574 explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 575 576 void Compute(OpKernelContext* ctx) override { 577 DatasetBase* dataset; 578 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); 579 IteratorResource* iterator_resource; 580 OP_REQUIRES_OK( 581 ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); 582 OP_REQUIRES_OK(ctx, iterator_resource->set_iterator( 583 dataset->MakeIterator("Iterator"))); 584 iterator_resource->Unref(); 585 } 586 }; 587 588 class ToSingleElementOp : public AsyncOpKernel { 589 public: 590 explicit ToSingleElementOp(OpKernelConstruction* ctx) 591 : AsyncOpKernel(ctx), 592 thread_pool_(new thread::ThreadPool( 593 ctx->env(), ThreadOptions(), 594 strings::StrCat("to_single_element_op_thread_", 595 SanitizeThreadSuffix(name())), 596 1 /* num_threads */, false /* low_latency_hint */)) {} 597 598 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 599 // The call to `iterator->GetNext()` may block and depend on an 600 // inter-op thread pool thread, so we issue the call from the 601 // owned thread pool. 602 thread_pool_->Schedule([ctx, done]() { 603 DatasetBase* dataset; 604 OP_REQUIRES_OK_ASYNC( 605 ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); 606 auto iterator = dataset->MakeIterator("SingleElementIterator"); 607 608 IteratorContext::Params params; 609 params.env = ctx->env(); 610 params.runner = *(ctx->runner()); 611 params.lib = ctx->function_library(); 612 DeviceBase* device = ctx->function_library()->device(); 613 params.allocator_getter = [device](AllocatorAttributes attrs) { 614 return device->GetAllocator(attrs); 615 }; 616 617 IteratorContext iter_ctx(std::move(params)); 618 619 std::vector<Tensor> components; 620 components.reserve(dataset->output_dtypes().size()); 621 bool end_of_sequence; 622 623 OP_REQUIRES_OK_ASYNC( 624 ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), 625 done); 626 OP_REQUIRES_ASYNC(ctx, !end_of_sequence, 627 errors::InvalidArgument("Dataset was empty."), done); 628 629 for (int i = 0; i < components.size(); ++i) { 630 // TODO(mrry): Check that the shapes match the shape attrs. 631 ctx->set_output(i, components[i]); 632 } 633 634 components.clear(); 635 OP_REQUIRES_OK_ASYNC( 636 ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), 637 done); 638 OP_REQUIRES_ASYNC( 639 ctx, end_of_sequence, 640 errors::InvalidArgument("Dataset had more than one element."), done); 641 642 done(); 643 }); 644 } 645 646 private: 647 std::unique_ptr<thread::ThreadPool> thread_pool_; 648 }; 649 650 class OneShotIteratorOp : public AsyncOpKernel { 651 public: 652 explicit OneShotIteratorOp(OpKernelConstruction* ctx) 653 : AsyncOpKernel(ctx), 654 thread_pool_(new thread::ThreadPool( 655 ctx->env(), ThreadOptions(), 656 strings::StrCat("one_shot_iterator_initialization_thread_", 657 SanitizeThreadSuffix(name())), 658 1 /* num_threads */, false /* low_latency_hint */)), 659 graph_def_version_(ctx->graph_def_version()) 660 661 { 662 string shared_name; 663 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name)); 664 OP_REQUIRES(ctx, shared_name.empty(), 665 errors::InvalidArgument("OneShotIteratorOp does not currently " 666 "support the 'shared_name' attr.")); 667 OP_REQUIRES_OK(ctx, 668 ctx->GetAttr("dataset_factory", &dataset_factory_func_)); 669 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); 670 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 671 } 672 673 ~OneShotIteratorOp() override { 674 if (iterator_resource_ != nullptr) { 675 iterator_resource_->Unref(); 676 if (!cinfo_.resource_manager() 677 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name()) 678 .ok()) { 679 // Do nothing; the resource can have been deleted by session resets. 680 } 681 } 682 } 683 684 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`, 685 // but due to the fact that `ResourceOpKernel<T>::CreateResource()` 686 // does not provide access to the `OpKernelContext*` and we need 687 // this to invoke the factory function, it's not possible to 688 // implement this kernel by implementing `CreateResource()`. 689 // Furthermore, due to the fact that this kernel might block when 690 // running the initialization function, we must implement this 691 // kernel as an async kernel. 692 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 693 { 694 mutex_lock l(mu_); 695 if (iterator_resource_ == nullptr && initialization_status_.ok()) { 696 // The initialization thread will call `done`. 697 if (!initialization_started_) { 698 // TODO(mrry): Convert the initialization code to use 699 // callbacks instead of wasting a thread. 700 thread_pool_->Schedule([this, ctx, done]() { Init(ctx, done); }); 701 initialization_started_ = true; 702 } else { 703 done_callbacks_.emplace_back(ctx, std::move(done)); 704 } 705 return; 706 } 707 } 708 ProduceOutput(ctx, std::move(done)); 709 } 710 711 private: 712 void Init(OpKernelContext* ctx, DoneCallback done) { 713 IteratorResource* iterator = nullptr; 714 ContainerInfo cinfo; 715 Status s = TryInit(ctx, &iterator, &cinfo); 716 717 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run; 718 { 719 mutex_lock l(mu_); 720 if (s.ok()) { 721 iterator_resource_ = iterator; 722 cinfo_ = cinfo; 723 } 724 initialization_status_ = s; 725 std::swap(done_callbacks_, callbacks_to_run); 726 } 727 728 for (auto&& ctx_done : callbacks_to_run) { 729 ProduceOutput(ctx_done.first, std::move(ctx_done.second)); 730 } 731 ProduceOutput(ctx, std::move(done)); 732 } 733 734 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, 735 ContainerInfo* cinfo) { 736 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); 737 738 FunctionLibraryRuntime* lib; 739 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); 740 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); 741 TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib)); 742 743 // Create an IteratorResource that will hold the iterator for this op. 744 TF_RETURN_IF_ERROR( 745 ctx->resource_manager()->LookupOrCreate<IteratorResource>( 746 cinfo->container(), cinfo->name(), iterator, 747 [lib, this, &flib_def, &pflr](IteratorResource** ret) 748 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 749 *ret = new IteratorResource( 750 output_dtypes_, output_shapes_, graph_def_version_, 751 nullptr, std::move(flib_def), std::move(pflr), lib); 752 return Status::OK(); 753 })); 754 755 core::ScopedUnref unref_iterator(*iterator); 756 757 TF_RETURN_IF_ERROR( 758 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes())); 759 TF_RETURN_IF_ERROR( 760 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes())); 761 762 // Call the dataset_factory_func_ to create a new dataset, 763 // over which this op will iterate. 764 FunctionLibraryRuntime::Handle f_handle; 765 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( 766 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()), 767 &f_handle)); 768 FunctionLibraryRuntime::Options opts; 769 opts.cancellation_manager = ctx->cancellation_manager(); 770 // Choose a step ID that is guaranteed not to clash with any 771 // Session-generated step ID. DirectSession only generates 772 // non-negative step IDs (contiguous, starting from 0), and 773 // MasterSession generates 56-bit random step IDs whose MSB is 774 // always 0, so a negative random step ID should suffice. 775 opts.step_id = -std::abs(static_cast<int64>(random::New64())); 776 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) { 777 ctx->resource_manager()->Cleanup(name).IgnoreError(); 778 }); 779 opts.step_container = &step_container; 780 opts.runner = ctx->runner(); 781 Notification n; 782 Status factory_status; 783 std::vector<Tensor> return_values; 784 ctx->function_library()->Run(opts, f_handle, {}, &return_values, 785 [&n, &factory_status](Status s) { 786 factory_status.Update(s); 787 n.Notify(); 788 }); 789 n.WaitForNotification(); 790 TF_RETURN_IF_ERROR(factory_status); 791 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT || 792 !TensorShapeUtils::IsScalar(return_values[0].shape())) { 793 return errors::InvalidArgument( 794 "The `dataset_factory` function must return " 795 "a single scalar of dtype DT_VARIANT."); 796 } 797 798 // Create an iterator for the dataset that was created in the 799 // factory function. 800 DatasetBase* dataset; 801 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); 802 TF_RETURN_IF_ERROR( 803 (*iterator)->set_iterator(dataset->MakeIterator("Iterator"))); 804 805 (*iterator)->Ref(); 806 return Status::OK(); 807 } 808 809 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) { 810 Tensor* handle; 811 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle), 812 done); 813 Status s; 814 { 815 mutex_lock l(mu_); 816 s = initialization_status_; 817 if (s.ok()) { 818 handle->scalar<ResourceHandle>()() = 819 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(), 820 cinfo_.name()); 821 } 822 } 823 OP_REQUIRES_OK_ASYNC(ctx, s, done); 824 done(); 825 } 826 827 NameAttrList dataset_factory_func_; 828 DataTypeVector output_dtypes_; 829 std::vector<PartialTensorShape> output_shapes_; 830 831 std::unique_ptr<thread::ThreadPool> thread_pool_; 832 833 mutex mu_; 834 ContainerInfo cinfo_ GUARDED_BY(mu_); 835 IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr; 836 837 bool initialization_started_ GUARDED_BY(mu_) = false; 838 Status initialization_status_ GUARDED_BY(mu_); 839 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_ 840 GUARDED_BY(mu_); 841 const int graph_def_version_; 842 }; 843 844 class IteratorGetNextOp : public AsyncOpKernel { 845 public: 846 explicit IteratorGetNextOp(OpKernelConstruction* ctx) 847 : AsyncOpKernel(ctx), 848 thread_pool_(new thread::ThreadPool( 849 ctx->env(), ThreadOptions(), 850 strings::StrCat("iterator_get_next_thread_", 851 SanitizeThreadSuffix(name())), 852 1 /* num_threads */, false /* low_latency_hint */)) {} 853 854 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 855 IteratorResource* iterator; 856 OP_REQUIRES_OK_ASYNC( 857 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); 858 // The call to `iterator->GetNext()` may block and depend on an 859 // inter-op thread pool thread, so we issue the call from the 860 // owned thread pool. 861 thread_pool_->Schedule(std::bind( 862 [this, ctx, iterator](DoneCallback done) { 863 core::ScopedUnref unref_iterator(iterator); 864 865 std::vector<Tensor> components; 866 bool end_of_sequence = false; 867 868 IteratorContext::Params params; 869 params.env = ctx->env(); 870 params.stats_aggregator_getter = [iterator]() { 871 return iterator->stats_aggregator(); 872 }; 873 params.runner = *(ctx->runner()); 874 params.function_library = iterator->function_library(); 875 DeviceBase* device = ctx->function_library()->device(); 876 params.allocator_getter = [device](AllocatorAttributes attrs) { 877 return device->GetAllocator(attrs); 878 }; 879 IteratorContext iter_ctx(std::move(params)); 880 881 OP_REQUIRES_OK_ASYNC( 882 ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), 883 done); 884 OP_REQUIRES_ASYNC(ctx, !end_of_sequence, 885 errors::OutOfRange("End of sequence"), done); 886 887 for (int i = 0; i < components.size(); ++i) { 888 // TODO(mrry): Check that the shapes match the shape attrs. 889 ctx->set_output(i, components[i]); 890 } 891 892 done(); 893 }, 894 std::move(done))); 895 } 896 897 private: 898 std::unique_ptr<thread::ThreadPool> thread_pool_; 899 }; 900 901 class IteratorGetNextSyncOp : public OpKernel { 902 public: 903 explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 904 905 void Compute(OpKernelContext* ctx) override { 906 IteratorResource* iterator; 907 OP_REQUIRES_OK(ctx, 908 LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); 909 core::ScopedUnref unref_iterator(iterator); 910 911 std::vector<Tensor> components; 912 bool end_of_sequence = false; 913 914 IteratorContext::Params params; 915 params.env = ctx->env(); 916 params.stats_aggregator_getter = [iterator]() { 917 return iterator->stats_aggregator(); 918 }; 919 params.runner = *(ctx->runner()); 920 params.function_library = iterator->function_library(); 921 DeviceBase* device = ctx->function_library()->device(); 922 params.allocator_getter = [device](AllocatorAttributes attrs) { 923 return device->GetAllocator(attrs); 924 }; 925 IteratorContext iter_ctx(std::move(params)); 926 927 OP_REQUIRES_OK(ctx, 928 iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); 929 OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); 930 931 for (int i = 0; i < components.size(); ++i) { 932 // TODO(mrry): Check that the shapes match the shape attrs. 933 ctx->set_output(i, components[i]); 934 } 935 } 936 }; 937 938 class IteratorToStringHandleOp : public OpKernel { 939 public: 940 explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) 941 : OpKernel(ctx) {} 942 943 void Compute(OpKernelContext* ctx) override { 944 const Tensor& resource_handle_t = ctx->input(0); 945 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), 946 errors::InvalidArgument("resource_handle must be a scalar")); 947 948 // Validate that the handle corresponds to a real resource, and 949 // that it is an IteratorResource. 950 IteratorResource* iterator_resource; 951 OP_REQUIRES_OK( 952 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); 953 iterator_resource->Unref(); 954 955 Tensor* string_handle_t; 956 OP_REQUIRES_OK(ctx, 957 ctx->allocate_output(0, TensorShape({}), &string_handle_t)); 958 string_handle_t->scalar<string>()() = 959 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); 960 } 961 }; 962 963 class IteratorFromStringHandleOp : public OpKernel { 964 public: 965 explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx) 966 : OpKernel(ctx) { 967 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); 968 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 969 OP_REQUIRES( 970 ctx, 971 output_dtypes_.empty() || output_shapes_.empty() || 972 output_dtypes_.size() == output_shapes_.size(), 973 errors::InvalidArgument("If both 'output_types' and 'output_shapes' " 974 "are set, they must have the same length.")); 975 } 976 977 void Compute(OpKernelContext* ctx) override { 978 const Tensor& string_handle_t = ctx->input(0); 979 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), 980 errors::InvalidArgument("string_handle must be a scalar")); 981 982 ResourceHandle resource_handle; 983 OP_REQUIRES( 984 ctx, 985 resource_handle.ParseFromString(string_handle_t.scalar<string>()()), 986 errors::InvalidArgument( 987 "Could not parse string_handle as a valid ResourceHandle")); 988 989 OP_REQUIRES( 990 ctx, resource_handle.device() == ctx->device()->attributes().name(), 991 errors::InvalidArgument("Attempted create an iterator on device \"", 992 ctx->device()->attributes().name(), 993 "\" from handle defined on device \"", 994 resource_handle.device(), "\"")); 995 996 // Validate that the handle corresponds to a real resource, and 997 // that it is an IteratorResource. 998 IteratorResource* iterator_resource; 999 OP_REQUIRES_OK(ctx, 1000 LookupResource(ctx, resource_handle, &iterator_resource)); 1001 core::ScopedUnref unref_iterator(iterator_resource); 1002 if (!output_dtypes_.empty()) { 1003 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_, 1004 iterator_resource->output_dtypes())); 1005 } 1006 if (!output_shapes_.empty()) { 1007 OP_REQUIRES_OK( 1008 ctx, VerifyShapesCompatible(output_shapes_, 1009 iterator_resource->output_shapes())); 1010 } 1011 1012 Tensor* resource_handle_t; 1013 OP_REQUIRES_OK( 1014 ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); 1015 resource_handle_t->scalar<ResourceHandle>()() = resource_handle; 1016 } 1017 1018 private: 1019 DataTypeVector output_dtypes_; 1020 std::vector<PartialTensorShape> output_shapes_; 1021 }; 1022 1023 class SerializeIteratorOp : public OpKernel { 1024 public: 1025 explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 1026 1027 void Compute(OpKernelContext* ctx) override { 1028 const Tensor& resource_handle_t = ctx->input(0); 1029 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), 1030 errors::InvalidArgument("resource_handle must be a scalar")); 1031 1032 // Validate that the handle corresponds to a real resource, and 1033 // that it is an IteratorResource. 1034 IteratorResource* iterator_resource; 1035 OP_REQUIRES_OK( 1036 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); 1037 core::ScopedUnref unref_iterator(iterator_resource); 1038 Tensor* variant_t; 1039 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t)); 1040 IteratorStateVariant v; 1041 OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource)); 1042 variant_t->scalar<Variant>()() = v; 1043 } 1044 }; 1045 1046 class DeserializeIteratorOp : public OpKernel { 1047 public: 1048 explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 1049 1050 void Compute(OpKernelContext* ctx) override { 1051 // Validate that the handle corresponds to a real resource, and 1052 // that it is an IteratorResource. 1053 IteratorResource* iterator_resource; 1054 OP_REQUIRES_OK( 1055 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); 1056 1057 Variant variant = ctx->input(1).scalar<Variant>()(); 1058 auto* wrapper = variant.get<IteratorStateVariant>(); 1059 OP_REQUIRES(ctx, wrapper != nullptr, 1060 errors::InvalidArgument( 1061 "DeserializeIteratorOp: Unable to parse variant tensor.")); 1062 OP_REQUIRES_OK(ctx, wrapper->status()); 1063 OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get())); 1064 } 1065 }; 1066 1067 class IteratorSetStatsAggregatorOp : public OpKernel { 1068 public: 1069 explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx) 1070 : OpKernel(ctx) {} 1071 1072 void Compute(OpKernelContext* ctx) override { 1073 IteratorResource* iterator_resource; 1074 OP_REQUIRES_OK( 1075 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); 1076 core::ScopedUnref unref_iterator(iterator_resource); 1077 1078 StatsAggregatorResource* stats_aggregator_resource; 1079 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), 1080 &stats_aggregator_resource)); 1081 core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); 1082 // TODO(mrry): Consider allowing multiple StatsAggregator ops to 1083 // subscribe to updates, and/or unsubscribing. 1084 OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(), 1085 errors::FailedPrecondition( 1086 "Iterator already associated with a StatsAggregator")); 1087 iterator_resource->set_stats_aggregator( 1088 stats_aggregator_resource->stats_aggregator()); 1089 } 1090 }; 1091 1092 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); 1093 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), 1094 MakeIteratorOp); 1095 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), 1096 ToSingleElementOp); 1097 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), 1098 OneShotIteratorOp); 1099 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), 1100 IteratorGetNextOp); 1101 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU), 1102 IteratorGetNextSyncOp); 1103 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU), 1104 IteratorToStringHandleOp); 1105 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU), 1106 IteratorFromStringHandleOp); 1107 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), 1108 SerializeIteratorOp); 1109 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), 1110 DeserializeIteratorOp); 1111 REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU), 1112 IteratorSetStatsAggregatorOp); 1113 1114 } // namespace 1115 1116 } // namespace tensorflow 1117