1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ 18 19 #include <atomic> 20 #include <functional> 21 22 #include <utility> 23 #include <vector> 24 #include "tensorflow/core/framework/allocator.h" 25 #include "tensorflow/core/framework/cancellation.h" 26 #include "tensorflow/core/framework/control_flow.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/framework/kernel_def.pb.h" 30 #include "tensorflow/core/framework/kernel_def_builder.h" 31 #include "tensorflow/core/framework/node_def_util.h" 32 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove 33 #include "tensorflow/core/framework/rendezvous.h" 34 #include "tensorflow/core/framework/selective_registration.h" 35 #include "tensorflow/core/framework/session_state.h" 36 #include "tensorflow/core/framework/tensor.h" 37 #include "tensorflow/core/framework/tensor_shape.h" 38 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove 39 #include "tensorflow/core/framework/tracking_allocator.h" 40 #include "tensorflow/core/framework/types.h" 41 #include "tensorflow/core/framework/types.pb.h" 42 #include "tensorflow/core/framework/unique_tensor_references.h" 43 #include "tensorflow/core/lib/core/errors.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/lib/gtl/array_slice.h" 46 #include "tensorflow/core/lib/gtl/manual_constructor.h" 47 #include "tensorflow/core/platform/env.h" 48 #include "tensorflow/core/platform/logging.h" 49 #include "tensorflow/core/platform/macros.h" 50 #include "tensorflow/core/platform/mutex.h" 51 #include "tensorflow/core/platform/profile_utils/cpu_utils.h" 52 #include "tensorflow/core/platform/thread_annotations.h" 53 #include "tensorflow/core/platform/types.h" 54 55 namespace Eigen { 56 struct ThreadPoolDevice; 57 struct GpuDevice; 58 struct SyclDevice; 59 } // end namespace Eigen 60 61 namespace tensorflow { 62 63 namespace checkpoint { 64 class TensorSliceReaderCacheWrapper; 65 } // namespace checkpoint 66 67 class AsyncOpKernel; 68 class CallFrameInterface; 69 class FunctionLibraryRuntime; 70 class OpKernelConstruction; // declared below 71 class OpKernelContext; // declared below, 72 class OpRegistryInterface; 73 class ResourceMgr; 74 class ScopedStepContainer; 75 class CollectiveExecutor; 76 class StepStatsCollectorInterface; 77 78 class OpKernel { 79 public: 80 // OpKernel won't be instantiated by the scheduler, so you may perform 81 // expensive initialization in the descendant's constructor. 82 explicit OpKernel(OpKernelConstruction* context); 83 84 // Specialized constructor that enables the descendant to provide a different 85 // `NodeDef` value. For example, this constructor can be used to provide a 86 // stripped-down `NodeDef` that does not contain the full set of attrs (such 87 // as tensor values) if the descendant stores them in a different form. 88 explicit OpKernel(OpKernelConstruction* context, 89 std::unique_ptr<const NodeDef> node_def); 90 91 virtual ~OpKernel(); 92 93 // An OpKernel's computation can be either synchronous or 94 // asynchronous. All OpKernel Compute() methods must be thread-safe as they 95 // may be called concurrently (e.g. by multiple executions of the same graph 96 // concurrently). 97 // 98 // Most OpKernels should compute synchronously. They should 99 // subclass OpKernel and override the Compute() method and have it 100 // return after completing the supplied work. 101 // 102 // A few special kernels might need to be asynchronous to bound the 103 // number of threads (e.g., network receive operations). These 104 // kernels must subclass AsyncOpKernel and override 105 // AsyncOpKernel::ComputeAsync(). 106 // 107 // In both cases, implementations of Compute() and ComputeAsync() 108 // get inputs and write outputs through the given OpKernelContext 109 // and returns a status via context->SetStatus(). They must be 110 // thread-safe. 111 112 // Synchronous compute. 113 // 114 // "context" is guaranteed to be alive until Compute() returns. 115 virtual void Compute(OpKernelContext* context) = 0; 116 117 // Returns nullptr iff this op kernel is synchronous. 118 virtual AsyncOpKernel* AsAsync() { return nullptr; } 119 virtual const AsyncOpKernel* AsAsync() const { return nullptr; } 120 121 // Initial time (in CPU cycles) we expect an operation to take. Used to 122 // determine whether an operation should be place in a threadpool. Operations 123 // start out "expensive". 124 static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; 125 static const uint64 kOpIsExpensiveThresholdCycles = 5000; 126 static const uint64 kCostDecay = 10; 127 128 // Returns true iff this op kernel is considered "expensive". The 129 // runtime may use this flag to optimize graph execution for example 130 // to "inline" inexpensive kernels. 131 virtual bool IsExpensive() { 132 return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) > 133 kOpIsExpensiveThresholdCycles); 134 } 135 136 // Updates the dynamic cost estimate, which is used to determine whether this 137 // op is expensive. The new cost estimate is a weighted average of the old 138 // cost estimate and the latest cost. 139 void UpdateCostEstimate(uint64 elapsed_cycles) { 140 // N.B. Updates to `cost_estimate_` are atomic but unlocked. Simulataneous 141 // updates may result in one or more updates being ignored. This does not 142 // affect correctness but may slow down the update frequency. 143 cost_estimate_.store( 144 (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) / 145 kCostDecay + 146 (elapsed_cycles / kCostDecay), 147 std::memory_order_relaxed); 148 } 149 150 // Accessors. 151 const NodeDef& def() const { return *def_; } 152 const string& name() const; // Same as def().name() 153 const string& type_string() const; // Same as def().op() 154 const string& requested_device() const; // Same as def().device() 155 bool is_internal() const { return is_internal_; } 156 157 int num_inputs() const { return input_types_.size(); } 158 DataType input_type(int i) const { return input_types_[i]; } 159 const DataTypeVector& input_types() const { return input_types_; } 160 const MemoryTypeVector& input_memory_types() const { 161 return input_memory_types_; 162 } 163 const string& requested_input(int i) const; // Same as def().input(i) 164 165 int num_outputs() const { return output_types_.size(); } 166 DataType output_type(int o) const { return output_types_[o]; } 167 const DataTypeVector& output_types() const { return output_types_; } 168 const MemoryTypeVector& output_memory_types() const { 169 return output_memory_types_; 170 } 171 172 Status InputRange(StringPiece input_name, int* start, int* stop) const; 173 Status OutputRange(StringPiece output_name, int* start, int* stop) const; 174 175 // We allow legacy scalars within Google up until GraphDef version 6. 176 // TODO(irving): Remove when we can drop support for GraphDef version 5. 177 bool allow_legacy_scalars() const { 178 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) 179 return graph_def_version_ < 6; 180 #else 181 return false; 182 #endif 183 } 184 185 // Allow either scalars or (if allowing legacy scalars) shape (1,). 186 bool IsLegacyScalar(const TensorShape& shape) const { 187 return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 && 188 shape.dim_size(0) == 1); 189 } 190 191 // Allow rank 1 or (if allowing legacy scalars) rank 0. 192 bool IsLegacyVector(const TensorShape& shape) const { 193 return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); 194 } 195 196 // Turn a shape Tensor into a TensorShape 197 // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars 198 Status MakeShape(const Tensor& shape, TensorShape* out) const; 199 200 static int DeviceNumaNode(const DeviceBase* device); 201 202 private: 203 const std::unique_ptr<const NodeDef> def_; 204 const DataTypeVector input_types_; 205 const MemoryTypeVector input_memory_types_; 206 const DataTypeVector output_types_; 207 const MemoryTypeVector output_memory_types_; 208 const int graph_def_version_; 209 const bool is_internal_; // True if this is an internal operation 210 NameRangeMap input_name_map_; 211 NameRangeMap output_name_map_; 212 bool expensive_; 213 std::atomic_uint_fast64_t cost_estimate_; 214 215 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); 216 }; 217 218 class AsyncOpKernel : public OpKernel { 219 public: 220 using OpKernel::OpKernel; // Lift OpKernel constructors. 221 222 // Asynchronous compute. 223 // 224 // Implementations of ComputeAsync() must run "done" to signal the 225 // completion of the computation. "context" is guaranteed to be 226 // alive until the "done" callback starts. 227 typedef std::function<void()> DoneCallback; 228 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; 229 230 AsyncOpKernel* AsAsync() final { return this; } 231 const AsyncOpKernel* AsAsync() const final { return this; } 232 233 void Compute(OpKernelContext* context) final; 234 }; 235 236 // Wraps a tensor that is held by an Op across calls to Compute(). For 237 // memory safety when using asynchronous devices like GPUs, the system 238 // must be notified when a Tensor is used inside an Op execution. The 239 // wrapper ensures that all uses of the Tensor are tracked, because in 240 // order to retrieve the Tensor the caller must use AccessTensor which 241 // notifies the context. 242 class PersistentTensor { 243 public: 244 PersistentTensor() {} 245 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} 246 247 // Caller does not own the returned Tensor*. 248 Tensor* AccessTensor(OpKernelConstruction* context); 249 // Caller does not own the returned Tensor*. 250 Tensor* AccessTensor(OpKernelContext* context); 251 252 // The check for initialization does not need to access the 253 // underlying tensor buffer. 254 bool IsInitialized() const { return tensor_.IsInitialized(); } 255 256 int64 NumElements() const { return tensor_.NumElements(); } 257 258 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); } 259 260 private: 261 Tensor tensor_; 262 }; 263 264 class OpKernelConstruction { 265 public: 266 OpKernelConstruction(DeviceType device_type, DeviceBase* device, 267 Allocator* allocator, const NodeDef* node_def, 268 const OpDef* op_def, FunctionLibraryRuntime* flib, 269 const DataTypeSlice& input_types, 270 const MemoryTypeSlice& input_memory_types, 271 const DataTypeSlice& output_types, 272 const MemoryTypeSlice& output_memory_types, 273 int graph_def_version, Status* status); 274 275 Env* env() const { return device_->env(); } 276 277 // Allocation of tensors during kernel construction: 278 // 279 // It is legal to temporarily allocate scratch tensor storage during 280 // Op kernel construction. Scratch tensors should be allocated using 281 // allocate_temp below. Some kernels need to keep tensors in between 282 // invocations. If such a Tensor is allocated during kernel 283 // construction this must be done using allocate_persistent, and the 284 // Op may only store the returned PersistentTensor object. When the 285 // Tensor is needed in a subsequent invocation, it can be retrieved 286 // from the PersistentTensor using the AccessTensor method. This 287 // ensures that the system is made aware of any use of the tensor's 288 // allocated memory, which is needed for correctness on asynchronous 289 // devices such as GPUs. 290 291 // Allocates a temporary Tensor of the specified type and shape. The 292 // Tensor must not be used after kernel construction is 293 // complete. See comment above. 294 Status allocate_temp(DataType type, const TensorShape& shape, 295 Tensor* out_temp); 296 297 // Allocates a Tensor of the specified type and shape which the Op 298 // plans to maintain as persistent state. out_persistent holds the 299 // PersistentTensor which is the object the caller should store. For 300 // convenience, if out_tensor is non-null then it will be filled in 301 // with a Tensor* pointing to the newly-allocated tensor which the 302 // caller can use instead of calling 303 // out_persistent->AccessTensor. The caller does not own out_tensor 304 // and should not keep a copy of it. See comment above. 305 Status allocate_persistent(DataType type, const TensorShape& shape, 306 PersistentTensor* out_persistent, 307 Tensor** out_tensor); 308 309 // User-supplied configuration of this operation. 310 const NodeDef& def() const { return *def_; } 311 312 // For inspecting the inputs to this operation. 313 int num_inputs() const { return input_types_.size(); } 314 DataType input_type(int i) const { return input_types_[i]; } 315 const DataTypeSlice& input_types() const { return input_types_; } 316 const MemoryTypeSlice& input_memory_types() const { 317 return input_memory_types_; 318 } 319 320 // For inspecting the outputs expected from this operation. 321 int num_outputs() const { return output_types_.size(); } 322 DataType output_type(int i) const { return output_types_[i]; } 323 const DataTypeSlice& output_types() const { return output_types_; } 324 const MemoryTypeSlice& output_memory_types() const { 325 return output_memory_types_; 326 } 327 328 // If expected_inputs == inputs() and expected_outputs == output_types(), 329 // returns OK, else returns INVALID_ARGUMENT with an error message. 330 // Recommended for Ops with dynamic signatures. 331 Status MatchSignature(const DataTypeSlice expected_inputs, 332 const DataTypeSlice expected_outputs); 333 334 // For recording configuration errors during construction. 335 void SetStatus(const Status& status); 336 const Status& status() const { return *status_; } 337 338 // Look up the attr with name attr_name and set *value to its value. If no 339 // attr with attr_name is found in def(), or the attr does not have 340 // a matching type, a non-ok status will be returned. 341 template <class T> 342 Status GetAttr(StringPiece attr_name, T* value) const; 343 344 // Return true if the attr_name is defined in def(). 345 bool HasAttr(StringPiece attr_name) const; 346 347 // Return the device type. 348 const DeviceType& device_type() const { return device_type_; } 349 350 // If not nullptr, the kernel can instantiate functions defined in 351 // the library. E.g., 352 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). 353 FunctionLibraryRuntime* function_library() const { return flib_; } 354 355 // The GraphDef version whose behavior we should follow. 356 int graph_def_version() const { return graph_def_version_; } 357 358 // Helper routines for the OP_REQUIRES macros 359 void CtxFailure(const Status& s); 360 void CtxFailureWithWarning(const Status& s); 361 void CtxFailure(const char* file, int line, const Status& s); 362 void CtxFailureWithWarning(const char* file, int line, const Status& s); 363 364 // Unrecommended functions: these are functions that have some 365 // current uses but are not recommended for use, and may go away at 366 // some future major version release. 367 368 // May be used, e.g., to get GPU handles, etc. 369 // 370 // Currently only used to call MakeTensorFromProto() for 371 // implementing ConstantOp for every device. See comments 372 // on Device::MakeTensorFromProto for longer-term replacement 373 // ideas. 374 DeviceBase* device() const { return device_; } 375 376 private: 377 const DeviceType device_type_; 378 DeviceBase* const device_; 379 Allocator* allocator_; 380 const NodeDef* def_; 381 const OpDef* op_def_; 382 FunctionLibraryRuntime* flib_; 383 DataTypeSlice input_types_; 384 MemoryTypeSlice input_memory_types_; 385 DataTypeSlice output_types_; 386 MemoryTypeSlice output_memory_types_; 387 const int graph_def_version_; 388 Status* status_; 389 390 // Allow op_def_ across from OpKernel, but not from subclasses. 391 // TODO(irving): Remove protos from this header entirely. 392 friend class OpKernel; 393 394 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); 395 }; 396 397 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading 398 // tensorflow::gtl::iterator_range to make the below container classes 399 // unnecessary. 400 template <typename ListType, typename ElementType> 401 class OpArgIterator { 402 public: 403 using iterator_category = std::forward_iterator_tag; 404 using value_type = ElementType; 405 using pointer = ElementType*; 406 using const_pointer = const ElementType*; 407 using reference = ElementType&; 408 using const_reference = const ElementType&; 409 using difference_type = ptrdiff_t; 410 411 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} 412 413 bool operator==(const OpArgIterator& rhs) { 414 DCHECK(list_ == rhs.list_); 415 return i_ == rhs.i_; 416 } 417 418 bool operator!=(const OpArgIterator& rhs) { 419 DCHECK(list_ == rhs.list_); 420 return i_ != rhs.i_; 421 } 422 423 OpArgIterator operator++() { // prefix ++it 424 ++i_; 425 return *this; 426 } 427 428 OpArgIterator operator++(int) { // postfix it++ 429 OpArgIterator old_value = *this; 430 ++i_; 431 return old_value; 432 } 433 434 reference operator*() { return (*list_)[i_]; } 435 pointer operator->() { return &(*list_)[i_]; } 436 437 const_reference operator*() const { return (*list_)[i_]; } 438 const_pointer operator->() const { return &(*list_)[i_]; } 439 440 private: 441 const ListType* const list_; 442 int i_; 443 }; 444 445 // Utility class for representing a list of immutable input tensors 446 // that are passed to the op as a single named argument. 447 class OpInputList { 448 public: 449 typedef OpArgIterator<OpInputList, const Tensor> Iterator; 450 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} 451 OpInputList(OpKernelContext* ctx, int start, int stop) 452 : ctx_(ctx), start_(start), stop_(stop) {} 453 OpInputList& operator=(const OpInputList& other) = default; 454 const Tensor& operator[](int i) const; 455 int size() const { return stop_ - start_; } 456 Iterator begin() const { return Iterator(this, 0); } 457 Iterator end() const { return Iterator(this, size()); } 458 459 private: 460 OpKernelContext* ctx_; // not owned 461 int start_; 462 int stop_; 463 }; 464 465 // Utility class for representing a list of mutable ("ref") input tensors 466 // that are passed to the op as a single named argument. 467 class OpMutableInputList { 468 public: 469 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator; 470 OpMutableInputList(OpKernelContext* ctx, int start, int stop) 471 : ctx_(ctx), start_(start), stop_(stop) {} 472 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} 473 OpMutableInputList& operator=(const OpMutableInputList& other) = default; 474 Tensor at(int i, bool lock_held); 475 mutex* ref_mutex(int i); 476 int size() const { return stop_ - start_; } 477 Iterator begin() const { return Iterator(this, 0); } 478 Iterator end() const { return Iterator(this, size()); } 479 480 private: 481 OpKernelContext* ctx_; // not owned 482 int start_; 483 int stop_; 484 }; 485 486 // Utility class for representing a list of output tensors that are 487 // grouped as a single named output. 488 class OpOutputList { 489 public: 490 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator; 491 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} 492 OpOutputList(OpKernelContext* ctx, int start, int stop) 493 : ctx_(ctx), start_(start), stop_(stop) {} 494 OpOutputList& operator=(const OpOutputList& other) = default; 495 Tensor* operator[](int i); 496 bool required(int i) const; 497 DataType expected_output_dtype(int i) const; 498 Status allocate(int i, const TensorShape& shape, Tensor** output); 499 void set(int i, const Tensor& tensor); 500 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); 501 int size() const { return stop_ - start_; } 502 Iterator begin() const { return Iterator(this, 0); } 503 Iterator end() const { return Iterator(this, size()); } 504 505 private: 506 OpKernelContext* ctx_; // not owned 507 int start_; 508 int stop_; 509 }; 510 511 // Holds a tensor or tensor reference. For tensor references, we need 512 // a mutex to prevent concurrent access to the tensor. 513 struct TensorValue { 514 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} 515 TensorValue(Tensor* t) // NOLINT(runtime/explicit) 516 : mutex_if_ref(nullptr), tensor(t) {} 517 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} 518 Tensor* operator->() const { return tensor; } 519 bool is_ref() const { return mutex_if_ref != nullptr; } 520 521 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref 522 Tensor* tensor; 523 }; 524 525 // Used to store partitioned graphs from function-calling ops. 526 struct GraphCollector { 527 mutex mu; 528 std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu); 529 GraphDef raw_graph GUARDED_BY(mu); 530 GraphDef optimized_graph GUARDED_BY(mu); 531 532 bool dirty GUARDED_BY(mu); 533 534 GraphCollector() : dirty(false) {} 535 536 void CollectRawGraph(const GraphDef& graph) { 537 mutex_lock ml(mu); 538 raw_graph.MergeFrom(graph); 539 dirty = true; 540 } 541 542 void CollectOptimizedGraph(const GraphDef& graph) { 543 mutex_lock ml(mu); 544 optimized_graph.MergeFrom(graph); 545 dirty = true; 546 } 547 548 void CollectPartitionedGraph(const GraphDef& graph) { 549 mutex_lock ml(mu); 550 partitioned_graphs.push_back(graph); 551 dirty = true; 552 } 553 554 void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) { 555 raw_graph.Clear(); 556 optimized_graph.Clear(); 557 partitioned_graphs.clear(); 558 dirty = false; 559 } 560 561 bool HasUpdatedGraphs() { 562 mutex_lock ml(mu); 563 return dirty; 564 } 565 }; 566 567 class OpKernelContext { 568 public: 569 // The first element of a WrappedAllocator is a "base" Allocator and 570 // the second element is that Allocator wrapped by a 571 // TrackingAllocator 572 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; 573 574 // TODO(zhifengc): Do some cleanup of Params. 575 // The Params struct is passed in to initialize an OpKernelContext, 576 // and must outlive the OpKernelContext. 577 struct Params { 578 ~Params() { delete eigen_gpu_device; } 579 580 // The step being executed. 581 int64 step_id = 0; 582 583 // The op kernel being computed. 584 OpKernel* op_kernel = nullptr; 585 586 // The device on which the kernel is running. 587 DeviceBase* device = nullptr; 588 589 // The Eigen GPU device wrapper, which may include a per-op 590 // wrapped allocator. The concrete type of this object depends on 591 // the type of this->device, so eigen_gpu_device can't be an 592 // inline member and must be heap allocated. However, we don't 593 // want to allocate a new eigen_gpu_device for every Op that is 594 // executed. Instead this member is allocated on first use using 595 // ensure_eigen_gpu_device, and then if the Params structure is 596 // re-used for subsequent Ops, the eigen_gpu_device is 597 // ReInitialized in the OpKernelContext constructor. Unlike the 598 // other pointers in Params, this one is owned by Params. 599 PerOpGpuDevice* eigen_gpu_device = nullptr; 600 601 inline void ensure_eigen_gpu_device() { 602 DCHECK(device); 603 if (nullptr == eigen_gpu_device) { 604 // Surprisingly, MakeGpuDevice will return nullptr if the 605 // device is not a GPU device. This is ok, since those devices 606 // will never use eigen_gpu_device. It seems better to have 607 // ensure_eigen_gpu_device fall through and regenerate the 608 // nullptr every time an OpKernelContext is instantiated, than 609 // to do an unnecessary allocation of a dummy eigen GPU 610 // device for CPU device Ops. 611 eigen_gpu_device = device->MakeGpuDevice(); 612 } 613 } 614 615 bool track_allocations = false; 616 bool log_memory = false; 617 bool record_tensor_accesses = false; 618 619 // Array indexed by output number for this node 620 const AllocatorAttributes* output_attr_array = nullptr; 621 622 // Shared resources accessible by this op kernel invocation. 623 ResourceMgr* resource_manager = nullptr; 624 625 // Per-step resources accessible by this op kernel invocation should be 626 // stored in this container.. 627 ScopedStepContainer* step_container = nullptr; 628 629 // Mechanism used by this op kernel invocation to communicate with 630 // computations running on other devices. 631 Rendezvous* rendezvous = nullptr; 632 633 // Mechanism for executing a collective op that needs to coordinate 634 // with parallel instances running on other devices. 635 CollectiveExecutor* collective_executor = nullptr; 636 637 // The session state for this op. 638 SessionState* session_state = nullptr; 639 640 // Unique session identifier. Can be empty. 641 string session_handle; 642 643 // The tensor store for this op. 644 TensorStore* tensor_store = nullptr; 645 646 // Mechanism used by this op kernel invocation to register a callback 647 // for its cancellation. 648 CancellationManager* cancellation_manager = nullptr; 649 650 // Inputs to this op kernel. 651 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr; 652 bool is_input_dead = false; 653 654 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs = 655 nullptr; 656 657 // Device contexts. 658 const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts = 659 nullptr; 660 DeviceContext* op_device_context = nullptr; 661 662 // Control-flow op supports. 663 FrameAndIter frame_iter; 664 665 // Function call supports. 666 CallFrameInterface* call_frame = nullptr; 667 FunctionLibraryRuntime* function_library = nullptr; 668 std::function<void(std::function<void()>)>* runner = nullptr; 669 StepStatsCollectorInterface* stats_collector = nullptr; 670 GraphCollector* graph_collector = nullptr; 671 672 // TensorSliceReaderCache support. 673 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; 674 675 // Support for forwarding reservations (used by ScopedAllocator). 676 static const int kNeverForward = -2; 677 static const int kNoReservation = -1; 678 // Values in [0,...) represent reservations for the indexed output. 679 const int* forward_from_array = nullptr; 680 681 // For tracking actively running deferred ops. 682 std::function<void()> inc_num_deferred_ops_function = []() {}; 683 std::function<void()> dec_num_deferred_ops_function = []() {}; 684 }; 685 686 // params must outlive the OpKernelContext. 687 explicit OpKernelContext(Params* params); 688 OpKernelContext(Params* params, int noutputs); 689 ~OpKernelContext(); 690 691 Env* env() const { return params_->device->env(); } 692 693 int64 step_id() const { return params_->step_id; } 694 695 const OpKernel& op_kernel() const { return *params_->op_kernel; } 696 697 // Input/output signature. 698 699 int num_inputs() const { return params_->inputs->size(); } 700 DataType input_dtype(int index) const; 701 Status input_dtype(StringPiece name, DataType* dtype) const; 702 MemoryType input_memory_type(int index) const; 703 704 int num_outputs() const { return outputs_.size(); } 705 DataType expected_output_dtype(int index) const; 706 MemoryType output_memory_type(int index) const; 707 708 // Input 709 710 // Returns an immutable input tensor. May only be used for non-Ref 711 // inputs. For Ref inputs use mutable_input below. 712 // REQUIRES: !IsRefType(input_dtype(index)) 713 // TODO(mrry): Convert this to return Status. 714 const Tensor& input(int index); 715 716 // Returns the named immutable input tensor in "tensor", as defined 717 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs 718 // use mutable_input below. 719 // REQUIRES: !IsRefType(input_dtype(index)) 720 // REQUIRES: the named input must not be a list. 721 Status input(StringPiece name, const Tensor** tensor); 722 723 // Returns the named list-valued immutable input in "list", as 724 // defined in the OpDef. If the named output is not list-valued, 725 // returns a one-element list. May only be used for non-Ref 726 // inputs. For Ref inputs use mutable_input below. 727 // REQUIRES: !IsRefType(input_dtype(index)) 728 Status input_list(StringPiece name, OpInputList* list); 729 730 // For mutable inputs, use the following together to make sure there 731 // is no concurrent access to mutable_input(), e.g.: 732 // { 733 // Tensor& t = context->mutable_input(index); 734 // mutex_lock lock(*context->input_ref_mutex(index)); 735 // // modify the values in t 736 // } 737 // REQUIRES: IsRefType(input_dtype(index)) 738 Status input_ref_mutex(StringPiece name, mutex** out_mutex); 739 740 // Returns a mutable input tensor. Must be used to access Ref 741 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may 742 // modify the values stored in the Tensor buffer, and modifications 743 // will be visible to other Ops reading the same ref tensor. If 744 // !lock_held the input mutex will be acquired before returning the 745 // Tensor. 746 // TODO(mrry): Convert this to return Status. 747 Tensor mutable_input(int index, bool lock_held); 748 749 // Returns the named mutable input tensor in "tensor", as defined in 750 // the OpDef. Must be used to access Ref inputs. The values stored 751 // in the Tensor buffer may be modified, and modifications will be 752 // visible to other Ops reading the same ref tensor. If !lock_held 753 // the input mutex will be acquired before returning the Tensor. 754 // REQUIRES: the named input must not be a list. 755 // REQUIRES: the named input must be a ref tensor. 756 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); 757 758 // Returns the named list-valued mutable input in "list", as defined 759 // in the OpDef. If the named input is not list-valued, returns a 760 // one-element list. Must be used to access Ref inputs. The values 761 // stored in the Tensor buffer may be modified, and modifications 762 // will be visible to other Ops reading the same ref tensor. 763 // REQUIRES: the named input must be a ref tensor. 764 Status mutable_input_list(StringPiece name, OpMutableInputList* list); 765 766 // Replace the corresponding Ref Input to use the storage buffer 767 // used by tensor. If !lock_held the input mutex will be acquired 768 // before returning the Tensor. 769 // REQUIRES: IsRefType(input_dtype(index)). 770 void replace_ref_input(int index, const Tensor& tensor, bool lock_held); 771 772 // Replace the corresponding named Ref Input to use the storage 773 // buffer used by tensor. If !lock_held the input mutex will be 774 // acquired before returning the Tensor. 775 // REQUIRES: IsRefType(input_dtype(index)). 776 Status replace_ref_input(StringPiece name, const Tensor& tensor, 777 bool lock_held); 778 779 // Deletes the Tensor object used as the Ref Input at 780 // input_index. This is not usually necessary and should be used 781 // with caution. If !lock_held the input mutex will be acquired 782 // before returning the Tensor. 783 // REQUIRES: IsRefType(input_dtype(input_index)). 784 void delete_ref_input(int input_index, bool lock_held); 785 786 // Return true if there is input at the given index. An operator has no 787 // input at index if its tensor is null. This is primarily used by the 788 // merge operator. 789 // TODO(mrry): Convert this to return Status. 790 bool has_input(int index) const; 791 792 // Returns true if all inputs are the same shape, otherwise sets the 793 // status to a non-OK value and returns false. 794 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 795 bool ValidateInputsAreSameShape(OpKernel* op); 796 797 // If non-null, kernels should populate with any partition subgraphs created. 798 GraphCollector* graph_collector() { return params_->graph_collector; } 799 800 // Input to output forwarding. 801 802 // Set the output Ref Tensor at output_index to be an alias of the 803 // input Ref Tensor at input_index. 804 // REQUIRES: IsRefType(input_dtype(input_index)). 805 // REQUIRES: IsRefType(output_dtype(output_index)). 806 void forward_ref_input_to_ref_output(int input_index, int output_index); 807 808 // Returns true when an alias to input[input_index], reshaped to output_shape, 809 // which is safe to use for in-place computation was written to *output. 810 // Returns false if input[input_index] has a refcount greater than one, or if 811 // its type does not match the expected output type of output[output_index], 812 // or the number of elements in input[input_index] does not equal the number 813 // of elements in output_shape. 814 bool forward_input_to_output_with_shape(int input_index, int output_index, 815 const TensorShape& output_shape, 816 Tensor** output) TF_MUST_USE_RESULT; 817 Status forward_input_to_output_with_shape(StringPiece input_name, 818 StringPiece output_name, 819 const TensorShape& output_shape, 820 Tensor** output) TF_MUST_USE_RESULT; 821 822 // Returns a pointer to a Tensor aliasing the underlying buffer backing 823 // input[input_index] iff 824 // * input[input_index] is not a ref, 825 // * the data type, shape, memory type, and allocator attributes of 826 // input[input_index] are compatible with those given in dtype, shape, 827 // memory_type, and attr, 828 // * refcount on the underlying buffer is one. 829 // * Either there is no forwarding reservation for either input_index 830 // or output_index or the specified input is reserved for the specified 831 // output. More precisely: 832 // 833 // These cases mean neither input nor output has a reservation: 834 // forward_from_array = nullptr 835 // OR (input_index is not in forward_from_array AND 836 // (output_index == kNoReservation OR 837 // forward_from_array[output_index] == kNoReservation)) 838 // 839 // This case means that input_index is reserved for output_index: 840 // forward_from_array[output_index] == input_index 841 // 842 // This case means the output is reserved to always be allocated, 843 // never assigned a forwarded input: 844 // forward_from_array[output_index] == kNeverForward 845 // 846 // Otherwise returns nullptr. 847 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, 848 // forwarding is only safe if there are no reads via __ldg() after writes 849 // to the same address. 850 std::unique_ptr<Tensor> forward_input( 851 int input_index, int output_index, DataType output_dtype, 852 const TensorShape& output_shape, MemoryType output_memory_type, 853 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT; 854 855 // Tries to forward one of the inputs given in input_indices to 856 // output[output_index]. If none of the given inputs can be forwarded, calls 857 // allocate_output() to allocate a new output buffer. 858 Status forward_input_or_allocate_output( 859 gtl::ArraySlice<int> candidate_input_indices, int output_index, 860 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; 861 Status forward_input_or_allocate_output( 862 gtl::ArraySlice<StringPiece> candidate_input_names, 863 StringPiece output_name, const TensorShape& output_shape, 864 Tensor** output) TF_MUST_USE_RESULT; 865 866 // Tries to reuse one of the inputs given in input_indices as a temporary. 867 // If none of the given inputs can be forwarded, calls 868 // allocate_temp() to allocate a new temporary buffer. 869 Status forward_input_or_allocate_temp( 870 gtl::ArraySlice<int> candidate_input_indices, DataType type, 871 const TensorShape& shape, const AllocatorAttributes& allocator_attr, 872 Tensor* out_temp) TF_MUST_USE_RESULT; 873 874 Status forward_input_or_allocate_temp( 875 gtl::ArraySlice<int> candidate_input_indices, DataType type, 876 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { 877 return forward_input_or_allocate_temp(candidate_input_indices, type, shape, 878 AllocatorAttributes(), out_temp); 879 } 880 881 // Output 882 883 // Returns the named list-valued output in "list", as defined in the OpDef. 884 // If the named output is not list-valued, returns a one-element list. 885 Status output_list(StringPiece name, OpOutputList* list); 886 887 // If output_required(index) returns true, the OpKernel's Compute() method 888 // should call allocate_output(index, ...), set_output(index, ...), 889 // set_output_ref(index, ...), or set the status to a non-ok value. 890 // If it returns false, it may output, but is not required to do so. 891 // TODO(mrry): Convert this to return Status, and implement a string 892 // name version. 893 bool output_required(int index) const { 894 return true; // TODO(josh11b): implement 895 } 896 897 // Allocation of tensors during kernel execution inside the Compute 898 // method: 899 // 900 // There are three methods to allocate Tensors when an Op kernel 901 // executes. 902 // 903 // 1) allocate_persistent. This is only needed for Tensors that will 904 // be stored by the Op between invocations, and it *must* be used 905 // for those Tensors. The call returns a PersistentTensor, and that 906 // is the only object the Op is allowed to hold on to between 907 // invocations. When the Tensor is needed in a subsequent 908 // invocation, it can be retrieved from the PersistentTensor using 909 // the AccessTensor method. This ensures that the system is made 910 // aware of any use of the tensor's allocated memory, which is 911 // needed for correctness on asynchronous devices such as GPUs. 912 // 913 // 2) allocate_output. This should be used to allocate any tensor 914 // that is going to be used as an output from the Op at the end of 915 // the current execution. The caller indicates which output the 916 // Tensor will be assigned to, and the call returns the 917 // newly-allocated Tensor. The Tensor can subsequently be assigned 918 // to during kernel execution, and will be used as the designated 919 // output when the kernel execution completes. 920 // 921 // 3) allocate_temp. This should be used to allocate any scratch 922 // storage that is needed while the kernel is executing, and will 923 // not be retained by the Op. 924 // 925 // In some cases a Tensor needs to be used as an output even though 926 // it was previously allocated elsewhere. The Tensor may have been 927 // passed as an input, or stored in a PersistentTensor during a 928 // previous kernel execution, or allocated earlier in the kernel 929 // execution at a time when it was not known which output it would 930 // be assigned to. In this case the kernel can use set_output or 931 // set_output_ref to indicate that the tensor should be used as the 932 // designated output. It is legal to use any previously-allocated 933 // Tensor as an argument to set_output or set_output_ref, including 934 // Tensors allocated via allocate_temp. There may be a performance 935 // penalty to using a Tensor that was not allocated using 936 // allocate_output. This is because allocate_output uses the 937 // AllocatorAttributes stored in output_attr_array for the 938 // designated output. In some cases, using the wrong attributes may 939 // cause an extra copy of the Tensor's buffer. 940 941 // Allocates output for the specified output index with shape. 942 // OpKernelContext retains ownership of the returned pointer. See 943 // comment above. 944 // 945 // If memory allocation fails, returns an error status. 946 // 947 // REQUIRES: !IsRefType(expected_output_dtype(index)) 948 Status allocate_output(int index, const TensorShape& shape, 949 Tensor** tensor) TF_MUST_USE_RESULT; 950 Status allocate_output(StringPiece name, const TensorShape& shape, 951 Tensor** tensor) TF_MUST_USE_RESULT; 952 // The following methods use the supplied attributes instead of 953 // those in output_attr_array. The caller is responsible for 954 // ensuring that the attributes are "compatible" with the 955 // output_attr_array, e.g. the tensor is allocated on the correct 956 // device. See comment above. 957 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, 958 AllocatorAttributes attr) TF_MUST_USE_RESULT; 959 Status allocate_output(StringPiece name, const TensorShape& shape, 960 Tensor** tensor, 961 AllocatorAttributes attr) TF_MUST_USE_RESULT; 962 963 // Allocates a temporary Tensor of the specified type and 964 // shape. Devices such as GPUs that enqueue Ops for lazy execution 965 // may retain references to the temporary tensors after the Op's 966 // Compute method has run. See comment above. 967 Status allocate_temp(DataType type, const TensorShape& shape, 968 Tensor* out_temp, AllocatorAttributes allocator_attr, 969 const AllocationAttributes& allocation_attr); 970 Status allocate_temp(DataType type, const TensorShape& shape, 971 Tensor* out_temp, AllocatorAttributes allocator_attr) { 972 return allocate_temp(type, shape, out_temp, allocator_attr, 973 AllocationAttributes()); 974 } 975 Status allocate_temp(DataType type, const TensorShape& shape, 976 Tensor* out_temp) { 977 return allocate_temp(type, shape, out_temp, AllocatorAttributes()); 978 } 979 980 // Allocates a Tensor of the specified type and shape which the Op 981 // plans to maintain as persistent state. out_persistent holds the 982 // PersistentTensor which is the object the caller should store. For 983 // convenience, if out_tensor is non-null then it will be filled in 984 // with a Tensor* pointing to the newly-allocated tensor which the 985 // caller can use instead of calling 986 // out_persistent->AccessTensor. The caller does not own out_tensor 987 // and should not keep a copy of it. See comment above. 988 Status allocate_persistent(DataType type, const TensorShape& shape, 989 PersistentTensor* out_persistent, 990 Tensor** out_tensor, AllocatorAttributes attr); 991 Status allocate_persistent(DataType type, const TensorShape& shape, 992 PersistentTensor* out_persistent, 993 Tensor** out_tensor) { 994 return allocate_persistent(type, shape, out_persistent, out_tensor, 995 AllocatorAttributes()); 996 } 997 998 // Copies a tensor (allocated by the caller) to the specified output 999 // index. REQUIRES: !IsRefType(expected_output_dtype(index)) 1000 // REQUIRES: 'tensor' must have the same MemoryType as 1001 // output_memory_types[index]. See comment above. 1002 Status set_output(StringPiece name, const Tensor& tensor); 1003 1004 // To output a reference. Caller retains ownership of mu and tensor_for_ref, 1005 // and they must outlive all uses within the step. See comment above. 1006 // REQUIRES: IsRefType(expected_output_dtype(index)) 1007 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); 1008 1009 // Returns nullptr if allocate_output() or set_output() have not been called. 1010 Status mutable_output(StringPiece name, Tensor** tensor); 1011 1012 // Records device specific state about how the input tensors were 1013 // computed. 1014 // 1015 // If using the templated function, the type must be a subclass 1016 // of DeviceContext. 1017 // 1018 // Get the DeviceContext used for the index input. Returns nullptr 1019 // if no DeviceContext was provided. 1020 template <typename T> 1021 T* input_device_context(int index); 1022 DeviceContext* input_device_context(int index); 1023 1024 // Return the DeviceContext that should be used for this Op. 1025 // 1026 // If using the templated function, the type must be a subclass 1027 // of DeviceContext. 1028 // 1029 // Returns nullptr if the device did not provide one. 1030 template <typename T> 1031 T* op_device_context(); 1032 DeviceContext* op_device_context() { 1033 DeviceContext* ret = params_->op_device_context; 1034 if (ret == nullptr) { 1035 auto* dev_info = device()->tensorflow_gpu_device_info(); 1036 if (dev_info) ret = dev_info->default_context; 1037 } 1038 return ret; 1039 } 1040 1041 AllocatorAttributes input_alloc_attr(int index) const { 1042 if (params_->input_alloc_attrs == nullptr) { 1043 return AllocatorAttributes(); 1044 } else { 1045 DCHECK_GE(index, 0); 1046 DCHECK_LT(index, params_->input_alloc_attrs->size()); 1047 return (*params_->input_alloc_attrs)[index]; 1048 } 1049 } 1050 1051 AllocatorAttributes output_alloc_attr(int index) const { 1052 return params_->output_attr_array[index]; 1053 } 1054 1055 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() { 1056 mutex_lock lock(mu_); 1057 gtl::InlinedVector<WrappedAllocator, 4> retrieved; 1058 retrieved.swap(wrapped_allocators_); 1059 return retrieved; 1060 } 1061 1062 // Communication. 1063 // 1064 // An op kernel communicates with outside environment through 1065 // Rendezvous Send() and Recv(). 1066 Rendezvous* rendezvous() const { return params_->rendezvous; } 1067 1068 CollectiveExecutor* collective_executor() const { 1069 return params_->collective_executor; 1070 } 1071 1072 // An op kernel can access the session state it belongs to. 1073 SessionState* session_state() const { return params_->session_state; } 1074 1075 // Unique identifier of the session it belongs to. Can be empty. 1076 string session_handle() const { return params_->session_handle; } 1077 1078 // An op kernel can access the tensor store of the run it belongs to. 1079 TensorStore* tensor_store() const { return params_->tensor_store; } 1080 1081 // Function call support. 1082 // 1083 // If this kernel invocation is within a function execution, 1084 // call_frame() returns the call frame for the function call. 1085 CallFrameInterface* call_frame() const { return params_->call_frame; } 1086 1087 // If not nullptr, the kernel invoke functions defined in the 1088 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). 1089 FunctionLibraryRuntime* function_library() const { 1090 return params_->function_library; 1091 } 1092 1093 std::function<void(std::function<void()>)>* runner() const { 1094 return params_->runner; 1095 } 1096 StepStatsCollectorInterface* stats_collector() const { 1097 return params_->stats_collector; 1098 } 1099 1100 // Shared resources accessible to this kernel. 1101 ResourceMgr* resource_manager() const { return params_->resource_manager; } 1102 1103 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { 1104 return params_->slice_reader_cache; 1105 } 1106 1107 // Execution. 1108 // 1109 // OpKernels can use these eigen devices to carry out their 1110 // numerical computation. 1111 const Eigen::ThreadPoolDevice& eigen_cpu_device() const { 1112 return *device()->eigen_cpu_device(); 1113 } 1114 const Eigen::GpuDevice& eigen_gpu_device() const { 1115 return params_->eigen_gpu_device->device(); 1116 } 1117 #ifdef TENSORFLOW_USE_SYCL 1118 const Eigen::SyclDevice& eigen_sycl_device() const { 1119 return *device()->eigen_sycl_device(); 1120 } 1121 #endif 1122 template <typename EigenDeviceType> 1123 const EigenDeviceType& eigen_device() const; 1124 1125 // Error handling. 1126 1127 // If expected_inputs == inputs() and expected_outputs == output_types(), 1128 // returns OK, else returns INVALID_ARGUMENT with an error message. 1129 // Recommended for Ops with dynamic signatures, where validation can only 1130 // be performed at runtime. 1131 Status MatchSignature(const DataTypeSlice expected_inputs, 1132 const DataTypeSlice expected_outputs); 1133 1134 // An OpKernel should call SetStatus() if Compute() encounters an 1135 // error. 1136 void SetStatus(const Status& status); 1137 const Status& status() const { return status_; } 1138 1139 // Cancellation. 1140 // 1141 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an 1142 // example of how to use this API. 1143 CancellationManager* cancellation_manager() const { 1144 return params_->cancellation_manager; 1145 } 1146 1147 // Other accessors. 1148 1149 // For control flow. 1150 FrameAndIter frame_iter() const { return params_->frame_iter; } 1151 bool is_input_dead() const { return params_->is_input_dead; } 1152 1153 // May be used, e.g., to get GPU handles, etc. 1154 // TODO(tucker): Add example usage. 1155 DeviceBase* device() const { return params_->device; } 1156 1157 // Retrieve list of referenced tensors in out_vector. Once this is 1158 // called, it is not legal to reference any more tensors. Should 1159 // not be called from Op kernels. 1160 void retrieve_accessed_tensors(TensorReferenceVector* out_vector); 1161 1162 // Per-step container for use by white-listed internal ops. 1163 ScopedStepContainer* step_container() const { 1164 return params_->step_container; 1165 } 1166 1167 // Helper routines for the OP_REQUIRES macros 1168 void CtxFailure(const Status& s); 1169 void CtxFailureWithWarning(const Status& s); 1170 void CtxFailure(const char* file, int line, const Status& s); 1171 void CtxFailureWithWarning(const char* file, int line, const Status& s); 1172 1173 // Unrecommended functions: these are functions that have some 1174 // current uses but are not recommended for use, and may go away at 1175 // some future major version release. 1176 // 1177 // The following functions all have versions that return Status 1178 // to capture error conditions, and are strongly preferred. 1179 Tensor* mutable_output(int index); 1180 void set_output(int index, const Tensor& tensor); 1181 mutex* input_ref_mutex(int index); 1182 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); 1183 TensorValue release_output(int index); 1184 1185 bool track_allocations() const { return params_->track_allocations; } 1186 1187 // Records temp memory allocation. Tensor object is recorded to identify the 1188 // case where temp memory is used as output memory. 1189 void record_temp_memory_allocation(int64 size, const Tensor& t) 1190 LOCKS_EXCLUDED(stats_mu_); 1191 1192 // Returns recorded size of temporary memory; 1193 int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); 1194 1195 // Records persistent memory allocation, size can be negative indicating 1196 // deallocation. 1197 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1) 1198 LOCKS_EXCLUDED(stats_mu_); 1199 1200 // Returns recorded size and ids of persistent memory. 1201 int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); 1202 1203 std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_); 1204 1205 // Resets counters for temp and persistent memory and recorded ids. 1206 void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_); 1207 1208 bool input_is_ref(int index) const; 1209 1210 // Used by OpKernel implementations to track actively running deferred ops. 1211 // 1212 // A deferred op is one whose Compute method returns (or whose ComputeAsync 1213 // method invokes the callback) when work is scheduled onto a device. At that 1214 // point, we don't know when the work will actually complete (or if it has 1215 // already completed) on the device. These functions allow the executor to 1216 // track the status of deferred ops and act accordingly. 1217 // 1218 // Deferred OpKernel implementations must use these methods to get two 1219 // functions. It then must call these two functions in pairs, before and after 1220 // device execution, respectively. 1221 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() { 1222 return params_->inc_num_deferred_ops_function; 1223 } 1224 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() { 1225 return params_->dec_num_deferred_ops_function; 1226 } 1227 1228 private: 1229 Allocator* get_allocator(AllocatorAttributes attr); 1230 1231 // Internal method to add a tensor's buffer to the list of buffers 1232 // referenced during the execution of the Op, so that GPUs may 1233 // accurately track the memory that may not be reused until the Op 1234 // execution completes. 1235 void record_tensor_reference(const Tensor& tensor); 1236 void really_record_tensor_reference(const Tensor& tensor); 1237 1238 // Internal common method used when allocating tensor memory 1239 Status allocate_tensor(DataType type, const TensorShape& shape, 1240 Tensor* out_tensor, 1241 AllocatorAttributes allocator_attr) { 1242 return allocate_tensor(type, shape, out_tensor, allocator_attr, 1243 AllocationAttributes()); 1244 } 1245 1246 Status allocate_tensor(DataType type, const TensorShape& shape, 1247 Tensor* out_tensor, AllocatorAttributes allocator_attr, 1248 const AllocationAttributes& allocation_attr); 1249 1250 // This is called by PersistentTensor::AccessTensor whenever the 1251 // wrapped tensor is retrieved, to ensure the runtime knows that the 1252 // Tensor is being accessed within an Op. This is necessary for 1253 // memory safety of devices like GPUs that queue Ops for 1254 // asynchronous execution after the Compute() method completes. 1255 friend class PersistentTensor; 1256 void NotifyUseOfPersistentTensor(const Tensor& tensor); 1257 1258 Status status_; 1259 friend class CollectiveExecutor; // for access to params_ 1260 Params* params_; // not owned 1261 mutable mutex mu_; // mutable so const accessors can acquire the lock 1262 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); 1263 gtl::InlinedVector<TensorValue, 4> outputs_; 1264 1265 // Constructed only if <params->record_tensor_accesses>. 1266 ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_); 1267 1268 // The following data members are only used when allocation tracking is 1269 // enabled. 1270 mutable mutex stats_mu_; 1271 int64 temp_memory_allocated_ GUARDED_BY(stats_mu_); 1272 int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_); 1273 std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>> 1274 temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_); 1275 std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_ 1276 GUARDED_BY(stats_mu_); 1277 1278 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); 1279 }; 1280 1281 // Register your OpKernel by specifying the Op's name, the device the 1282 // kernel runs on, any type attr constraints for this kernel, any 1283 // host-memory args, and the class to instantiate. Examples: 1284 // 1285 // // A kernel that supports all types. 1286 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); 1287 // 1288 // // The following are equivalent ways of specifying that the kernel only 1289 // // works if the "T" type attr is set to DT_FLOAT. 1290 // REGISTER_KERNEL_BUILDER( 1291 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"), 1292 // SubOp<float>); 1293 // // (You would then repeat this for every type supported by "Sub".) 1294 // 1295 // // This form allows you to specify a list of types as the constraint. 1296 // REGISTER_KERNEL_BUILDER(Name("Sub") 1297 // .Device(DEVICE_CPU) 1298 // .TypeConstraint("T", {DT_FLOAT}), 1299 // SubOp<float>); 1300 // 1301 // // A kernel that expects one of the input tensors in host memory. 1302 // REGISTER_KERNEL_BUILDER( 1303 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); 1304 // 1305 // See kernel_def_builder for details. 1306 1307 // Instantiate an OpKernel that has been registered. Returns nullptr 1308 // if no operation for that type of device / input signature combination 1309 // (and a NOT_FOUND *status), or there is an error in construction (and 1310 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership 1311 // of the returned pointer. 1312 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...); 1313 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). 1314 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, 1315 DeviceBase* device, 1316 Allocator* allocator, 1317 const NodeDef& def, 1318 int graph_def_version, Status* status); 1319 Status CreateOpKernel(DeviceType device_type, DeviceBase* device, 1320 Allocator* allocator, FunctionLibraryRuntime* flib, 1321 const NodeDef& def, int graph_def_version, 1322 OpKernel** kernel); 1323 1324 // Returns into 'device_types' the subset of prioritized_types that this 1325 // binary has registered for the given NodeDef. 1326 // 1327 // REQUIRES: * 'device_types' is not nullptr. 1328 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). 1329 Status SupportedDeviceTypesForNode( 1330 const std::vector<DeviceType>& prioritized_types, const NodeDef& def, 1331 PrioritizedDeviceTypeVector* device_types); 1332 1333 // Returns a message with a description of the kernels registered for op 1334 // `op_name`. 1335 string KernelsRegisteredForOp(StringPiece op_name); 1336 1337 // Call once after Op registration has completed. 1338 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); 1339 1340 // ----------------------------------------------------------------------------- 1341 // OpKernel registration implementation follows, please ignore. 1342 1343 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. 1344 namespace register_kernel { 1345 1346 class Name : public KernelDefBuilder { 1347 public: 1348 // With selective registration, kernels whose implementation class is not used 1349 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in 1350 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an 1351 // implementation class with a used kernel would get through that mechanism. 1352 // 1353 // This mechanism stops that registration by changing the name of the kernel 1354 // for the unused op to one that is ignored by 1355 // OpKernelRegistrar::InitInternal. Note that this method alone is 1356 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at 1357 // compilation time, so this method doesn't actually reduce code size. 1358 explicit Name(const char* op) 1359 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {} 1360 }; 1361 1362 namespace system { 1363 1364 class Name : public KernelDefBuilder { 1365 public: 1366 // For system kernels, we ignore selective registration and 1367 // unconditionally register the kernel. 1368 explicit Name(const char* op) : KernelDefBuilder(op) {} 1369 }; 1370 1371 } // namespace system 1372 1373 } // namespace register_kernel 1374 1375 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ 1376 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) 1377 1378 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ 1379 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) 1380 1381 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ 1382 constexpr bool should_register_##ctr##__flag = \ 1383 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \ 1384 static ::tensorflow::kernel_factory::OpKernelRegistrar \ 1385 registrar__body__##ctr##__object( \ 1386 should_register_##ctr##__flag \ 1387 ? ::tensorflow::register_kernel::kernel_builder.Build() \ 1388 : nullptr, \ 1389 #__VA_ARGS__, \ 1390 [](::tensorflow::OpKernelConstruction* context) \ 1391 -> ::tensorflow::OpKernel* { \ 1392 return new __VA_ARGS__(context); \ 1393 }); 1394 1395 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as 1396 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered 1397 // unconditionally even when selective registration is used. 1398 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ 1399 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \ 1400 __VA_ARGS__) 1401 1402 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ 1403 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) 1404 1405 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ 1406 static ::tensorflow::kernel_factory::OpKernelRegistrar \ 1407 registrar__body__##ctr##__object( \ 1408 ::tensorflow::register_kernel::system::kernel_builder.Build(), \ 1409 #__VA_ARGS__, \ 1410 [](::tensorflow::OpKernelConstruction* context) \ 1411 -> ::tensorflow::OpKernel* { \ 1412 return new __VA_ARGS__(context); \ 1413 }); 1414 1415 // Checks whether a given kernel is registered on device_type. 1416 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def); 1417 1418 // If node_def has a corresponding kernel registered on device_type, 1419 // returns OK and fill in the kernel def and kernel_class_name. <def> and 1420 // <kernel_class_name> may be null. 1421 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, 1422 const KernelDef** def, string* kernel_class_name); 1423 1424 // Writes a list of all registered kernels to LOG(INFO), to help users debug 1425 // missing kernel errors. 1426 void LogAllRegisteredKernels(); 1427 1428 // Gets a list of all registered kernels. 1429 KernelList GetAllRegisteredKernels(); 1430 1431 // Gets a list of all registered kernels for which predicate returns true 1432 KernelList GetFilteredRegisteredKernels( 1433 const std::function<bool(const KernelDef&)>& predicate); 1434 1435 // Gets a list of all registered kernels for a given op 1436 KernelList GetRegisteredKernelsForOp(StringPiece op_name); 1437 1438 namespace kernel_factory { 1439 1440 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs 1441 // them. You register factories with the TensorFlow core by constructing an 1442 // OpKernelRegistrar and passing the factory as a constructor parameter. 1443 class OpKernelFactory { 1444 public: 1445 virtual OpKernel* Create(OpKernelConstruction* context) = 0; 1446 virtual ~OpKernelFactory() = default; 1447 }; 1448 1449 class OpKernelRegistrar { 1450 public: 1451 // Registers the given kernel factory with TensorFlow. TF will call the 1452 // factory Create() method when it determines that a kernel matching the given 1453 // KernelDef is required. 1454 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, 1455 std::unique_ptr<OpKernelFactory> factory) { 1456 // Perform the check in the header to allow compile-time optimization 1457 // to a no-op, allowing the linker to remove the kernel symbols. 1458 if (kernel_def != nullptr) { 1459 InitInternal(kernel_def, kernel_class_name, std::move(factory)); 1460 } 1461 } 1462 1463 // Registers the given factory function with TensorFlow. This is equivalent 1464 // to registering a factory whose Create function invokes `create_fn`. 1465 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, 1466 OpKernel* (*create_fn)(OpKernelConstruction*)) { 1467 // Perform the check in the header to allow compile-time optimization 1468 // to a no-op, allowing the linker to remove the kernel symbols. 1469 if (kernel_def != nullptr) { 1470 InitInternal(kernel_def, kernel_class_name, 1471 absl::make_unique<PtrOpKernelFactory>(create_fn)); 1472 } 1473 } 1474 1475 private: 1476 struct PtrOpKernelFactory : public OpKernelFactory { 1477 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*)) 1478 : create_func_(create_func) {} 1479 1480 OpKernel* Create(OpKernelConstruction* context) override; 1481 1482 OpKernel* (*create_func_)(OpKernelConstruction*); 1483 }; 1484 1485 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, 1486 std::unique_ptr<OpKernelFactory> factory); 1487 }; 1488 1489 } // namespace kernel_factory 1490 1491 // ----------------------------------------------------------------------------- 1492 // Template and inline method implementations, please ignore 1493 1494 template <class T> 1495 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { 1496 return GetNodeAttr(def(), attr_name, value); 1497 } 1498 1499 inline DataType OpKernelContext::input_dtype(int index) const { 1500 DCHECK_GE(index, 0); 1501 DCHECK_LT(index, num_inputs()); 1502 const TensorValue& value((*params_->inputs)[index]); 1503 if (value.is_ref()) { 1504 return MakeRefType(value->dtype()); 1505 } else { 1506 return value->dtype(); 1507 } 1508 } 1509 1510 inline MemoryType OpKernelContext::input_memory_type(int index) const { 1511 DCHECK_GE(index, 0); 1512 DCHECK_LT(index, num_inputs()); 1513 return op_kernel().input_memory_types()[index]; 1514 } 1515 1516 inline DataType OpKernelContext::expected_output_dtype(int index) const { 1517 DCHECK_GE(index, 0); 1518 DCHECK_LT(index, num_outputs()); 1519 return params_->op_kernel->output_type(index); 1520 } 1521 1522 inline MemoryType OpKernelContext::output_memory_type(int index) const { 1523 DCHECK_GE(index, 0); 1524 DCHECK_LT(index, num_outputs()); 1525 return op_kernel().output_memory_types()[index]; 1526 } 1527 1528 inline bool OpKernelContext::input_is_ref(int index) const { 1529 const TensorValue& value((*params_->inputs)[index]); 1530 return value.is_ref(); 1531 } 1532 1533 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { 1534 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), 1535 params_->record_tensor_accesses); 1536 if (params_->record_tensor_accesses) { 1537 really_record_tensor_reference(tensor); 1538 } 1539 } 1540 1541 inline void OpKernelContext::retrieve_accessed_tensors( 1542 TensorReferenceVector* out_vector) { 1543 if (params_->record_tensor_accesses) { 1544 mutex_lock l(mu_); 1545 referenced_tensors_->FreezeAndReturnReferences(out_vector); 1546 } 1547 } 1548 1549 // no input if tensor == nullptr. 1550 inline bool OpKernelContext::has_input(int index) const { 1551 DCHECK_GE(index, 0); 1552 DCHECK_LT(index, num_inputs()); 1553 return (*params_->inputs)[index].tensor != nullptr; 1554 } 1555 1556 inline mutex* OpKernelContext::input_ref_mutex(int index) { 1557 DCHECK_GE(index, 0); 1558 DCHECK_LT(index, num_inputs()); 1559 DCHECK(input_is_ref(index)); 1560 return (*params_->inputs)[index].mutex_if_ref; 1561 } 1562 1563 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { 1564 if (t.IsInitialized()) { 1565 record_tensor_reference(t); 1566 } 1567 } 1568 1569 inline Tensor* OpKernelContext::mutable_output(int index) { 1570 DCHECK_GE(index, 0); 1571 DCHECK_LT(index, num_outputs()); 1572 // No need to record_tensor_reference since the output must already 1573 // have been set by a call that did so. 1574 return outputs_[index].tensor; 1575 } 1576 1577 inline TensorValue OpKernelContext::release_output(int index) { 1578 DCHECK_GE(index, 0); 1579 DCHECK_LT(index, num_outputs()); 1580 TensorValue value = outputs_[index]; 1581 outputs_[index] = TensorValue(); 1582 return value; 1583 } 1584 1585 inline Status OpKernelContext::forward_input_or_allocate_output( 1586 gtl::ArraySlice<int> candidate_input_indices, int output_index, 1587 const TensorShape& output_shape, Tensor** output) { 1588 for (int input_index : candidate_input_indices) { 1589 if (forward_input_to_output_with_shape(input_index, output_index, 1590 output_shape, output)) { 1591 return Status::OK(); 1592 } 1593 } 1594 return allocate_output(output_index, output_shape, output); 1595 } 1596 1597 inline Status OpKernelContext::forward_input_or_allocate_output( 1598 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name, 1599 const TensorShape& output_shape, Tensor** output) { 1600 for (const StringPiece& input_name : candidate_input_names) { 1601 if (forward_input_to_output_with_shape(input_name, output_name, 1602 output_shape, output) 1603 .ok()) { 1604 return Status::OK(); 1605 } 1606 } 1607 return allocate_output(output_name, output_shape, output); 1608 } 1609 1610 template <typename T> 1611 T* OpKernelContext::op_device_context() { 1612 static_assert(std::is_base_of<DeviceContext, T>::value, 1613 "T is not a subclass of DeviceContext"); 1614 return static_cast<T*>(op_device_context()); 1615 } 1616 1617 template <typename T> 1618 T* OpKernelContext::input_device_context(int index) { 1619 DCHECK_NE(params_->input_device_contexts, nullptr); 1620 DCHECK_GE(index, 0); 1621 DCHECK_LT(index, params_->input_device_contexts->size()); 1622 static_assert(std::is_base_of<DeviceContext, T>::value, 1623 "T is not a subclass of DeviceContext"); 1624 return static_cast<T*>((*params_->input_device_contexts)[index]); 1625 } 1626 1627 inline DeviceContext* OpKernelContext::input_device_context(int index) { 1628 DCHECK_NE(params_->input_device_contexts, nullptr); 1629 DCHECK_GE(index, 0); 1630 DCHECK_LT(index, params_->input_device_contexts->size()); 1631 return (*params_->input_device_contexts)[index]; 1632 } 1633 1634 inline const Tensor& OpInputList::operator[](int i) const { 1635 DCHECK_GE(i, 0); 1636 DCHECK_LT(i, stop_ - start_); 1637 return ctx_->input(start_ + i); 1638 } 1639 1640 inline mutex* OpMutableInputList::ref_mutex(int i) { 1641 DCHECK_GE(i, 0); 1642 DCHECK_LT(i, stop_ - start_); 1643 return ctx_->input_ref_mutex(start_ + i); 1644 } 1645 1646 inline Tensor OpMutableInputList::at(int i, bool lock_held) { 1647 DCHECK_GE(i, 0); 1648 DCHECK_LT(i, stop_ - start_); 1649 return ctx_->mutable_input(start_ + i, lock_held); 1650 } 1651 1652 inline Tensor* OpOutputList::operator[](int i) { 1653 DCHECK_GE(i, 0); 1654 DCHECK_LT(i, stop_ - start_); 1655 return ctx_->mutable_output(start_ + i); 1656 } 1657 1658 inline bool OpOutputList::required(int i) const { 1659 DCHECK_GE(i, 0); 1660 DCHECK_LT(i, stop_ - start_); 1661 return ctx_->output_required(start_ + i); 1662 } 1663 1664 inline DataType OpOutputList::expected_output_dtype(int i) const { 1665 DCHECK_GE(i, 0); 1666 DCHECK_LT(i, stop_ - start_); 1667 return ctx_->expected_output_dtype(start_ + i); 1668 } 1669 1670 inline Status OpOutputList::allocate(int i, const TensorShape& shape, 1671 Tensor** output) { 1672 DCHECK_GE(i, 0); 1673 DCHECK_LT(i, stop_ - start_); 1674 return ctx_->allocate_output(start_ + i, shape, output); 1675 } 1676 1677 inline void OpOutputList::set(int i, const Tensor& tensor) { 1678 DCHECK_GE(i, 0); 1679 DCHECK_LT(i, stop_ - start_); 1680 ctx_->set_output(start_ + i, tensor); 1681 } 1682 1683 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { 1684 DCHECK_GE(i, 0); 1685 DCHECK_LT(i, stop_ - start_); 1686 ctx_->set_output_ref(i, mu, tensor_for_ref); 1687 } 1688 1689 // Convenience macros for asserting and handling exceptional conditions. 1690 // Analogous to the CHECK* macros provided by logging.h. 1691 // 1692 // Example use: 1693 // void Compute(OperationContext* context) { 1694 // OP_REQUIRES(context, context->num_inputs() == 2, 1695 // errors::InvalidArgument("FooOp requires 2 arguments")); 1696 // ... 1697 // Status status = SomeUncertainMethod(); 1698 // OP_REQUIRES_OK(context, status); 1699 // ... 1700 // } 1701 1702 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in 1703 // AsyncOpKernel implementations. If these macros are used and the condition 1704 // does not hold, the `done` callback will never be called and the system will 1705 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros 1706 // are legal to use in AsyncOpKernel constructors, we use overload resolution 1707 // to distinguish between OpKernelConstruction* and OpKernelContext* context 1708 // types. 1709 class XlaOpKernelContext; 1710 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {} 1711 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {} 1712 void CheckNotInComputeAsync(OpKernelContext* ctx, 1713 const char* correct_macro_name); 1714 1715 #define OP_REQUIRES(CTX, EXP, STATUS) \ 1716 do { \ 1717 if (!TF_PREDICT_TRUE(EXP)) { \ 1718 CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \ 1719 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 1720 return; \ 1721 } \ 1722 } while (0) 1723 1724 #define OP_REQUIRES_OK(CTX, ...) \ 1725 do { \ 1726 ::tensorflow::Status _s(__VA_ARGS__); \ 1727 if (!TF_PREDICT_TRUE(_s.ok())) { \ 1728 CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \ 1729 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 1730 return; \ 1731 } \ 1732 } while (0) 1733 1734 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ 1735 do { \ 1736 if (!TF_PREDICT_TRUE(EXP)) { \ 1737 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 1738 (CALLBACK)(); \ 1739 return; \ 1740 } \ 1741 } while (0) 1742 1743 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ 1744 do { \ 1745 ::tensorflow::Status _s(STATUS); \ 1746 if (!TF_PREDICT_TRUE(_s.ok())) { \ 1747 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 1748 (CALLBACK)(); \ 1749 return; \ 1750 } \ 1751 } while (0) 1752 1753 } // namespace tensorflow 1754 1755 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ 1756