1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ 17 #define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ 18 19 #include <functional> 20 21 #include <utility> 22 #include <vector> 23 #include "tensorflow/core/framework/allocator.h" 24 #include "tensorflow/core/framework/cancellation.h" 25 #include "tensorflow/core/framework/control_flow.h" 26 #include "tensorflow/core/framework/device_base.h" 27 #include "tensorflow/core/framework/kernel_def_builder.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove 30 #include "tensorflow/core/framework/rendezvous.h" 31 #include "tensorflow/core/framework/selective_registration.h" 32 #include "tensorflow/core/framework/session_state.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove 36 #include "tensorflow/core/framework/tracking_allocator.h" 37 #include "tensorflow/core/framework/types.h" 38 #include "tensorflow/core/framework/types.pb.h" 39 #include "tensorflow/core/framework/unique_tensor_references.h" 40 #include "tensorflow/core/lib/core/errors.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/lib/gtl/array_slice.h" 43 #include "tensorflow/core/lib/gtl/manual_constructor.h" 44 #include "tensorflow/core/platform/env.h" 45 #include "tensorflow/core/platform/logging.h" 46 #include "tensorflow/core/platform/macros.h" 47 #include "tensorflow/core/platform/mutex.h" 48 #include "tensorflow/core/platform/thread_annotations.h" 49 #include "tensorflow/core/platform/types.h" 50 51 namespace Eigen { 52 struct ThreadPoolDevice; 53 struct GpuDevice; 54 struct SyclDevice; 55 } // end namespace Eigen 56 57 namespace tensorflow { 58 59 namespace checkpoint { 60 class TensorSliceReaderCacheWrapper; 61 } // namespace checkpoint 62 63 class AsyncOpKernel; 64 class CallFrameInterface; 65 class FunctionLibraryRuntime; 66 class OpKernelConstruction; // declared below 67 class OpKernelContext; // declared below 68 class OpRegistryInterface; 69 class ResourceMgr; 70 class ScopedStepContainer; 71 class StepStatsCollector; 72 73 class OpKernel { 74 public: 75 // OpKernel won't be instantiated by the scheduler, so you may perform 76 // expensive initialization in the descendant's constructor. 77 explicit OpKernel(OpKernelConstruction* context); 78 79 // Specialized constructor that enables the descendant to provide a different 80 // `NodeDef` value. For example, this constructor can be used to provide a 81 // stripped-down `NodeDef` that does not contain the full set of attrs (such 82 // as tensor values) if the descendant stores them in a different form. 83 explicit OpKernel(OpKernelConstruction* context, 84 std::unique_ptr<const NodeDef> node_def); 85 86 virtual ~OpKernel(); 87 88 // An OpKernel's computation can be either synchronous or 89 // asynchronous. All OpKernel Compute() methods must be thread-safe as they 90 // may be called concurrently (e.g. by multiple executions of the same graph 91 // concurrently). 92 // 93 // Most OpKernels should compute synchronously. They should 94 // subclass OpKernel and override the Compute() method and have it 95 // return after completing the supplied work. 96 // 97 // A few special kernels might need to be asynchronous to bound the 98 // number of threads (e.g., network receive operations). These 99 // kernels must subclass AsyncOpKernel and override 100 // AsyncOpKernel::ComputeAsync(). 101 // 102 // In both cases, implementations of Compute() and ComputeAsync() 103 // get inputs and write outputs through the given OpKernelContext 104 // and returns a status via context->SetStatus(). They must be 105 // thread-safe. 106 107 // Synchronous compute. 108 // 109 // "context" is guaranteed to be alive until Compute() returns. 110 virtual void Compute(OpKernelContext* context) = 0; 111 112 // Returns nullptr iff this op kernel is synchronous. 113 virtual AsyncOpKernel* AsAsync() { return nullptr; } 114 115 // Returns true iff this op kernel is considered "expensive". The 116 // runtime may use this flag to optimize graph execution for example 117 // to "inline" inexpensive kernels. 118 virtual bool IsExpensive() { return expensive_; } 119 120 // Accessors. 121 const NodeDef& def() const { return *def_; } 122 const string& name() const; // Same as def().name() 123 const string& type_string() const; // Same as def().op() 124 const string& requested_device() const; // Same as def().device() 125 bool is_internal() const { return is_internal_; } 126 127 int num_inputs() const { return input_types_.size(); } 128 DataType input_type(int i) const { return input_types_[i]; } 129 const DataTypeVector& input_types() const { return input_types_; } 130 const MemoryTypeVector& input_memory_types() const { 131 return input_memory_types_; 132 } 133 const string& requested_input(int i) const; // Same as def().input(i) 134 135 int num_outputs() const { return output_types_.size(); } 136 DataType output_type(int o) const { return output_types_[o]; } 137 const DataTypeVector& output_types() const { return output_types_; } 138 const MemoryTypeVector& output_memory_types() const { 139 return output_memory_types_; 140 } 141 142 Status InputRange(StringPiece input_name, int* start, int* stop) const; 143 Status OutputRange(StringPiece output_name, int* start, int* stop) const; 144 145 // We allow legacy scalars within Google up until GraphDef version 6. 146 // TODO(irving): Remove when we can drop support for GraphDef version 5. 147 bool allow_legacy_scalars() const { 148 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) 149 return graph_def_version_ < 6; 150 #else 151 return false; 152 #endif 153 } 154 155 // Allow either scalars or (if allowing legacy scalars) shape (1,). 156 bool IsLegacyScalar(const TensorShape& shape) const { 157 return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 && 158 shape.dim_size(0) == 1); 159 } 160 161 // Allow rank 1 or (if allowing legacy scalars) rank 0. 162 bool IsLegacyVector(const TensorShape& shape) const { 163 return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); 164 } 165 166 // Turn a shape Tensor into a TensorShape 167 // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars 168 Status MakeShape(const Tensor& shape, TensorShape* out) const; 169 170 private: 171 const std::unique_ptr<const NodeDef> def_; 172 const DataTypeVector input_types_; 173 const MemoryTypeVector input_memory_types_; 174 const DataTypeVector output_types_; 175 const MemoryTypeVector output_memory_types_; 176 const int graph_def_version_; 177 const bool is_internal_; // True if this is an internal operation 178 NameRangeMap input_name_map_; 179 NameRangeMap output_name_map_; 180 bool expensive_; 181 182 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); 183 }; 184 185 class AsyncOpKernel : public OpKernel { 186 public: 187 using OpKernel::OpKernel; // Lift OpKernel constructors. 188 189 // Asynchronous compute. 190 // 191 // Implementations of ComputeAsync() must run "done" to signal the 192 // completion of the computation. "context" is guaranteed to be 193 // alive until the "done" callback starts. 194 typedef std::function<void()> DoneCallback; 195 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; 196 197 AsyncOpKernel* AsAsync() final { return this; } 198 199 void Compute(OpKernelContext* context) final; 200 201 bool IsExpensive() override { return true; } 202 }; 203 204 // Wraps a tensor that is held by an Op across calls to Compute(). For 205 // memory safety when using asynchronous devices like GPUs, the system 206 // must be notified when a Tensor is used inside an Op execution. The 207 // wrapper ensures that all uses of the Tensor are tracked, because in 208 // order to retrieve the Tensor the caller must use AccessTensor which 209 // notifies the context. 210 class PersistentTensor { 211 public: 212 PersistentTensor() {} 213 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} 214 215 // Caller does not own the returned Tensor*. 216 Tensor* AccessTensor(OpKernelConstruction* context); 217 // Caller does not own the returned Tensor*. 218 Tensor* AccessTensor(OpKernelContext* context); 219 220 // The check for initialization does not need to access the 221 // underlying tensor buffer. 222 bool IsInitialized() const { return tensor_.IsInitialized(); } 223 224 int64 NumElements() const { return tensor_.NumElements(); } 225 226 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); } 227 228 private: 229 Tensor tensor_; 230 }; 231 232 class OpKernelConstruction { 233 public: 234 OpKernelConstruction(DeviceType device_type, DeviceBase* device, 235 Allocator* allocator, const NodeDef* node_def, 236 const OpDef* op_def, FunctionLibraryRuntime* flib, 237 const DataTypeSlice& input_types, 238 const MemoryTypeSlice& input_memory_types, 239 const DataTypeSlice& output_types, 240 const MemoryTypeSlice& output_memory_types, 241 int graph_def_version, Status* status); 242 243 Env* env() const { return device_->env(); } 244 245 // Allocation of tensors during kernel construction: 246 // 247 // It is legal to temporarily allocate scratch tensor storage during 248 // Op kernel construction. Scratch tensors should be allocated using 249 // allocate_temp below. Some kernels need to keep tensors in between 250 // invocations. If such a Tensor is allocated during kernel 251 // construction this must be done using allocate_persistent, and the 252 // Op may only store the returned PersistentTensor object. When the 253 // Tensor is needed in a subsequent invocation, it can be retrieved 254 // from the PersistentTensor using the AccessTensor method. This 255 // ensures that the system is made aware of any use of the tensor's 256 // allocated memory, which is needed for correctness on asynchronous 257 // devices such as GPUs. 258 259 // Allocates a temporary Tensor of the specified type and shape. The 260 // Tensor must not be used after kernel construction is 261 // complete. See comment above. 262 Status allocate_temp(DataType type, const TensorShape& shape, 263 Tensor* out_temp); 264 265 // Allocates a Tensor of the specified type and shape which the Op 266 // plans to maintain as persistent state. out_persistent holds the 267 // PersistentTensor which is the object the caller should store. For 268 // convenience, if out_tensor is non-null then it will be filled in 269 // with a Tensor* pointing to the newly-allocated tensor which the 270 // caller can use instead of calling 271 // out_persistent->AccessTensor. The caller does not own out_tensor 272 // and should not keep a copy of it. See comment above. 273 Status allocate_persistent(DataType type, const TensorShape& shape, 274 PersistentTensor* out_persistent, 275 Tensor** out_tensor); 276 277 // User-supplied configuration of this operation. 278 const NodeDef& def() const { return *def_; } 279 280 // For inspecting the inputs to this operation. 281 int num_inputs() const { return input_types_.size(); } 282 DataType input_type(int i) const { return input_types_[i]; } 283 const DataTypeSlice& input_types() const { return input_types_; } 284 const MemoryTypeSlice& input_memory_types() const { 285 return input_memory_types_; 286 } 287 288 // For inspecting the outputs expected from this operation. 289 int num_outputs() const { return output_types_.size(); } 290 DataType output_type(int i) const { return output_types_[i]; } 291 const DataTypeSlice& output_types() const { return output_types_; } 292 const MemoryTypeSlice& output_memory_types() const { 293 return output_memory_types_; 294 } 295 296 // If expected_inputs == inputs() and expected_outputs == output_types(), 297 // returns OK, else returns INVALID_ARGUMENT with an error message. 298 // Recommended for Ops with dynamic signatures. 299 Status MatchSignature(const DataTypeSlice expected_inputs, 300 const DataTypeSlice expected_outputs); 301 302 // For recording configuration errors during construction. 303 void SetStatus(const Status& status); 304 const Status& status() const { return *status_; } 305 306 // Look up the attr with name attr_name and set *value to its value. If no 307 // attr with attr_name is found in def(), or the attr does not have 308 // a matching type, a non-ok status will be returned. 309 template <class T> 310 Status GetAttr(StringPiece attr_name, T* value) const; 311 312 // Return true if the attr_name is defined in def(). 313 bool HasAttr(StringPiece attr_name) const; 314 315 // Return the device type. 316 const DeviceType& device_type() const { return device_type_; } 317 318 // If not nullptr, the kernel can instantiate functions defined in 319 // the library. E.g., 320 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). 321 FunctionLibraryRuntime* function_library() const { return flib_; } 322 323 // The GraphDef version whose behavior we should follow. 324 int graph_def_version() const { return graph_def_version_; } 325 326 // Helper routines for the OP_REQUIRES macros 327 void CtxFailure(const Status& s); 328 void CtxFailureWithWarning(const Status& s); 329 void CtxFailure(const char* file, int line, const Status& s); 330 void CtxFailureWithWarning(const char* file, int line, const Status& s); 331 332 // Unrecommended functions: these are functions that have some 333 // current uses but are not recommended for use, and may go away at 334 // some future major version release. 335 336 // May be used, e.g., to get GPU handles, etc. 337 // 338 // Currently only used to call MakeTensorFromProto() for 339 // implementing ConstantOp for every device. See comments 340 // on Device::MakeTensorFromProto for longer-term replacement 341 // ideas. 342 DeviceBase* device() const { return device_; } 343 344 private: 345 const DeviceType device_type_; 346 DeviceBase* const device_; 347 Allocator* allocator_; 348 const NodeDef* def_; 349 const OpDef* op_def_; 350 FunctionLibraryRuntime* flib_; 351 DataTypeSlice input_types_; 352 MemoryTypeSlice input_memory_types_; 353 DataTypeSlice output_types_; 354 MemoryTypeSlice output_memory_types_; 355 const int graph_def_version_; 356 Status* status_; 357 358 // Allow op_def_ across from OpKernel, but not from subclasses. 359 // TODO(irving): Remove protos from this header entirely. 360 friend class OpKernel; 361 362 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); 363 }; 364 365 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading 366 // tensorflow::gtl::iterator_range to make the below container classes 367 // unnecessary. 368 template <typename ListType, typename ElementType> 369 class OpArgIterator { 370 public: 371 typedef OpArgIterator<ListType, ElementType> ME; 372 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} 373 bool operator==(const ME& rhs) { 374 DCHECK(list_ == rhs.list_); 375 return i_ == rhs.i_; 376 } 377 bool operator!=(const ME& rhs) { 378 DCHECK(list_ == rhs.list_); 379 return i_ != rhs.i_; 380 } 381 void operator++() { ++i_; } 382 ElementType& operator*() { return (*list_)[i_]; } 383 384 private: 385 const ListType* const list_; 386 int i_; 387 }; 388 389 // Utility class for representing a list of immutable input tensors 390 // that are passed to the op as a single named argument. 391 class OpInputList { 392 public: 393 typedef OpArgIterator<OpInputList, const Tensor&> Iterator; 394 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} 395 OpInputList(OpKernelContext* ctx, int start, int stop) 396 : ctx_(ctx), start_(start), stop_(stop) {} 397 OpInputList& operator=(const OpInputList& other) = default; 398 const Tensor& operator[](int i) const; 399 int size() const { return stop_ - start_; } 400 Iterator begin() const { return Iterator(this, 0); } 401 Iterator end() const { return Iterator(this, size()); } 402 403 private: 404 OpKernelContext* ctx_; // not owned 405 int start_; 406 int stop_; 407 }; 408 409 // Utility class for representing a list of mutable ("ref") input tensors 410 // that are passed to the op as a single named argument. 411 class OpMutableInputList { 412 public: 413 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator; 414 OpMutableInputList(OpKernelContext* ctx, int start, int stop) 415 : ctx_(ctx), start_(start), stop_(stop) {} 416 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} 417 OpMutableInputList& operator=(const OpMutableInputList& other) = default; 418 Tensor at(int i, bool lock_held); 419 mutex* ref_mutex(int i); 420 int size() const { return stop_ - start_; } 421 Iterator begin() const { return Iterator(this, 0); } 422 Iterator end() const { return Iterator(this, size()); } 423 424 private: 425 OpKernelContext* ctx_; // not owned 426 int start_; 427 int stop_; 428 }; 429 430 // Utility class for representing a list of output tensors that are 431 // grouped as a single named output. 432 class OpOutputList { 433 public: 434 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator; 435 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} 436 OpOutputList(OpKernelContext* ctx, int start, int stop) 437 : ctx_(ctx), start_(start), stop_(stop) {} 438 OpOutputList& operator=(const OpOutputList& other) = default; 439 Tensor* operator[](int i); 440 bool required(int i) const; 441 DataType expected_output_dtype(int i) const; 442 Status allocate(int i, const TensorShape& shape, Tensor** output); 443 void set(int i, const Tensor& tensor); 444 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); 445 int size() const { return stop_ - start_; } 446 Iterator begin() const { return Iterator(this, 0); } 447 Iterator end() const { return Iterator(this, size()); } 448 449 private: 450 OpKernelContext* ctx_; // not owned 451 int start_; 452 int stop_; 453 }; 454 455 // Holds a tensor or tensor reference. For tensor references, we need 456 // a mutex to prevent concurrent access to the tensor. 457 struct TensorValue { 458 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} 459 TensorValue(Tensor* t) // NOLINT(runtime/explicit) 460 : mutex_if_ref(nullptr), tensor(t) {} 461 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} 462 Tensor* operator->() const { return tensor; } 463 bool is_ref() const { return mutex_if_ref != nullptr; } 464 465 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref 466 Tensor* tensor; 467 }; 468 469 class OpKernelContext { 470 public: 471 // The first element of a WrappedAllocator is a "base" Allocator and 472 // the second element is that Allocator wrapped by a 473 // TrackingAllocator 474 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; 475 476 // TODO(zhifengc): Do some cleanup of Params. 477 // The Params struct is passed in to initialize an OpKernelContext, 478 // and must outlive the OpKernelContext. 479 struct Params { 480 ~Params() { delete eigen_gpu_device; } 481 482 // The step being executed. 483 int64 step_id = 0; 484 485 // The op kernel being computed. 486 OpKernel* op_kernel = nullptr; 487 488 // The device on which the kernel is running. 489 DeviceBase* device = nullptr; 490 491 // The Eigen GPU device wrapper, which may include a per-op 492 // wrapped allocator. The concrete type of this object depends on 493 // the type of this->device, so eigen_gpu_device can't be an 494 // inline member and must be heap allocated. However, we don't 495 // want to allocate a new eigen_gpu_device for every Op that is 496 // executed. Instead this member is allocated on first use using 497 // ensure_eigen_gpu_device, and then if the Params structure is 498 // re-used for subsequent Ops, the eigen_gpu_device is 499 // ReInitialized in the OpKernelContext constructor. Unlike the 500 // other pointers in Params, this one is owned by Params. 501 PerOpGpuDevice* eigen_gpu_device = nullptr; 502 503 inline void ensure_eigen_gpu_device() { 504 DCHECK(device); 505 if (nullptr == eigen_gpu_device) { 506 // Surprisingly, MakeGpuDevice will return nullptr if the 507 // device is not a GPU device. This is ok, since those devices 508 // will never use eigen_gpu_device. It seems better to have 509 // ensure_eigen_gpu_device fall through and regenerate the 510 // nullptr every time an OpKernelContext is instantiated, than 511 // to do an unnecessary allocation of a dummy eigen GPU 512 // device for CPU device Ops. 513 eigen_gpu_device = device->MakeGpuDevice(); 514 } 515 } 516 517 bool track_allocations = false; 518 bool log_memory = false; 519 bool record_tensor_accesses = false; 520 521 // Array indexed by output number for this node 522 const AllocatorAttributes* output_attr_array = nullptr; 523 524 // Shared resources accessible by this op kernel invocation. 525 ResourceMgr* resource_manager = nullptr; 526 527 // Per-step resources accessible by this op kernel invocation should be 528 // stored in this container.. 529 ScopedStepContainer* step_container = nullptr; 530 531 // Mechanism used by this op kernel invocation to communicate with 532 // computations running on other devices. 533 Rendezvous* rendezvous = nullptr; 534 535 // The session state for this op. 536 SessionState* session_state = nullptr; 537 538 // The tensor store for this op. 539 TensorStore* tensor_store = nullptr; 540 541 // Mechanism used by this op kernel invocation to register a callback 542 // for its cancellation. 543 CancellationManager* cancellation_manager = nullptr; 544 545 // Inputs to this op kernel. 546 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr; 547 bool is_input_dead = false; 548 549 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs = 550 nullptr; 551 552 // Device contexts. 553 const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts = 554 nullptr; 555 DeviceContext* op_device_context = nullptr; 556 557 // Control-flow op supports. 558 FrameAndIter frame_iter; 559 560 // Function call supports. 561 CallFrameInterface* call_frame = nullptr; 562 FunctionLibraryRuntime* function_library = nullptr; 563 std::function<void(std::function<void()>)>* runner = nullptr; 564 StepStatsCollector* stats_collector = nullptr; 565 566 // TensorSliceReaderCache support. 567 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; 568 }; 569 570 // params must outlive the OpKernelContext. 571 explicit OpKernelContext(Params* params); 572 OpKernelContext(Params* params, int noutputs); 573 ~OpKernelContext(); 574 575 Env* env() const { return params_->device->env(); } 576 577 int64 step_id() const { return params_->step_id; } 578 579 const OpKernel& op_kernel() const { return *params_->op_kernel; } 580 581 // Input/output signature. 582 583 int num_inputs() const { return params_->inputs->size(); } 584 DataType input_dtype(int index) const; 585 Status input_dtype(StringPiece name, DataType* dtype) const; 586 MemoryType input_memory_type(int index) const; 587 588 int num_outputs() const { return outputs_.size(); } 589 DataType expected_output_dtype(int index) const; 590 MemoryType output_memory_type(int index) const; 591 592 // Input 593 594 // Returns an immutable input tensor. May only be used for non-Ref 595 // inputs. For Ref inputs use mutable_input below. 596 // REQUIRES: !IsRefType(input_dtype(index)) 597 // TODO(mrry): Convert this to return Status. 598 const Tensor& input(int index); 599 600 // Returns the named immutable input tensor in "tensor", as defined 601 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs 602 // use mutable_input below. 603 // REQUIRES: !IsRefType(input_dtype(index)) 604 // REQUIRES: the named input must not be a list. 605 Status input(StringPiece name, const Tensor** tensor); 606 607 // Returns the named list-valued immutable input in "list", as 608 // defined in the OpDef. If the named output is not list-valued, 609 // returns a one-element list. May only be used for non-Ref 610 // inputs. For Ref inputs use mutable_input below. 611 // REQUIRES: !IsRefType(input_dtype(index)) 612 Status input_list(StringPiece name, OpInputList* list); 613 614 // For mutable inputs, use the following together to make sure there 615 // is no concurrent access to mutable_input(), e.g.: 616 // { 617 // Tensor& t = context->mutable_input(index); 618 // mutex_lock lock(*context->input_ref_mutex(index)); 619 // // modify the values in t 620 // } 621 // REQUIRES: IsRefType(input_dtype(index)) 622 Status input_ref_mutex(StringPiece name, mutex** out_mutex); 623 624 // Returns a mutable input tensor. Must be used to access Ref 625 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may 626 // modify the values stored in the Tensor buffer, and modifications 627 // will be visible to other Ops reading the same ref tensor. If 628 // !lock_held the input mutex will be acquired before returning the 629 // Tensor. 630 // TODO(mrry): Convert this to return Status. 631 Tensor mutable_input(int index, bool lock_held); 632 633 // Returns the named mutable input tensor in "tensor", as defined in 634 // the OpDef. Must be used to access Ref inputs. The values stored 635 // in the Tensor buffer may be modified, and modifications will be 636 // visible to other Ops reading the same ref tensor. If !lock_held 637 // the input mutex will be acquired before returning the Tensor. 638 // REQUIRES: the named input must not be a list. 639 // REQUIRES: the named input must be a ref tensor. 640 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); 641 642 // Returns the named list-valued mutable input in "list", as defined 643 // in the OpDef. If the named input is not list-valued, returns a 644 // one-element list. Must be used to access Ref inputs. The values 645 // stored in the Tensor buffer may be modified, and modifications 646 // will be visible to other Ops reading the same ref tensor. 647 // REQUIRES: the named input must be a ref tensor. 648 Status mutable_input_list(StringPiece name, OpMutableInputList* list); 649 650 // Replace the corresponding Ref Input to use the storage buffer 651 // used by tensor. If !lock_held the input mutex will be acquired 652 // before returning the Tensor. 653 // REQUIRES: IsRefType(input_dtype(index)). 654 void replace_ref_input(int index, const Tensor& tensor, bool lock_held); 655 656 // Replace the corresponding named Ref Input to use the storage 657 // buffer used by tensor. If !lock_held the input mutex will be 658 // acquired before returning the Tensor. 659 // REQUIRES: IsRefType(input_dtype(index)). 660 Status replace_ref_input(StringPiece name, const Tensor& tensor, 661 bool lock_held); 662 663 // Deletes the Tensor object used as the Ref Input at 664 // input_index. This is not usually necessary and should be used 665 // with caution. If !lock_held the input mutex will be acquired 666 // before returning the Tensor. 667 // REQUIRES: IsRefType(input_dtype(input_index)). 668 void delete_ref_input(int input_index, bool lock_held); 669 670 // Return true if there is input at the given index. An operator has no 671 // input at index if its tensor is null. This is primarily used by the 672 // merge operator. 673 // TODO(mrry): Convert this to return Status. 674 bool has_input(int index) const; 675 676 // Returns true if all inputs are the same shape, otherwise sets the 677 // status to a non-OK value and returns false. 678 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 679 bool ValidateInputsAreSameShape(OpKernel* op); 680 681 // Input to output forwarding. 682 683 // Set the output Ref Tensor at output_index to be an alias of the 684 // input Ref Tensor at input_index. 685 // REQUIRES: IsRefType(input_dtype(input_index)). 686 // REQUIRES: IsRefType(output_dtype(output_index)). 687 void forward_ref_input_to_ref_output(int input_index, int output_index); 688 689 // Returns true when an alias to input[input_index], reshaped to output_shape, 690 // which is safe to use for in-place computation was written to *output. 691 // Returns false if input[input_index] has a refcount greater than one, or if 692 // its type does not match the expected output type of output[output_index], 693 // or the number of elements in input[input_index] does not equal the number 694 // of elements in output_shape. 695 bool forward_input_to_output_with_shape(int input_index, int output_index, 696 const TensorShape& output_shape, 697 Tensor** output) TF_MUST_USE_RESULT; 698 Status forward_input_to_output_with_shape(StringPiece input_name, 699 StringPiece output_name, 700 const TensorShape& output_shape, 701 Tensor** output) TF_MUST_USE_RESULT; 702 703 // Returns a pointer to a Tensor aliasing the underlying buffer backing 704 // input[input_index] iff 705 // * input[input_index] is not a ref, 706 // * the data type, shape, memory type, and allocator attributes of 707 // input[input_index] are compatible with those given in dtype, shape, 708 // memory_type, and attr, 709 // * refcount on the underlying buffer is one. 710 // Otherwise returns nullptr. 711 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, 712 // forwarding is only safe if there are no reads via __ldg() after writes 713 // to the same address. 714 std::unique_ptr<Tensor> forward_input( 715 int input_index, DataType dtype, const TensorShape& shape, 716 MemoryType memory_type, 717 const AllocatorAttributes& attr) TF_MUST_USE_RESULT; 718 719 // Tries to forward one of the inputs given in input_indices to 720 // output[output_index]. If none of the given inputs can be forwarded, calls 721 // allocate_output() to allocate a new output buffer. 722 Status forward_input_or_allocate_output( 723 gtl::ArraySlice<int> candidate_input_indices, int output_index, 724 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; 725 Status forward_input_or_allocate_output( 726 gtl::ArraySlice<StringPiece> candidate_input_names, 727 StringPiece output_name, const TensorShape& output_shape, 728 Tensor** output) TF_MUST_USE_RESULT; 729 730 // Tries to reuse one of the inputs given in input_indices as a temporary. 731 // If none of the given inputs can be forwarded, calls 732 // allocate_temp() to allocate a new temporary buffer. 733 Status forward_input_or_allocate_temp( 734 gtl::ArraySlice<int> candidate_input_indices, DataType type, 735 const TensorShape& shape, const AllocatorAttributes& allocator_attr, 736 Tensor* out_temp) TF_MUST_USE_RESULT; 737 738 Status forward_input_or_allocate_temp( 739 gtl::ArraySlice<int> candidate_input_indices, DataType type, 740 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { 741 return forward_input_or_allocate_temp(candidate_input_indices, type, shape, 742 AllocatorAttributes(), out_temp); 743 } 744 745 // Output 746 747 // Returns the named list-valued output in "list", as defined in the OpDef. 748 // If the named output is not list-valued, returns a one-element list. 749 Status output_list(StringPiece name, OpOutputList* list); 750 751 // If output_required(index) returns true, the OpKernel's Compute() method 752 // should call allocate_output(index, ...), set_output(index, ...), 753 // set_output_ref(index, ...), or set the status to a non-ok value. 754 // If it returns false, it may output, but is not required to do so. 755 // TODO(mrry): Convert this to return Status, and implement a string 756 // name version. 757 bool output_required(int index) const { 758 return true; // TODO(josh11b): implement 759 } 760 761 // Allocation of tensors during kernel execution inside the Compute 762 // method: 763 // 764 // There are three methods to allocate Tensors when an Op kernel 765 // executes. 766 // 767 // 1) allocate_persistent. This is only needed for Tensors that will 768 // be stored by the Op between invocations, and it *must* be used 769 // for those Tensors. The call returns a PersistentTensor, and that 770 // is the only object the Op is allowed to hold on to between 771 // invocations. When the Tensor is needed in a subsequent 772 // invocation, it can be retrieved from the PersistentTensor using 773 // the AccessTensor method. This ensures that the system is made 774 // aware of any use of the tensor's allocated memory, which is 775 // needed for correctness on asynchronous devices such as GPUs. 776 // 777 // 2) allocate_output. This should be used to allocate any tensor 778 // that is going to be used as an output from the Op at the end of 779 // the current execution. The caller indicates which output the 780 // Tensor will be assigned to, and the call returns the 781 // newly-allocated Tensor. The Tensor can subsequently be assigned 782 // to during kernel execution, and will be used as the designated 783 // output when the kernel execution completes. 784 // 785 // 3) allocate_temp. This should be used to allocate any scratch 786 // storage that is needed while the kernel is executing, and will 787 // not be retained by the Op. 788 // 789 // In some cases a Tensor needs to be used as an output even though 790 // it was previously allocated elsewhere. The Tensor may have been 791 // passed as an input, or stored in a PersistentTensor during a 792 // previous kernel execution, or allocated earlier in the kernel 793 // execution at a time when it was not known which output it would 794 // be assigned to. In this case the kernel can use set_output or 795 // set_output_ref to indicate that the tensor should be used as the 796 // designated output. It is legal to use any previously-allocated 797 // Tensor as an argument to set_output or set_output_ref, including 798 // Tensors allocated via allocate_temp. There may be a performance 799 // penalty to using a Tensor that was not allocated using 800 // allocate_output. This is because allocate_output uses the 801 // AllocatorAttributes stored in output_attr_array for the 802 // designated output. In some cases, using the wrong attributes may 803 // cause an extra copy of the Tensor's buffer. 804 805 // Allocates output for the specified output index with shape. 806 // OpKernelContext retains ownership of the returned pointer. See 807 // comment above. 808 // 809 // If memory allocation fails, returns an error status. 810 // 811 // REQUIRES: !IsRefType(expected_output_dtype(index)) 812 Status allocate_output(int index, const TensorShape& shape, 813 Tensor** tensor) TF_MUST_USE_RESULT; 814 Status allocate_output(StringPiece name, const TensorShape& shape, 815 Tensor** tensor) TF_MUST_USE_RESULT; 816 // The following methods use the supplied attributes instead of 817 // those in output_attr_array. The caller is responsible for 818 // ensuring that the attributes are "compatible" with the 819 // output_attr_array, e.g. the tensor is allocated on the correct 820 // device. See comment above. 821 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, 822 AllocatorAttributes attr) TF_MUST_USE_RESULT; 823 Status allocate_output(StringPiece name, const TensorShape& shape, 824 Tensor** tensor, 825 AllocatorAttributes attr) TF_MUST_USE_RESULT; 826 827 // Allocates a temporary Tensor of the specified type and 828 // shape. Devices such as GPUs that enqueue Ops for lazy execution 829 // may retain references to the temporary tensors after the Op's 830 // Compute method has run. See comment above. 831 Status allocate_temp(DataType type, const TensorShape& shape, 832 Tensor* out_temp, AllocatorAttributes allocator_attr, 833 const AllocationAttributes& allocation_attr); 834 Status allocate_temp(DataType type, const TensorShape& shape, 835 Tensor* out_temp, AllocatorAttributes allocator_attr) { 836 return allocate_temp(type, shape, out_temp, allocator_attr, 837 AllocationAttributes()); 838 } 839 Status allocate_temp(DataType type, const TensorShape& shape, 840 Tensor* out_temp) { 841 return allocate_temp(type, shape, out_temp, AllocatorAttributes()); 842 } 843 844 // Allocates a Tensor of the specified type and shape which the Op 845 // plans to maintain as persistent state. out_persistent holds the 846 // PersistentTensor which is the object the caller should store. For 847 // convenience, if out_tensor is non-null then it will be filled in 848 // with a Tensor* pointing to the newly-allocated tensor which the 849 // caller can use instead of calling 850 // out_persistent->AccessTensor. The caller does not own out_tensor 851 // and should not keep a copy of it. See comment above. 852 Status allocate_persistent(DataType type, const TensorShape& shape, 853 PersistentTensor* out_persistent, 854 Tensor** out_tensor, AllocatorAttributes attr); 855 Status allocate_persistent(DataType type, const TensorShape& shape, 856 PersistentTensor* out_persistent, 857 Tensor** out_tensor) { 858 return allocate_persistent(type, shape, out_persistent, out_tensor, 859 AllocatorAttributes()); 860 } 861 862 // Copies a tensor (allocated by the caller) to the specified output 863 // index. REQUIRES: !IsRefType(expected_output_dtype(index)) 864 // REQUIRES: 'tensor' must have the same MemoryType as 865 // output_memory_types[index]. See comment above. 866 Status set_output(StringPiece name, const Tensor& tensor); 867 868 // To output a reference. Caller retains ownership of mu and tensor_for_ref, 869 // and they must outlive all uses within the step. See comment above. 870 // REQUIRES: IsRefType(expected_output_dtype(index)) 871 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); 872 873 // Returns nullptr if allocate_output() or set_output() have not been called. 874 Status mutable_output(StringPiece name, Tensor** tensor); 875 876 // Transfers ownership of an output tensor to the caller. 877 // NOTE: For non-reference outputs, the caller takes responsibility 878 // for deletion. For reference outputs, the caller does NOT take 879 // responsibility for deletion. 880 Status release_output(StringPiece name, TensorValue* value); 881 882 // Records device specific state about how the input tensors were 883 // computed. 884 // 885 // If using the templated function, the type must be a subclass 886 // of DeviceContext. 887 // 888 // Get the DeviceContext used for the index input. Returns nullptr 889 // if no DeviceContext was provided. 890 template <typename T> 891 T* input_device_context(int index); 892 DeviceContext* input_device_context(int index); 893 894 // Return the DeviceContext that should be used for this Op. 895 // 896 // If using the templated function, the type must be a subclass 897 // of DeviceContext. 898 // 899 // Returns nullptr if the device did not provide one. 900 template <typename T> 901 T* op_device_context(); 902 DeviceContext* op_device_context() { 903 DeviceContext* ret = params_->op_device_context; 904 if (ret == nullptr) { 905 auto* dev_info = device()->tensorflow_gpu_device_info(); 906 if (dev_info) ret = dev_info->default_context; 907 } 908 return ret; 909 } 910 911 AllocatorAttributes input_alloc_attr(int index) const { 912 if (params_->input_alloc_attrs == nullptr) { 913 return AllocatorAttributes(); 914 } else { 915 DCHECK_GE(index, 0); 916 DCHECK_LT(index, params_->input_alloc_attrs->size()); 917 return (*params_->input_alloc_attrs)[index]; 918 } 919 } 920 921 AllocatorAttributes output_alloc_attr(int index) const { 922 return params_->output_attr_array[index]; 923 } 924 925 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const { 926 mutex_lock lock(mu_); 927 gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_; 928 return retrieved; 929 } 930 931 // Communication. 932 // 933 // An op kernel communicates with outside environment through 934 // Rendezvous Send() and Recv(). 935 Rendezvous* rendezvous() const { return params_->rendezvous; } 936 937 // An op kernel can access the session state it belongs to. 938 SessionState* session_state() const { return params_->session_state; } 939 940 // An op kernel can access the tensor store of the run it belongs to. 941 TensorStore* tensor_store() const { return params_->tensor_store; } 942 943 // Function call support. 944 // 945 // If this kernel invocation is within a function execution, 946 // call_frame() returns the call frame for the function call. 947 CallFrameInterface* call_frame() const { return params_->call_frame; } 948 949 // If not nullptr, the kernel invoke functions defined in the 950 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). 951 FunctionLibraryRuntime* function_library() const { 952 return params_->function_library; 953 } 954 955 std::function<void(std::function<void()>)>* runner() const { 956 return params_->runner; 957 } 958 StepStatsCollector* stats_collector() const { 959 return params_->stats_collector; 960 } 961 962 // Shared resources accessible to this kernel. 963 ResourceMgr* resource_manager() const { return params_->resource_manager; } 964 965 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { 966 return params_->slice_reader_cache; 967 } 968 969 // Execution. 970 // 971 // OpKernels can use these eigen devices to carry out their 972 // numerical computation. 973 const Eigen::ThreadPoolDevice& eigen_cpu_device() const { 974 return *device()->eigen_cpu_device(); 975 } 976 const Eigen::GpuDevice& eigen_gpu_device() const { 977 return params_->eigen_gpu_device->device(); 978 } 979 #ifdef TENSORFLOW_USE_SYCL 980 const Eigen::SyclDevice& eigen_sycl_device() const { 981 return *device()->eigen_sycl_device(); 982 } 983 #endif 984 template <typename EigenDeviceType> 985 const EigenDeviceType& eigen_device() const; 986 987 // Error handling. 988 989 // If expected_inputs == inputs() and expected_outputs == output_types(), 990 // returns OK, else returns INVALID_ARGUMENT with an error message. 991 // Recommended for Ops with dynamic signatures, where validation can only 992 // be performed at runtime. 993 Status MatchSignature(const DataTypeSlice expected_inputs, 994 const DataTypeSlice expected_outputs); 995 996 // An OpKernel should call SetStatus() if Compute() encounters an 997 // error. 998 void SetStatus(const Status& status); 999 const Status& status() const { return status_; } 1000 1001 // Cancellation. 1002 // 1003 // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an 1004 // example of how to use this API. 1005 CancellationManager* cancellation_manager() const { 1006 return params_->cancellation_manager; 1007 } 1008 1009 // Other accessors. 1010 1011 // For control flow. 1012 FrameAndIter frame_iter() const { return params_->frame_iter; } 1013 bool is_input_dead() const { return params_->is_input_dead; } 1014 bool* is_output_dead() { return &is_output_dead_; } 1015 1016 // May be used, e.g., to get GPU handles, etc. 1017 // TODO(tucker): Add example usage. 1018 DeviceBase* device() const { return params_->device; } 1019 1020 // Retrieve list of referenced tensors in out_vector. Once this is 1021 // called, it is not legal to reference any more tensors. Should 1022 // not be called from Op kernels. 1023 void retrieve_accessed_tensors(TensorReferenceVector* out_vector); 1024 1025 // Per-step container for use by white-listed internal ops. 1026 ScopedStepContainer* step_container() const { 1027 return params_->step_container; 1028 } 1029 1030 // Helper routines for the OP_REQUIRES macros 1031 void CtxFailure(const Status& s); 1032 void CtxFailureWithWarning(const Status& s); 1033 void CtxFailure(const char* file, int line, const Status& s); 1034 void CtxFailureWithWarning(const char* file, int line, const Status& s); 1035 1036 // Unrecommended functions: these are functions that have some 1037 // current uses but are not recommended for use, and may go away at 1038 // some future major version release. 1039 // 1040 // The following functions all have versions that return Status 1041 // to capture error conditions, and are strongly preferred. 1042 Tensor* mutable_output(int index); 1043 void set_output(int index, const Tensor& tensor); 1044 mutex* input_ref_mutex(int index); 1045 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); 1046 TensorValue release_output(int index); 1047 1048 bool track_allocations() const { return params_->track_allocations; } 1049 1050 // Records temp memory allocation. Tensor object is recorded to identify the 1051 // case where temp memory is used as output memory. 1052 void record_temp_memory_allocation(int64 size, const Tensor& t) 1053 LOCKS_EXCLUDED(stats_mu_); 1054 1055 // Returns recorded size of temporary memory; 1056 int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); 1057 1058 // Records persistent memory allocation, size can be negative indicating 1059 // deallocation. 1060 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1) 1061 LOCKS_EXCLUDED(stats_mu_); 1062 1063 // Returns recorded size and ids of persistent memory. 1064 int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); 1065 1066 std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_); 1067 1068 // Resets counters for temp and persistent memory and recorded ids. 1069 void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_); 1070 1071 bool input_is_ref(int index) const; 1072 1073 private: 1074 Allocator* get_allocator(AllocatorAttributes attr); 1075 1076 // Internal method to add a tensor's buffer to the list of buffers 1077 // referenced during the execution of the Op, so that GPUs may 1078 // accurately track the memory that may not be reused until the Op 1079 // execution completes. 1080 void record_tensor_reference(const Tensor& tensor); 1081 void really_record_tensor_reference(const Tensor& tensor); 1082 1083 // Internal common method used when allocating tensor memory 1084 Status allocate_tensor(DataType type, const TensorShape& shape, 1085 Tensor* out_tensor, 1086 AllocatorAttributes allocator_attr) { 1087 return allocate_tensor(type, shape, out_tensor, allocator_attr, 1088 AllocationAttributes()); 1089 } 1090 1091 Status allocate_tensor(DataType type, const TensorShape& shape, 1092 Tensor* out_tensor, AllocatorAttributes allocator_attr, 1093 const AllocationAttributes& allocation_attr); 1094 1095 // This is called by PersistentTensor::AccessTensor whenever the 1096 // wrapped tensor is retrieved, to ensure the runtime knows that the 1097 // Tensor is being accessed within an Op. This is necessary for 1098 // memory safety of devices like GPUs that queue Ops for 1099 // asynchronous execution after the Compute() method completes. 1100 friend class PersistentTensor; 1101 void NotifyUseOfPersistentTensor(const Tensor& tensor); 1102 1103 Status status_; 1104 Params* params_; // not owned 1105 mutable mutex mu_; // mutable so const accessors can acquire the lock 1106 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); 1107 gtl::InlinedVector<TensorValue, 4> outputs_; 1108 1109 // Constructed only if <params->record_tensor_accesses>. 1110 ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_); 1111 1112 bool is_output_dead_ = false; 1113 1114 // The following data members are only used when allocation tracking is 1115 // enabled. 1116 mutable mutex stats_mu_; 1117 int64 temp_memory_allocated_ GUARDED_BY(stats_mu_); 1118 int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_); 1119 std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>> 1120 temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_); 1121 std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_ 1122 GUARDED_BY(stats_mu_); 1123 1124 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); 1125 }; 1126 1127 // Register your OpKernel by specifying the Op's name, the device the 1128 // kernel runs on, any type attr constraints for this kernel, any 1129 // host-memory args, and the class to instantiate. Examples: 1130 // 1131 // // A kernel that supports all types. 1132 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); 1133 // 1134 // // The following are equivalent ways of specifying that the kernel only 1135 // // works if the "T" type attr is set to DT_FLOAT. 1136 // REGISTER_KERNEL_BUILDER( 1137 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"), 1138 // SubOp<float>); 1139 // // (You would then repeat this for every type supported by "Sub".) 1140 // 1141 // // This form allows you to specify a list of types as the constraint. 1142 // REGISTER_KERNEL_BUILDER(Name("Sub") 1143 // .Device(DEVICE_CPU) 1144 // .TypeConstraint("T", {DT_FLOAT}), 1145 // SubOp<float>); 1146 // 1147 // // A kernel that expects one of the input tensors in host memory. 1148 // REGISTER_KERNEL_BUILDER( 1149 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); 1150 // 1151 // See kernel_def_builder for details. 1152 1153 // Instantiate an OpKernel that has been registered. Returns nullptr 1154 // if no operation for that type of device / input signature combination 1155 // (and a NOT_FOUND *status), or there is an error in construction (and 1156 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership 1157 // of the returned pointer. 1158 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...); 1159 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). 1160 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, 1161 DeviceBase* device, 1162 Allocator* allocator, 1163 const NodeDef& def, 1164 int graph_def_version, Status* status); 1165 Status CreateOpKernel(DeviceType device_type, DeviceBase* device, 1166 Allocator* allocator, FunctionLibraryRuntime* flib, 1167 const NodeDef& def, int graph_def_version, 1168 OpKernel** kernel); 1169 1170 // Returns into 'device_types' the subset of prioritized_types that this 1171 // binary has registered for the given NodeDef. 1172 // 1173 // REQUIRES: * 'device_types' is not nullptr. 1174 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). 1175 Status SupportedDeviceTypesForNode( 1176 const std::vector<DeviceType>& prioritized_types, const NodeDef& def, 1177 DeviceTypeVector* device_types); 1178 1179 // Returns a message with a description of the kernels registered for op 1180 // `op_name`. 1181 string KernelsRegisteredForOp(StringPiece op_name); 1182 1183 // Call once after Op registration has completed. 1184 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); 1185 1186 // ----------------------------------------------------------------------------- 1187 // OpKernel registration implementation follows, please ignore. 1188 1189 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. 1190 namespace register_kernel { 1191 1192 class Name : public KernelDefBuilder { 1193 public: 1194 // With selective registration, kernels whose implementation class is not used 1195 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in 1196 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an 1197 // implementation class with a used kernel would get through that mechanism. 1198 // 1199 // This mechanism stops that registration by changing the name of the kernel 1200 // for the unused op to one that is ignored by 1201 // OpKernelRegistrar::InitInternal. Note that this method alone is 1202 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at 1203 // compilation time, so this method doesn't actually reduce code size. 1204 explicit Name(const char* op) 1205 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {} 1206 }; 1207 1208 namespace system { 1209 1210 class Name : public KernelDefBuilder { 1211 public: 1212 // For system kernels, we ignore selective registration and 1213 // unconditionally register the kernel. 1214 explicit Name(const char* op) : KernelDefBuilder(op) {} 1215 }; 1216 1217 } // namespace system 1218 1219 } // namespace register_kernel 1220 1221 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ 1222 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) 1223 1224 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ 1225 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) 1226 1227 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ 1228 constexpr bool should_register_##ctr##__flag = \ 1229 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \ 1230 static ::tensorflow::kernel_factory::OpKernelRegistrar \ 1231 registrar__body__##ctr##__object( \ 1232 should_register_##ctr##__flag \ 1233 ? ::tensorflow::register_kernel::kernel_builder.Build() \ 1234 : nullptr, \ 1235 #__VA_ARGS__, \ 1236 [](::tensorflow::OpKernelConstruction* context) \ 1237 -> ::tensorflow::OpKernel* { \ 1238 return new __VA_ARGS__(context); \ 1239 }); 1240 1241 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as 1242 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered 1243 // unconditionally even when selective registration is used. 1244 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ 1245 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \ 1246 __VA_ARGS__) 1247 1248 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ 1249 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) 1250 1251 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ 1252 static ::tensorflow::kernel_factory::OpKernelRegistrar \ 1253 registrar__body__##ctr##__object( \ 1254 ::tensorflow::register_kernel::system::kernel_builder.Build(), \ 1255 #__VA_ARGS__, \ 1256 [](::tensorflow::OpKernelConstruction* context) \ 1257 -> ::tensorflow::OpKernel* { \ 1258 return new __VA_ARGS__(context); \ 1259 }); 1260 1261 void* GlobalKernelRegistry(); 1262 1263 // If node_def has a corresponding kernel registered on device_type, 1264 // returns OK and fill in the kernel def and kernel_class_name. <def> and 1265 // <kernel_class_name> may be null. 1266 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, 1267 const KernelDef** def, string* kernel_class_name); 1268 1269 // Writes a list of all registered kernels to LOG(INFO), to help users debug 1270 // missing kernel errors. 1271 void LogAllRegisteredKernels(); 1272 1273 namespace kernel_factory { 1274 1275 class OpKernelRegistrar { 1276 public: 1277 typedef OpKernel* (*Factory)(OpKernelConstruction*); 1278 1279 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, 1280 Factory factory) { 1281 // Perform the check in the header to allow compile-time optimization 1282 // to a no-op, allowing the linker to remove the kernel symbols. 1283 if (kernel_def != nullptr) { 1284 InitInternal(kernel_def, kernel_class_name, factory); 1285 } 1286 } 1287 1288 private: 1289 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, 1290 Factory factory); 1291 }; 1292 1293 } // namespace kernel_factory 1294 1295 // ----------------------------------------------------------------------------- 1296 // Template and inline method implementations, please ignore 1297 1298 template <class T> 1299 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { 1300 return GetNodeAttr(def(), attr_name, value); 1301 } 1302 1303 inline DataType OpKernelContext::input_dtype(int index) const { 1304 DCHECK_GE(index, 0); 1305 DCHECK_LT(index, num_inputs()); 1306 const TensorValue& value((*params_->inputs)[index]); 1307 if (value.is_ref()) { 1308 return MakeRefType(value->dtype()); 1309 } else { 1310 return value->dtype(); 1311 } 1312 } 1313 1314 inline MemoryType OpKernelContext::input_memory_type(int index) const { 1315 DCHECK_GE(index, 0); 1316 DCHECK_LT(index, num_inputs()); 1317 return op_kernel().input_memory_types()[index]; 1318 } 1319 1320 inline DataType OpKernelContext::expected_output_dtype(int index) const { 1321 DCHECK_GE(index, 0); 1322 DCHECK_LT(index, num_outputs()); 1323 return params_->op_kernel->output_type(index); 1324 } 1325 1326 inline MemoryType OpKernelContext::output_memory_type(int index) const { 1327 DCHECK_GE(index, 0); 1328 DCHECK_LT(index, num_outputs()); 1329 return op_kernel().output_memory_types()[index]; 1330 } 1331 1332 inline bool OpKernelContext::input_is_ref(int index) const { 1333 const TensorValue& value((*params_->inputs)[index]); 1334 return value.is_ref(); 1335 } 1336 1337 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { 1338 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), 1339 params_->record_tensor_accesses); 1340 if (params_->record_tensor_accesses) { 1341 really_record_tensor_reference(tensor); 1342 } 1343 } 1344 1345 inline void OpKernelContext::retrieve_accessed_tensors( 1346 TensorReferenceVector* out_vector) { 1347 if (params_->record_tensor_accesses) { 1348 mutex_lock l(mu_); 1349 referenced_tensors_->FreezeAndReturnReferences(out_vector); 1350 } 1351 } 1352 1353 // no input if tensor == nullptr. 1354 inline bool OpKernelContext::has_input(int index) const { 1355 DCHECK_GE(index, 0); 1356 DCHECK_LT(index, num_inputs()); 1357 return (*params_->inputs)[index].tensor != nullptr; 1358 } 1359 1360 inline mutex* OpKernelContext::input_ref_mutex(int index) { 1361 DCHECK_GE(index, 0); 1362 DCHECK_LT(index, num_inputs()); 1363 DCHECK(input_is_ref(index)); 1364 return (*params_->inputs)[index].mutex_if_ref; 1365 } 1366 1367 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { 1368 if (t.IsInitialized()) { 1369 record_tensor_reference(t); 1370 } 1371 } 1372 1373 inline Tensor* OpKernelContext::mutable_output(int index) { 1374 DCHECK_GE(index, 0); 1375 DCHECK_LT(index, num_outputs()); 1376 // No need to record_tensor_reference since the output must already 1377 // have been set by a call that did so. 1378 return outputs_[index].tensor; 1379 } 1380 1381 inline TensorValue OpKernelContext::release_output(int index) { 1382 DCHECK_GE(index, 0); 1383 DCHECK_LT(index, num_outputs()); 1384 TensorValue value = outputs_[index]; 1385 outputs_[index] = TensorValue(); 1386 return value; 1387 } 1388 1389 inline Status OpKernelContext::forward_input_or_allocate_output( 1390 gtl::ArraySlice<int> candidate_input_indices, int output_index, 1391 const TensorShape& output_shape, Tensor** output) { 1392 for (int input_index : candidate_input_indices) { 1393 if (forward_input_to_output_with_shape(input_index, output_index, 1394 output_shape, output)) { 1395 return Status::OK(); 1396 } 1397 } 1398 return allocate_output(output_index, output_shape, output); 1399 } 1400 1401 inline Status OpKernelContext::forward_input_or_allocate_output( 1402 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name, 1403 const TensorShape& output_shape, Tensor** output) { 1404 for (const StringPiece& input_name : candidate_input_names) { 1405 if (forward_input_to_output_with_shape(input_name, output_name, 1406 output_shape, output) 1407 .ok()) { 1408 return Status::OK(); 1409 } 1410 } 1411 return allocate_output(output_name, output_shape, output); 1412 } 1413 1414 template <typename T> 1415 T* OpKernelContext::op_device_context() { 1416 static_assert(std::is_base_of<DeviceContext, T>::value, 1417 "T is not a subclass of DeviceContext"); 1418 return static_cast<T*>(op_device_context()); 1419 } 1420 1421 template <typename T> 1422 T* OpKernelContext::input_device_context(int index) { 1423 DCHECK_GE(index, 0); 1424 DCHECK_LT(index, params_->input_device_contexts->size()); 1425 static_assert(std::is_base_of<DeviceContext, T>::value, 1426 "T is not a subclass of DeviceContext"); 1427 return static_cast<T*>((*params_->input_device_contexts)[index]); 1428 } 1429 1430 inline DeviceContext* OpKernelContext::input_device_context(int index) { 1431 DCHECK_GE(index, 0); 1432 DCHECK_LT(index, params_->input_device_contexts->size()); 1433 return (*params_->input_device_contexts)[index]; 1434 } 1435 1436 inline const Tensor& OpInputList::operator[](int i) const { 1437 DCHECK_GE(i, 0); 1438 DCHECK_LT(i, stop_ - start_); 1439 return ctx_->input(start_ + i); 1440 } 1441 1442 inline mutex* OpMutableInputList::ref_mutex(int i) { 1443 DCHECK_GE(i, 0); 1444 DCHECK_LT(i, stop_ - start_); 1445 return ctx_->input_ref_mutex(start_ + i); 1446 } 1447 1448 inline Tensor OpMutableInputList::at(int i, bool lock_held) { 1449 DCHECK_GE(i, 0); 1450 DCHECK_LT(i, stop_ - start_); 1451 return ctx_->mutable_input(start_ + i, lock_held); 1452 } 1453 1454 inline Tensor* OpOutputList::operator[](int i) { 1455 DCHECK_GE(i, 0); 1456 DCHECK_LT(i, stop_ - start_); 1457 return ctx_->mutable_output(start_ + i); 1458 } 1459 1460 inline bool OpOutputList::required(int i) const { 1461 DCHECK_GE(i, 0); 1462 DCHECK_LT(i, stop_ - start_); 1463 return ctx_->output_required(start_ + i); 1464 } 1465 1466 inline DataType OpOutputList::expected_output_dtype(int i) const { 1467 DCHECK_GE(i, 0); 1468 DCHECK_LT(i, stop_ - start_); 1469 return ctx_->expected_output_dtype(start_ + i); 1470 } 1471 1472 inline Status OpOutputList::allocate(int i, const TensorShape& shape, 1473 Tensor** output) { 1474 DCHECK_GE(i, 0); 1475 DCHECK_LT(i, stop_ - start_); 1476 return ctx_->allocate_output(start_ + i, shape, output); 1477 } 1478 1479 inline void OpOutputList::set(int i, const Tensor& tensor) { 1480 DCHECK_GE(i, 0); 1481 DCHECK_LT(i, stop_ - start_); 1482 ctx_->set_output(start_ + i, tensor); 1483 } 1484 1485 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { 1486 DCHECK_GE(i, 0); 1487 DCHECK_LT(i, stop_ - start_); 1488 ctx_->set_output_ref(i, mu, tensor_for_ref); 1489 } 1490 1491 // Convenience macros for asserting and handling exceptional conditions. 1492 // Analogous to the CHECK* macros provided by logging.h. 1493 // 1494 // Example use: 1495 // void Compute(OperationContext* context) { 1496 // OP_REQUIRES(context, context->num_inputs() == 2, 1497 // errors::InvalidArgument("FooOp requires 2 arguments")); 1498 // ... 1499 // Status status = SomeUncertainMethod(); 1500 // OP_REQUIRES_OK(context, status); 1501 // ... 1502 // } 1503 1504 #define OP_REQUIRES(CTX, EXP, STATUS) \ 1505 do { \ 1506 if (!TF_PREDICT_TRUE(EXP)) { \ 1507 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 1508 return; \ 1509 } \ 1510 } while (0) 1511 1512 #define OP_REQUIRES_OK(CTX, ...) \ 1513 do { \ 1514 ::tensorflow::Status _s(__VA_ARGS__); \ 1515 if (!TF_PREDICT_TRUE(_s.ok())) { \ 1516 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 1517 return; \ 1518 } \ 1519 } while (0) 1520 1521 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ 1522 do { \ 1523 if (!TF_PREDICT_TRUE(EXP)) { \ 1524 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 1525 (CALLBACK)(); \ 1526 return; \ 1527 } \ 1528 } while (0) 1529 1530 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ 1531 do { \ 1532 ::tensorflow::Status _s(STATUS); \ 1533 if (!TF_PREDICT_TRUE(_s.ok())) { \ 1534 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 1535 (CALLBACK)(); \ 1536 return; \ 1537 } \ 1538 } while (0) 1539 1540 } // namespace tensorflow 1541 1542 #endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ 1543