Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
     17 #define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
     18 
     19 #include <functional>
     20 
     21 #include <utility>
     22 #include <vector>
     23 #include "tensorflow/core/framework/allocator.h"
     24 #include "tensorflow/core/framework/cancellation.h"
     25 #include "tensorflow/core/framework/control_flow.h"
     26 #include "tensorflow/core/framework/device_base.h"
     27 #include "tensorflow/core/framework/kernel_def_builder.h"
     28 #include "tensorflow/core/framework/node_def_util.h"
     29 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
     30 #include "tensorflow/core/framework/rendezvous.h"
     31 #include "tensorflow/core/framework/selective_registration.h"
     32 #include "tensorflow/core/framework/session_state.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_shape.h"
     35 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
     36 #include "tensorflow/core/framework/tracking_allocator.h"
     37 #include "tensorflow/core/framework/types.h"
     38 #include "tensorflow/core/framework/types.pb.h"
     39 #include "tensorflow/core/framework/unique_tensor_references.h"
     40 #include "tensorflow/core/lib/core/errors.h"
     41 #include "tensorflow/core/lib/core/status.h"
     42 #include "tensorflow/core/lib/gtl/array_slice.h"
     43 #include "tensorflow/core/lib/gtl/manual_constructor.h"
     44 #include "tensorflow/core/platform/env.h"
     45 #include "tensorflow/core/platform/logging.h"
     46 #include "tensorflow/core/platform/macros.h"
     47 #include "tensorflow/core/platform/mutex.h"
     48 #include "tensorflow/core/platform/thread_annotations.h"
     49 #include "tensorflow/core/platform/types.h"
     50 
     51 namespace Eigen {
     52 struct ThreadPoolDevice;
     53 struct GpuDevice;
     54 struct SyclDevice;
     55 }  // end namespace Eigen
     56 
     57 namespace tensorflow {
     58 
     59 namespace checkpoint {
     60 class TensorSliceReaderCacheWrapper;
     61 }  // namespace checkpoint
     62 
     63 class AsyncOpKernel;
     64 class CallFrameInterface;
     65 class FunctionLibraryRuntime;
     66 class OpKernelConstruction;  // declared below
     67 class OpKernelContext;       // declared below
     68 class OpRegistryInterface;
     69 class ResourceMgr;
     70 class ScopedStepContainer;
     71 class StepStatsCollector;
     72 
     73 class OpKernel {
     74  public:
     75   // OpKernel won't be instantiated by the scheduler, so you may perform
     76   // expensive initialization in the descendant's constructor.
     77   explicit OpKernel(OpKernelConstruction* context);
     78 
     79   // Specialized constructor that enables the descendant to provide a different
     80   // `NodeDef` value. For example, this constructor can be used to provide a
     81   // stripped-down `NodeDef` that does not contain the full set of attrs (such
     82   // as tensor values) if the descendant stores them in a different form.
     83   explicit OpKernel(OpKernelConstruction* context,
     84                     std::unique_ptr<const NodeDef> node_def);
     85 
     86   virtual ~OpKernel();
     87 
     88   // An OpKernel's computation can be either synchronous or
     89   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
     90   // may be called concurrently (e.g. by multiple executions of the same graph
     91   // concurrently).
     92   //
     93   // Most OpKernels should compute synchronously.  They should
     94   // subclass OpKernel and override the Compute() method and have it
     95   // return after completing the supplied work.
     96   //
     97   // A few special kernels might need to be asynchronous to bound the
     98   // number of threads (e.g., network receive operations). These
     99   // kernels must subclass AsyncOpKernel and override
    100   // AsyncOpKernel::ComputeAsync().
    101   //
    102   // In both cases, implementations of Compute() and ComputeAsync()
    103   // get inputs and write outputs through the given OpKernelContext
    104   // and returns a status via context->SetStatus(). They must be
    105   // thread-safe.
    106 
    107   // Synchronous compute.
    108   //
    109   // "context" is guaranteed to be alive until Compute() returns.
    110   virtual void Compute(OpKernelContext* context) = 0;
    111 
    112   // Returns nullptr iff this op kernel is synchronous.
    113   virtual AsyncOpKernel* AsAsync() { return nullptr; }
    114 
    115   // Returns true iff this op kernel is considered "expensive". The
    116   // runtime may use this flag to optimize graph execution for example
    117   // to "inline" inexpensive kernels.
    118   virtual bool IsExpensive() { return expensive_; }
    119 
    120   // Accessors.
    121   const NodeDef& def() const { return *def_; }
    122   const string& name() const;              // Same as def().name()
    123   const string& type_string() const;       // Same as def().op()
    124   const string& requested_device() const;  // Same as def().device()
    125   bool is_internal() const { return is_internal_; }
    126 
    127   int num_inputs() const { return input_types_.size(); }
    128   DataType input_type(int i) const { return input_types_[i]; }
    129   const DataTypeVector& input_types() const { return input_types_; }
    130   const MemoryTypeVector& input_memory_types() const {
    131     return input_memory_types_;
    132   }
    133   const string& requested_input(int i) const;  // Same as def().input(i)
    134 
    135   int num_outputs() const { return output_types_.size(); }
    136   DataType output_type(int o) const { return output_types_[o]; }
    137   const DataTypeVector& output_types() const { return output_types_; }
    138   const MemoryTypeVector& output_memory_types() const {
    139     return output_memory_types_;
    140   }
    141 
    142   Status InputRange(StringPiece input_name, int* start, int* stop) const;
    143   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
    144 
    145   // We allow legacy scalars within Google up until GraphDef version 6.
    146   // TODO(irving): Remove when we can drop support for GraphDef version 5.
    147   bool allow_legacy_scalars() const {
    148 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
    149     return graph_def_version_ < 6;
    150 #else
    151     return false;
    152 #endif
    153   }
    154 
    155   // Allow either scalars or (if allowing legacy scalars) shape (1,).
    156   bool IsLegacyScalar(const TensorShape& shape) const {
    157     return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
    158                                  shape.dim_size(0) == 1);
    159   }
    160 
    161   // Allow rank 1 or (if allowing legacy scalars) rank 0.
    162   bool IsLegacyVector(const TensorShape& shape) const {
    163     return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
    164   }
    165 
    166   // Turn a shape Tensor into a TensorShape
    167   // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
    168   Status MakeShape(const Tensor& shape, TensorShape* out) const;
    169 
    170  private:
    171   const std::unique_ptr<const NodeDef> def_;
    172   const DataTypeVector input_types_;
    173   const MemoryTypeVector input_memory_types_;
    174   const DataTypeVector output_types_;
    175   const MemoryTypeVector output_memory_types_;
    176   const int graph_def_version_;
    177   const bool is_internal_;  // True if this is an internal operation
    178   NameRangeMap input_name_map_;
    179   NameRangeMap output_name_map_;
    180   bool expensive_;
    181 
    182   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
    183 };
    184 
    185 class AsyncOpKernel : public OpKernel {
    186  public:
    187   using OpKernel::OpKernel;  // Lift OpKernel constructors.
    188 
    189   // Asynchronous compute.
    190   //
    191   // Implementations of ComputeAsync() must run "done" to signal the
    192   // completion of the computation. "context" is guaranteed to be
    193   // alive until the "done" callback starts.
    194   typedef std::function<void()> DoneCallback;
    195   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
    196 
    197   AsyncOpKernel* AsAsync() final { return this; }
    198 
    199   void Compute(OpKernelContext* context) final;
    200 
    201   bool IsExpensive() override { return true; }
    202 };
    203 
    204 // Wraps a tensor that is held by an Op across calls to Compute(). For
    205 // memory safety when using asynchronous devices like GPUs, the system
    206 // must be notified when a Tensor is used inside an Op execution. The
    207 // wrapper ensures that all uses of the Tensor are tracked, because in
    208 // order to retrieve the Tensor the caller must use AccessTensor which
    209 // notifies the context.
    210 class PersistentTensor {
    211  public:
    212   PersistentTensor() {}
    213   explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
    214 
    215   // Caller does not own the returned Tensor*.
    216   Tensor* AccessTensor(OpKernelConstruction* context);
    217   // Caller does not own the returned Tensor*.
    218   Tensor* AccessTensor(OpKernelContext* context);
    219 
    220   // The check for initialization does not need to access the
    221   // underlying tensor buffer.
    222   bool IsInitialized() const { return tensor_.IsInitialized(); }
    223 
    224   int64 NumElements() const { return tensor_.NumElements(); }
    225 
    226   int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
    227 
    228  private:
    229   Tensor tensor_;
    230 };
    231 
    232 class OpKernelConstruction {
    233  public:
    234   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
    235                        Allocator* allocator, const NodeDef* node_def,
    236                        const OpDef* op_def, FunctionLibraryRuntime* flib,
    237                        const DataTypeSlice& input_types,
    238                        const MemoryTypeSlice& input_memory_types,
    239                        const DataTypeSlice& output_types,
    240                        const MemoryTypeSlice& output_memory_types,
    241                        int graph_def_version, Status* status);
    242 
    243   Env* env() const { return device_->env(); }
    244 
    245   // Allocation of tensors during kernel construction:
    246   //
    247   // It is legal to temporarily allocate scratch tensor storage during
    248   // Op kernel construction. Scratch tensors should be allocated using
    249   // allocate_temp below. Some kernels need to keep tensors in between
    250   // invocations. If such a Tensor is allocated during kernel
    251   // construction this must be done using allocate_persistent, and the
    252   // Op may only store the returned PersistentTensor object. When the
    253   // Tensor is needed in a subsequent invocation, it can be retrieved
    254   // from the PersistentTensor using the AccessTensor method. This
    255   // ensures that the system is made aware of any use of the tensor's
    256   // allocated memory, which is needed for correctness on asynchronous
    257   // devices such as GPUs.
    258 
    259   // Allocates a temporary Tensor of the specified type and shape. The
    260   // Tensor must not be used after kernel construction is
    261   // complete. See comment above.
    262   Status allocate_temp(DataType type, const TensorShape& shape,
    263                        Tensor* out_temp);
    264 
    265   // Allocates a Tensor of the specified type and shape which the Op
    266   // plans to maintain as persistent state. out_persistent holds the
    267   // PersistentTensor which is the object the caller should store. For
    268   // convenience, if out_tensor is non-null then it will be filled in
    269   // with a Tensor* pointing to the newly-allocated tensor which the
    270   // caller can use instead of calling
    271   // out_persistent->AccessTensor. The caller does not own out_tensor
    272   // and should not keep a copy of it. See comment above.
    273   Status allocate_persistent(DataType type, const TensorShape& shape,
    274                              PersistentTensor* out_persistent,
    275                              Tensor** out_tensor);
    276 
    277   // User-supplied configuration of this operation.
    278   const NodeDef& def() const { return *def_; }
    279 
    280   // For inspecting the inputs to this operation.
    281   int num_inputs() const { return input_types_.size(); }
    282   DataType input_type(int i) const { return input_types_[i]; }
    283   const DataTypeSlice& input_types() const { return input_types_; }
    284   const MemoryTypeSlice& input_memory_types() const {
    285     return input_memory_types_;
    286   }
    287 
    288   // For inspecting the outputs expected from this operation.
    289   int num_outputs() const { return output_types_.size(); }
    290   DataType output_type(int i) const { return output_types_[i]; }
    291   const DataTypeSlice& output_types() const { return output_types_; }
    292   const MemoryTypeSlice& output_memory_types() const {
    293     return output_memory_types_;
    294   }
    295 
    296   // If expected_inputs == inputs() and expected_outputs == output_types(),
    297   // returns OK, else returns INVALID_ARGUMENT with an error message.
    298   // Recommended for Ops with dynamic signatures.
    299   Status MatchSignature(const DataTypeSlice expected_inputs,
    300                         const DataTypeSlice expected_outputs);
    301 
    302   // For recording configuration errors during construction.
    303   void SetStatus(const Status& status);
    304   const Status& status() const { return *status_; }
    305 
    306   // Look up the attr with name attr_name and set *value to its value.  If no
    307   // attr with attr_name is found in def(), or the attr does not have
    308   // a matching type, a non-ok status will be returned.
    309   template <class T>
    310   Status GetAttr(StringPiece attr_name, T* value) const;
    311 
    312   // Return true if the attr_name is defined in def().
    313   bool HasAttr(StringPiece attr_name) const;
    314 
    315   // Return the device type.
    316   const DeviceType& device_type() const { return device_type_; }
    317 
    318   // If not nullptr, the kernel can instantiate functions defined in
    319   // the library. E.g.,
    320   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
    321   FunctionLibraryRuntime* function_library() const { return flib_; }
    322 
    323   // The GraphDef version whose behavior we should follow.
    324   int graph_def_version() const { return graph_def_version_; }
    325 
    326   // Helper routines for the OP_REQUIRES macros
    327   void CtxFailure(const Status& s);
    328   void CtxFailureWithWarning(const Status& s);
    329   void CtxFailure(const char* file, int line, const Status& s);
    330   void CtxFailureWithWarning(const char* file, int line, const Status& s);
    331 
    332   // Unrecommended functions: these are functions that have some
    333   // current uses but are not recommended for use, and may go away at
    334   // some future major version release.
    335 
    336   // May be used, e.g., to get GPU handles, etc.
    337   //
    338   // Currently only used to call MakeTensorFromProto() for
    339   // implementing ConstantOp for every device.  See comments
    340   // on Device::MakeTensorFromProto for longer-term replacement
    341   // ideas.
    342   DeviceBase* device() const { return device_; }
    343 
    344  private:
    345   const DeviceType device_type_;
    346   DeviceBase* const device_;
    347   Allocator* allocator_;
    348   const NodeDef* def_;
    349   const OpDef* op_def_;
    350   FunctionLibraryRuntime* flib_;
    351   DataTypeSlice input_types_;
    352   MemoryTypeSlice input_memory_types_;
    353   DataTypeSlice output_types_;
    354   MemoryTypeSlice output_memory_types_;
    355   const int graph_def_version_;
    356   Status* status_;
    357 
    358   // Allow op_def_ across from OpKernel, but not from subclasses.
    359   // TODO(irving): Remove protos from this header entirely.
    360   friend class OpKernel;
    361 
    362   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
    363 };
    364 
    365 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
    366 // tensorflow::gtl::iterator_range to make the below container classes
    367 // unnecessary.
    368 template <typename ListType, typename ElementType>
    369 class OpArgIterator {
    370  public:
    371   typedef OpArgIterator<ListType, ElementType> ME;
    372   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
    373   bool operator==(const ME& rhs) {
    374     DCHECK(list_ == rhs.list_);
    375     return i_ == rhs.i_;
    376   }
    377   bool operator!=(const ME& rhs) {
    378     DCHECK(list_ == rhs.list_);
    379     return i_ != rhs.i_;
    380   }
    381   void operator++() { ++i_; }
    382   ElementType& operator*() { return (*list_)[i_]; }
    383 
    384  private:
    385   const ListType* const list_;
    386   int i_;
    387 };
    388 
    389 // Utility class for representing a list of immutable input tensors
    390 // that are passed to the op as a single named argument.
    391 class OpInputList {
    392  public:
    393   typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
    394   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
    395   OpInputList(OpKernelContext* ctx, int start, int stop)
    396       : ctx_(ctx), start_(start), stop_(stop) {}
    397   OpInputList& operator=(const OpInputList& other) = default;
    398   const Tensor& operator[](int i) const;
    399   int size() const { return stop_ - start_; }
    400   Iterator begin() const { return Iterator(this, 0); }
    401   Iterator end() const { return Iterator(this, size()); }
    402 
    403  private:
    404   OpKernelContext* ctx_;  // not owned
    405   int start_;
    406   int stop_;
    407 };
    408 
    409 // Utility class for representing a list of mutable ("ref") input tensors
    410 // that are passed to the op as a single named argument.
    411 class OpMutableInputList {
    412  public:
    413   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
    414   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
    415       : ctx_(ctx), start_(start), stop_(stop) {}
    416   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
    417   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
    418   Tensor at(int i, bool lock_held);
    419   mutex* ref_mutex(int i);
    420   int size() const { return stop_ - start_; }
    421   Iterator begin() const { return Iterator(this, 0); }
    422   Iterator end() const { return Iterator(this, size()); }
    423 
    424  private:
    425   OpKernelContext* ctx_;  // not owned
    426   int start_;
    427   int stop_;
    428 };
    429 
    430 // Utility class for representing a list of output tensors that are
    431 // grouped as a single named output.
    432 class OpOutputList {
    433  public:
    434   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
    435   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
    436   OpOutputList(OpKernelContext* ctx, int start, int stop)
    437       : ctx_(ctx), start_(start), stop_(stop) {}
    438   OpOutputList& operator=(const OpOutputList& other) = default;
    439   Tensor* operator[](int i);
    440   bool required(int i) const;
    441   DataType expected_output_dtype(int i) const;
    442   Status allocate(int i, const TensorShape& shape, Tensor** output);
    443   void set(int i, const Tensor& tensor);
    444   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
    445   int size() const { return stop_ - start_; }
    446   Iterator begin() const { return Iterator(this, 0); }
    447   Iterator end() const { return Iterator(this, size()); }
    448 
    449  private:
    450   OpKernelContext* ctx_;  // not owned
    451   int start_;
    452   int stop_;
    453 };
    454 
    455 // Holds a tensor or tensor reference. For tensor references, we need
    456 // a mutex to prevent concurrent access to the tensor.
    457 struct TensorValue {
    458   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
    459   TensorValue(Tensor* t)  // NOLINT(runtime/explicit)
    460       : mutex_if_ref(nullptr), tensor(t) {}
    461   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
    462   Tensor* operator->() const { return tensor; }
    463   bool is_ref() const { return mutex_if_ref != nullptr; }
    464 
    465   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
    466   Tensor* tensor;
    467 };
    468 
    469 class OpKernelContext {
    470  public:
    471   // The first element of a WrappedAllocator is a "base" Allocator and
    472   // the second element is that Allocator wrapped by a
    473   // TrackingAllocator
    474   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
    475 
    476   // TODO(zhifengc): Do some cleanup of Params.
    477   // The Params struct is passed in to initialize an OpKernelContext,
    478   // and must outlive the OpKernelContext.
    479   struct Params {
    480     ~Params() { delete eigen_gpu_device; }
    481 
    482     // The step being executed.
    483     int64 step_id = 0;
    484 
    485     // The op kernel being computed.
    486     OpKernel* op_kernel = nullptr;
    487 
    488     // The device on which the kernel is running.
    489     DeviceBase* device = nullptr;
    490 
    491     // The Eigen GPU device wrapper, which may include a per-op
    492     // wrapped allocator. The concrete type of this object depends on
    493     // the type of this->device, so eigen_gpu_device can't be an
    494     // inline member and must be heap allocated. However, we don't
    495     // want to allocate a new eigen_gpu_device for every Op that is
    496     // executed. Instead this member is allocated on first use using
    497     // ensure_eigen_gpu_device, and then if the Params structure is
    498     // re-used for subsequent Ops, the eigen_gpu_device is
    499     // ReInitialized in the OpKernelContext constructor. Unlike the
    500     // other pointers in Params, this one is owned by Params.
    501     PerOpGpuDevice* eigen_gpu_device = nullptr;
    502 
    503     inline void ensure_eigen_gpu_device() {
    504       DCHECK(device);
    505       if (nullptr == eigen_gpu_device) {
    506         // Surprisingly, MakeGpuDevice will return nullptr if the
    507         // device is not a GPU device. This is ok, since those devices
    508         // will never use eigen_gpu_device. It seems better to have
    509         // ensure_eigen_gpu_device fall through and regenerate the
    510         // nullptr every time an OpKernelContext is instantiated, than
    511         // to do an unnecessary allocation of a dummy eigen GPU
    512         // device for CPU device Ops.
    513         eigen_gpu_device = device->MakeGpuDevice();
    514       }
    515     }
    516 
    517     bool track_allocations = false;
    518     bool log_memory = false;
    519     bool record_tensor_accesses = false;
    520 
    521     // Array indexed by output number for this node
    522     const AllocatorAttributes* output_attr_array = nullptr;
    523 
    524     // Shared resources accessible by this op kernel invocation.
    525     ResourceMgr* resource_manager = nullptr;
    526 
    527     // Per-step resources accessible by this op kernel invocation should be
    528     // stored in this container..
    529     ScopedStepContainer* step_container = nullptr;
    530 
    531     // Mechanism used by this op kernel invocation to communicate with
    532     // computations running on other devices.
    533     Rendezvous* rendezvous = nullptr;
    534 
    535     // The session state for this op.
    536     SessionState* session_state = nullptr;
    537 
    538     // The tensor store for this op.
    539     TensorStore* tensor_store = nullptr;
    540 
    541     // Mechanism used by this op kernel invocation to register a callback
    542     // for its cancellation.
    543     CancellationManager* cancellation_manager = nullptr;
    544 
    545     // Inputs to this op kernel.
    546     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
    547     bool is_input_dead = false;
    548 
    549     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
    550         nullptr;
    551 
    552     // Device contexts.
    553     const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
    554         nullptr;
    555     DeviceContext* op_device_context = nullptr;
    556 
    557     // Control-flow op supports.
    558     FrameAndIter frame_iter;
    559 
    560     // Function call supports.
    561     CallFrameInterface* call_frame = nullptr;
    562     FunctionLibraryRuntime* function_library = nullptr;
    563     std::function<void(std::function<void()>)>* runner = nullptr;
    564     StepStatsCollector* stats_collector = nullptr;
    565 
    566     // TensorSliceReaderCache support.
    567     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
    568   };
    569 
    570   // params must outlive the OpKernelContext.
    571   explicit OpKernelContext(Params* params);
    572   OpKernelContext(Params* params, int noutputs);
    573   ~OpKernelContext();
    574 
    575   Env* env() const { return params_->device->env(); }
    576 
    577   int64 step_id() const { return params_->step_id; }
    578 
    579   const OpKernel& op_kernel() const { return *params_->op_kernel; }
    580 
    581   // Input/output signature.
    582 
    583   int num_inputs() const { return params_->inputs->size(); }
    584   DataType input_dtype(int index) const;
    585   Status input_dtype(StringPiece name, DataType* dtype) const;
    586   MemoryType input_memory_type(int index) const;
    587 
    588   int num_outputs() const { return outputs_.size(); }
    589   DataType expected_output_dtype(int index) const;
    590   MemoryType output_memory_type(int index) const;
    591 
    592   // Input
    593 
    594   // Returns an immutable input tensor. May only be used for non-Ref
    595   // inputs. For Ref inputs use mutable_input below.
    596   // REQUIRES: !IsRefType(input_dtype(index))
    597   // TODO(mrry): Convert this to return Status.
    598   const Tensor& input(int index);
    599 
    600   // Returns the named immutable input tensor in "tensor", as defined
    601   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
    602   // use mutable_input below.
    603   // REQUIRES: !IsRefType(input_dtype(index))
    604   // REQUIRES: the named input must not be a list.
    605   Status input(StringPiece name, const Tensor** tensor);
    606 
    607   // Returns the named list-valued immutable input in "list", as
    608   // defined in the OpDef.  If the named output is not list-valued,
    609   // returns a one-element list. May only be used for non-Ref
    610   // inputs. For Ref inputs use mutable_input below.
    611   // REQUIRES: !IsRefType(input_dtype(index))
    612   Status input_list(StringPiece name, OpInputList* list);
    613 
    614   // For mutable inputs, use the following together to make sure there
    615   // is no concurrent access to mutable_input(), e.g.:
    616   // {
    617   //   Tensor& t = context->mutable_input(index);
    618   //   mutex_lock lock(*context->input_ref_mutex(index));
    619   //   // modify the values in t
    620   // }
    621   // REQUIRES: IsRefType(input_dtype(index))
    622   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
    623 
    624   // Returns a mutable input tensor. Must be used to access Ref
    625   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
    626   // modify the values stored in the Tensor buffer, and modifications
    627   // will be visible to other Ops reading the same ref tensor. If
    628   // !lock_held the input mutex will be acquired before returning the
    629   // Tensor.
    630   // TODO(mrry): Convert this to return Status.
    631   Tensor mutable_input(int index, bool lock_held);
    632 
    633   // Returns the named mutable input tensor in "tensor", as defined in
    634   // the OpDef. Must be used to access Ref inputs. The values stored
    635   // in the Tensor buffer may be modified, and modifications will be
    636   // visible to other Ops reading the same ref tensor. If !lock_held
    637   // the input mutex will be acquired before returning the Tensor.
    638   // REQUIRES: the named input must not be a list.
    639   // REQUIRES: the named input must be a ref tensor.
    640   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
    641 
    642   // Returns the named list-valued mutable input in "list", as defined
    643   // in the OpDef.  If the named input is not list-valued, returns a
    644   // one-element list. Must be used to access Ref inputs. The values
    645   // stored in the Tensor buffer may be modified, and modifications
    646   // will be visible to other Ops reading the same ref tensor.
    647   // REQUIRES: the named input must be a ref tensor.
    648   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
    649 
    650   // Replace the corresponding Ref Input to use the storage buffer
    651   // used by tensor. If !lock_held the input mutex will be acquired
    652   // before returning the Tensor.
    653   // REQUIRES: IsRefType(input_dtype(index)).
    654   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
    655 
    656   // Replace the corresponding named Ref Input to use the storage
    657   // buffer used by tensor. If !lock_held the input mutex will be
    658   // acquired before returning the Tensor.
    659   // REQUIRES: IsRefType(input_dtype(index)).
    660   Status replace_ref_input(StringPiece name, const Tensor& tensor,
    661                            bool lock_held);
    662 
    663   // Deletes the Tensor object used as the Ref Input at
    664   // input_index. This is not usually necessary and should be used
    665   // with caution. If !lock_held the input mutex will be acquired
    666   // before returning the Tensor.
    667   // REQUIRES: IsRefType(input_dtype(input_index)).
    668   void delete_ref_input(int input_index, bool lock_held);
    669 
    670   // Return true if there is input at the given index. An operator has no
    671   // input at index if its tensor is null. This is primarily used by the
    672   // merge operator.
    673   // TODO(mrry): Convert this to return Status.
    674   bool has_input(int index) const;
    675 
    676   // Returns true if all inputs are the same shape, otherwise sets the
    677   // status to a non-OK value and returns false.
    678   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
    679   bool ValidateInputsAreSameShape(OpKernel* op);
    680 
    681   // Input to output forwarding.
    682 
    683   // Set the output Ref Tensor at output_index to be an alias of the
    684   // input Ref Tensor at input_index.
    685   // REQUIRES: IsRefType(input_dtype(input_index)).
    686   // REQUIRES: IsRefType(output_dtype(output_index)).
    687   void forward_ref_input_to_ref_output(int input_index, int output_index);
    688 
    689   // Returns true when an alias to input[input_index], reshaped to output_shape,
    690   // which is safe to use for in-place computation was written to *output.
    691   // Returns false if input[input_index] has a refcount greater than one, or if
    692   // its type does not match the expected output type of output[output_index],
    693   // or the number of elements in input[input_index] does not equal the number
    694   // of elements in output_shape.
    695   bool forward_input_to_output_with_shape(int input_index, int output_index,
    696                                           const TensorShape& output_shape,
    697                                           Tensor** output) TF_MUST_USE_RESULT;
    698   Status forward_input_to_output_with_shape(StringPiece input_name,
    699                                             StringPiece output_name,
    700                                             const TensorShape& output_shape,
    701                                             Tensor** output) TF_MUST_USE_RESULT;
    702 
    703   // Returns a pointer to a Tensor aliasing the underlying buffer backing
    704   // input[input_index] iff
    705   //   * input[input_index] is not a ref,
    706   //   * the data type, shape, memory type, and allocator attributes of
    707   //     input[input_index] are compatible with those given in dtype, shape,
    708   //     memory_type, and attr,
    709   //   * refcount on the underlying buffer is one.
    710   // Otherwise returns nullptr.
    711   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
    712   // forwarding is only safe if there are no reads via __ldg() after writes
    713   // to the same address.
    714   std::unique_ptr<Tensor> forward_input(
    715       int input_index, DataType dtype, const TensorShape& shape,
    716       MemoryType memory_type,
    717       const AllocatorAttributes& attr) TF_MUST_USE_RESULT;
    718 
    719   // Tries to forward one of the inputs given in input_indices to
    720   // output[output_index]. If none of the given inputs can be forwarded, calls
    721   // allocate_output() to allocate a new output buffer.
    722   Status forward_input_or_allocate_output(
    723       gtl::ArraySlice<int> candidate_input_indices, int output_index,
    724       const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
    725   Status forward_input_or_allocate_output(
    726       gtl::ArraySlice<StringPiece> candidate_input_names,
    727       StringPiece output_name, const TensorShape& output_shape,
    728       Tensor** output) TF_MUST_USE_RESULT;
    729 
    730   // Tries to reuse one of the inputs given in input_indices as a temporary.
    731   // If none of the given inputs can be forwarded, calls
    732   // allocate_temp() to allocate a new temporary buffer.
    733   Status forward_input_or_allocate_temp(
    734       gtl::ArraySlice<int> candidate_input_indices, DataType type,
    735       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
    736       Tensor* out_temp) TF_MUST_USE_RESULT;
    737 
    738   Status forward_input_or_allocate_temp(
    739       gtl::ArraySlice<int> candidate_input_indices, DataType type,
    740       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
    741     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
    742                                           AllocatorAttributes(), out_temp);
    743   }
    744 
    745   // Output
    746 
    747   // Returns the named list-valued output in "list", as defined in the OpDef.
    748   // If the named output is not list-valued, returns a one-element list.
    749   Status output_list(StringPiece name, OpOutputList* list);
    750 
    751   // If output_required(index) returns true, the OpKernel's Compute() method
    752   // should call allocate_output(index, ...), set_output(index, ...),
    753   // set_output_ref(index, ...), or set the status to a non-ok value.
    754   // If it returns false, it may output, but is not required to do so.
    755   // TODO(mrry): Convert this to return Status, and implement a string
    756   // name version.
    757   bool output_required(int index) const {
    758     return true;  // TODO(josh11b): implement
    759   }
    760 
    761   // Allocation of tensors during kernel execution inside the Compute
    762   // method:
    763   //
    764   // There are three methods to allocate Tensors when an Op kernel
    765   // executes.
    766   //
    767   // 1) allocate_persistent. This is only needed for Tensors that will
    768   // be stored by the Op between invocations, and it *must* be used
    769   // for those Tensors. The call returns a PersistentTensor, and that
    770   // is the only object the Op is allowed to hold on to between
    771   // invocations. When the Tensor is needed in a subsequent
    772   // invocation, it can be retrieved from the PersistentTensor using
    773   // the AccessTensor method. This ensures that the system is made
    774   // aware of any use of the tensor's allocated memory, which is
    775   // needed for correctness on asynchronous devices such as GPUs.
    776   //
    777   // 2) allocate_output. This should be used to allocate any tensor
    778   // that is going to be used as an output from the Op at the end of
    779   // the current execution. The caller indicates which output the
    780   // Tensor will be assigned to, and the call returns the
    781   // newly-allocated Tensor. The Tensor can subsequently be assigned
    782   // to during kernel execution, and will be used as the designated
    783   // output when the kernel execution completes.
    784   //
    785   // 3) allocate_temp. This should be used to allocate any scratch
    786   // storage that is needed while the kernel is executing, and will
    787   // not be retained by the Op.
    788   //
    789   // In some cases a Tensor needs to be used as an output even though
    790   // it was previously allocated elsewhere. The Tensor may have been
    791   // passed as an input, or stored in a PersistentTensor during a
    792   // previous kernel execution, or allocated earlier in the kernel
    793   // execution at a time when it was not known which output it would
    794   // be assigned to. In this case the kernel can use set_output or
    795   // set_output_ref to indicate that the tensor should be used as the
    796   // designated output. It is legal to use any previously-allocated
    797   // Tensor as an argument to set_output or set_output_ref, including
    798   // Tensors allocated via allocate_temp. There may be a performance
    799   // penalty to using a Tensor that was not allocated using
    800   // allocate_output. This is because allocate_output uses the
    801   // AllocatorAttributes stored in output_attr_array for the
    802   // designated output. In some cases, using the wrong attributes may
    803   // cause an extra copy of the Tensor's buffer.
    804 
    805   // Allocates output for the specified output index with shape.
    806   // OpKernelContext retains ownership of the returned pointer. See
    807   // comment above.
    808   //
    809   // If memory allocation fails, returns an error status.
    810   //
    811   // REQUIRES: !IsRefType(expected_output_dtype(index))
    812   Status allocate_output(int index, const TensorShape& shape,
    813                          Tensor** tensor) TF_MUST_USE_RESULT;
    814   Status allocate_output(StringPiece name, const TensorShape& shape,
    815                          Tensor** tensor) TF_MUST_USE_RESULT;
    816   // The following methods use the supplied attributes instead of
    817   // those in output_attr_array. The caller is responsible for
    818   // ensuring that the attributes are "compatible" with the
    819   // output_attr_array, e.g. the tensor is allocated on the correct
    820   // device. See comment above.
    821   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
    822                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
    823   Status allocate_output(StringPiece name, const TensorShape& shape,
    824                          Tensor** tensor,
    825                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
    826 
    827   // Allocates a temporary Tensor of the specified type and
    828   // shape. Devices such as GPUs that enqueue Ops for lazy execution
    829   // may retain references to the temporary tensors after the Op's
    830   // Compute method has run. See comment above.
    831   Status allocate_temp(DataType type, const TensorShape& shape,
    832                        Tensor* out_temp, AllocatorAttributes allocator_attr,
    833                        const AllocationAttributes& allocation_attr);
    834   Status allocate_temp(DataType type, const TensorShape& shape,
    835                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
    836     return allocate_temp(type, shape, out_temp, allocator_attr,
    837                          AllocationAttributes());
    838   }
    839   Status allocate_temp(DataType type, const TensorShape& shape,
    840                        Tensor* out_temp) {
    841     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
    842   }
    843 
    844   // Allocates a Tensor of the specified type and shape which the Op
    845   // plans to maintain as persistent state. out_persistent holds the
    846   // PersistentTensor which is the object the caller should store. For
    847   // convenience, if out_tensor is non-null then it will be filled in
    848   // with a Tensor* pointing to the newly-allocated tensor which the
    849   // caller can use instead of calling
    850   // out_persistent->AccessTensor. The caller does not own out_tensor
    851   // and should not keep a copy of it. See comment above.
    852   Status allocate_persistent(DataType type, const TensorShape& shape,
    853                              PersistentTensor* out_persistent,
    854                              Tensor** out_tensor, AllocatorAttributes attr);
    855   Status allocate_persistent(DataType type, const TensorShape& shape,
    856                              PersistentTensor* out_persistent,
    857                              Tensor** out_tensor) {
    858     return allocate_persistent(type, shape, out_persistent, out_tensor,
    859                                AllocatorAttributes());
    860   }
    861 
    862   // Copies a tensor (allocated by the caller) to the specified output
    863   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
    864   // REQUIRES: 'tensor' must have the same MemoryType as
    865   // output_memory_types[index]. See comment above.
    866   Status set_output(StringPiece name, const Tensor& tensor);
    867 
    868   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
    869   // and they must outlive all uses within the step. See comment above.
    870   // REQUIRES: IsRefType(expected_output_dtype(index))
    871   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
    872 
    873   // Returns nullptr if allocate_output() or set_output() have not been called.
    874   Status mutable_output(StringPiece name, Tensor** tensor);
    875 
    876   // Transfers ownership of an output tensor to the caller.
    877   // NOTE: For non-reference outputs, the caller takes responsibility
    878   // for deletion. For reference outputs, the caller does NOT take
    879   // responsibility for deletion.
    880   Status release_output(StringPiece name, TensorValue* value);
    881 
    882   // Records device specific state about how the input tensors were
    883   // computed.
    884   //
    885   // If using the templated function, the type must be a subclass
    886   // of DeviceContext.
    887   //
    888   // Get the DeviceContext used for the index input.  Returns nullptr
    889   // if no DeviceContext was provided.
    890   template <typename T>
    891   T* input_device_context(int index);
    892   DeviceContext* input_device_context(int index);
    893 
    894   // Return the DeviceContext that should be used for this Op.
    895   //
    896   // If using the templated function, the type must be a subclass
    897   // of DeviceContext.
    898   //
    899   // Returns nullptr if the device did not provide one.
    900   template <typename T>
    901   T* op_device_context();
    902   DeviceContext* op_device_context() {
    903     DeviceContext* ret = params_->op_device_context;
    904     if (ret == nullptr) {
    905       auto* dev_info = device()->tensorflow_gpu_device_info();
    906       if (dev_info) ret = dev_info->default_context;
    907     }
    908     return ret;
    909   }
    910 
    911   AllocatorAttributes input_alloc_attr(int index) const {
    912     if (params_->input_alloc_attrs == nullptr) {
    913       return AllocatorAttributes();
    914     } else {
    915       DCHECK_GE(index, 0);
    916       DCHECK_LT(index, params_->input_alloc_attrs->size());
    917       return (*params_->input_alloc_attrs)[index];
    918     }
    919   }
    920 
    921   AllocatorAttributes output_alloc_attr(int index) const {
    922     return params_->output_attr_array[index];
    923   }
    924 
    925   gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
    926     mutex_lock lock(mu_);
    927     gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_;
    928     return retrieved;
    929   }
    930 
    931   // Communication.
    932   //
    933   // An op kernel communicates with outside environment through
    934   // Rendezvous Send() and Recv().
    935   Rendezvous* rendezvous() const { return params_->rendezvous; }
    936 
    937   // An op kernel can access the session state it belongs to.
    938   SessionState* session_state() const { return params_->session_state; }
    939 
    940   // An op kernel can access the tensor store of the run it belongs to.
    941   TensorStore* tensor_store() const { return params_->tensor_store; }
    942 
    943   // Function call support.
    944   //
    945   // If this kernel invocation is within a function execution,
    946   // call_frame() returns the call frame for the function call.
    947   CallFrameInterface* call_frame() const { return params_->call_frame; }
    948 
    949   // If not nullptr, the kernel invoke functions defined in the
    950   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
    951   FunctionLibraryRuntime* function_library() const {
    952     return params_->function_library;
    953   }
    954 
    955   std::function<void(std::function<void()>)>* runner() const {
    956     return params_->runner;
    957   }
    958   StepStatsCollector* stats_collector() const {
    959     return params_->stats_collector;
    960   }
    961 
    962   // Shared resources accessible to this kernel.
    963   ResourceMgr* resource_manager() const { return params_->resource_manager; }
    964 
    965   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
    966     return params_->slice_reader_cache;
    967   }
    968 
    969   // Execution.
    970   //
    971   // OpKernels can use these eigen devices to carry out their
    972   // numerical computation.
    973   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
    974     return *device()->eigen_cpu_device();
    975   }
    976   const Eigen::GpuDevice& eigen_gpu_device() const {
    977     return params_->eigen_gpu_device->device();
    978   }
    979 #ifdef TENSORFLOW_USE_SYCL
    980   const Eigen::SyclDevice& eigen_sycl_device() const {
    981     return *device()->eigen_sycl_device();
    982   }
    983 #endif
    984   template <typename EigenDeviceType>
    985   const EigenDeviceType& eigen_device() const;
    986 
    987   // Error handling.
    988 
    989   // If expected_inputs == inputs() and expected_outputs == output_types(),
    990   // returns OK, else returns INVALID_ARGUMENT with an error message.
    991   // Recommended for Ops with dynamic signatures, where validation can only
    992   // be performed at runtime.
    993   Status MatchSignature(const DataTypeSlice expected_inputs,
    994                         const DataTypeSlice expected_outputs);
    995 
    996   // An OpKernel should call SetStatus() if Compute() encounters an
    997   // error.
    998   void SetStatus(const Status& status);
    999   const Status& status() const { return status_; }
   1000 
   1001   // Cancellation.
   1002   //
   1003   // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
   1004   // example of how to use this API.
   1005   CancellationManager* cancellation_manager() const {
   1006     return params_->cancellation_manager;
   1007   }
   1008 
   1009   // Other accessors.
   1010 
   1011   // For control flow.
   1012   FrameAndIter frame_iter() const { return params_->frame_iter; }
   1013   bool is_input_dead() const { return params_->is_input_dead; }
   1014   bool* is_output_dead() { return &is_output_dead_; }
   1015 
   1016   // May be used, e.g., to get GPU handles, etc.
   1017   // TODO(tucker): Add example usage.
   1018   DeviceBase* device() const { return params_->device; }
   1019 
   1020   // Retrieve list of referenced tensors in out_vector. Once this is
   1021   // called, it is not legal to reference any more tensors.  Should
   1022   // not be called from Op kernels.
   1023   void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
   1024 
   1025   // Per-step container for use by white-listed internal ops.
   1026   ScopedStepContainer* step_container() const {
   1027     return params_->step_container;
   1028   }
   1029 
   1030   // Helper routines for the OP_REQUIRES macros
   1031   void CtxFailure(const Status& s);
   1032   void CtxFailureWithWarning(const Status& s);
   1033   void CtxFailure(const char* file, int line, const Status& s);
   1034   void CtxFailureWithWarning(const char* file, int line, const Status& s);
   1035 
   1036   // Unrecommended functions: these are functions that have some
   1037   // current uses but are not recommended for use, and may go away at
   1038   // some future major version release.
   1039   //
   1040   // The following functions all have versions that return Status
   1041   // to capture error conditions, and are strongly preferred.
   1042   Tensor* mutable_output(int index);
   1043   void set_output(int index, const Tensor& tensor);
   1044   mutex* input_ref_mutex(int index);
   1045   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
   1046   TensorValue release_output(int index);
   1047 
   1048   bool track_allocations() const { return params_->track_allocations; }
   1049 
   1050   // Records temp memory allocation. Tensor object is recorded to identify the
   1051   // case where temp memory is used as output memory.
   1052   void record_temp_memory_allocation(int64 size, const Tensor& t)
   1053       LOCKS_EXCLUDED(stats_mu_);
   1054 
   1055   // Returns recorded size of temporary memory;
   1056   int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
   1057 
   1058   // Records persistent memory allocation, size can be negative indicating
   1059   // deallocation.
   1060   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
   1061       LOCKS_EXCLUDED(stats_mu_);
   1062 
   1063   // Returns recorded size and ids of persistent memory.
   1064   int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
   1065 
   1066   std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
   1067 
   1068   // Resets counters for temp and persistent memory and recorded ids.
   1069   void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
   1070 
   1071   bool input_is_ref(int index) const;
   1072 
   1073  private:
   1074   Allocator* get_allocator(AllocatorAttributes attr);
   1075 
   1076   // Internal method to add a tensor's buffer to the list of buffers
   1077   // referenced during the execution of the Op, so that GPUs may
   1078   // accurately track the memory that may not be reused until the Op
   1079   // execution completes.
   1080   void record_tensor_reference(const Tensor& tensor);
   1081   void really_record_tensor_reference(const Tensor& tensor);
   1082 
   1083   // Internal common method used when allocating tensor memory
   1084   Status allocate_tensor(DataType type, const TensorShape& shape,
   1085                          Tensor* out_tensor,
   1086                          AllocatorAttributes allocator_attr) {
   1087     return allocate_tensor(type, shape, out_tensor, allocator_attr,
   1088                            AllocationAttributes());
   1089   }
   1090 
   1091   Status allocate_tensor(DataType type, const TensorShape& shape,
   1092                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
   1093                          const AllocationAttributes& allocation_attr);
   1094 
   1095   // This is called by PersistentTensor::AccessTensor whenever the
   1096   // wrapped tensor is retrieved, to ensure the runtime knows that the
   1097   // Tensor is being accessed within an Op. This is necessary for
   1098   // memory safety of devices like GPUs that queue Ops for
   1099   // asynchronous execution after the Compute() method completes.
   1100   friend class PersistentTensor;
   1101   void NotifyUseOfPersistentTensor(const Tensor& tensor);
   1102 
   1103   Status status_;
   1104   Params* params_;    // not owned
   1105   mutable mutex mu_;  // mutable so const accessors can acquire the lock
   1106   gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
   1107   gtl::InlinedVector<TensorValue, 4> outputs_;
   1108 
   1109   // Constructed only if <params->record_tensor_accesses>.
   1110   ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
   1111 
   1112   bool is_output_dead_ = false;
   1113 
   1114   // The following data members are only used when allocation tracking is
   1115   // enabled.
   1116   mutable mutex stats_mu_;
   1117   int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
   1118   int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
   1119   std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
   1120       temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
   1121   std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
   1122       GUARDED_BY(stats_mu_);
   1123 
   1124   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
   1125 };
   1126 
   1127 // Register your OpKernel by specifying the Op's name, the device the
   1128 // kernel runs on, any type attr constraints for this kernel, any
   1129 // host-memory args, and the class to instantiate.  Examples:
   1130 //
   1131 //  // A kernel that supports all types.
   1132 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
   1133 //
   1134 //  // The following are equivalent ways of specifying that the kernel only
   1135 //  // works if the "T" type attr is set to DT_FLOAT.
   1136 //  REGISTER_KERNEL_BUILDER(
   1137 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
   1138 //      SubOp<float>);
   1139 //  // (You would then repeat this for every type supported by "Sub".)
   1140 //
   1141 //  // This form allows you to specify a list of types as the constraint.
   1142 //  REGISTER_KERNEL_BUILDER(Name("Sub")
   1143 //                              .Device(DEVICE_CPU)
   1144 //                              .TypeConstraint("T", {DT_FLOAT}),
   1145 //                          SubOp<float>);
   1146 //
   1147 //  // A kernel that expects one of the input tensors in host memory.
   1148 //  REGISTER_KERNEL_BUILDER(
   1149 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
   1150 //
   1151 // See kernel_def_builder for details.
   1152 
   1153 // Instantiate an OpKernel that has been registered.  Returns nullptr
   1154 // if no operation for that type of device / input signature combination
   1155 // (and a NOT_FOUND *status), or there is an error in construction (and
   1156 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
   1157 // of the returned pointer.
   1158 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
   1159 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
   1160 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
   1161                                          DeviceBase* device,
   1162                                          Allocator* allocator,
   1163                                          const NodeDef& def,
   1164                                          int graph_def_version, Status* status);
   1165 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
   1166                       Allocator* allocator, FunctionLibraryRuntime* flib,
   1167                       const NodeDef& def, int graph_def_version,
   1168                       OpKernel** kernel);
   1169 
   1170 // Returns into 'device_types' the subset of prioritized_types that this
   1171 // binary has registered for the given NodeDef.
   1172 //
   1173 // REQUIRES: * 'device_types' is not nullptr.
   1174 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
   1175 Status SupportedDeviceTypesForNode(
   1176     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
   1177     DeviceTypeVector* device_types);
   1178 
   1179 // Returns a message with a description of the kernels registered for op
   1180 // `op_name`.
   1181 string KernelsRegisteredForOp(StringPiece op_name);
   1182 
   1183 // Call once after Op registration has completed.
   1184 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
   1185 
   1186 // -----------------------------------------------------------------------------
   1187 // OpKernel registration implementation follows, please ignore.
   1188 
   1189 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
   1190 namespace register_kernel {
   1191 
   1192 class Name : public KernelDefBuilder {
   1193  public:
   1194   // With selective registration, kernels whose implementation class is not used
   1195   // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
   1196   // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
   1197   // implementation class with a used kernel would get through that mechanism.
   1198   //
   1199   // This mechanism stops that registration by changing the name of the kernel
   1200   // for the unused op to one that is ignored by
   1201   // OpKernelRegistrar::InitInternal.  Note that this method alone is
   1202   // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
   1203   // compilation time, so this method doesn't actually reduce code size.
   1204   explicit Name(const char* op)
   1205       : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
   1206 };
   1207 
   1208 namespace system {
   1209 
   1210 class Name : public KernelDefBuilder {
   1211  public:
   1212   // For system kernels, we ignore selective registration and
   1213   // unconditionally register the kernel.
   1214   explicit Name(const char* op) : KernelDefBuilder(op) {}
   1215 };
   1216 
   1217 }  // namespace system
   1218 
   1219 }  // namespace register_kernel
   1220 
   1221 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
   1222   REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
   1223 
   1224 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
   1225   REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
   1226 
   1227 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
   1228   constexpr bool should_register_##ctr##__flag =                      \
   1229       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
   1230   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
   1231       registrar__body__##ctr##__object(                               \
   1232           should_register_##ctr##__flag                               \
   1233               ? ::tensorflow::register_kernel::kernel_builder.Build() \
   1234               : nullptr,                                              \
   1235           #__VA_ARGS__,                                               \
   1236           [](::tensorflow::OpKernelConstruction* context)             \
   1237               -> ::tensorflow::OpKernel* {                            \
   1238             return new __VA_ARGS__(context);                          \
   1239           });
   1240 
   1241 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
   1242 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
   1243 // unconditionally even when selective registration is used.
   1244 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...)               \
   1245   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
   1246                                              __VA_ARGS__)
   1247 
   1248 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
   1249   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
   1250 
   1251 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)    \
   1252   static ::tensorflow::kernel_factory::OpKernelRegistrar                 \
   1253       registrar__body__##ctr##__object(                                  \
   1254           ::tensorflow::register_kernel::system::kernel_builder.Build(), \
   1255           #__VA_ARGS__,                                                  \
   1256           [](::tensorflow::OpKernelConstruction* context)                \
   1257               -> ::tensorflow::OpKernel* {                               \
   1258             return new __VA_ARGS__(context);                             \
   1259           });
   1260 
   1261 void* GlobalKernelRegistry();
   1262 
   1263 // If node_def has a corresponding kernel registered on device_type,
   1264 // returns OK and fill in the kernel def and kernel_class_name. <def> and
   1265 // <kernel_class_name> may be null.
   1266 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
   1267                      const KernelDef** def, string* kernel_class_name);
   1268 
   1269 // Writes a list of all registered kernels to LOG(INFO), to help users debug
   1270 // missing kernel errors.
   1271 void LogAllRegisteredKernels();
   1272 
   1273 namespace kernel_factory {
   1274 
   1275 class OpKernelRegistrar {
   1276  public:
   1277   typedef OpKernel* (*Factory)(OpKernelConstruction*);
   1278 
   1279   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
   1280                     Factory factory) {
   1281     // Perform the check in the header to allow compile-time optimization
   1282     // to a no-op, allowing the linker to remove the kernel symbols.
   1283     if (kernel_def != nullptr) {
   1284       InitInternal(kernel_def, kernel_class_name, factory);
   1285     }
   1286   }
   1287 
   1288  private:
   1289   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
   1290                     Factory factory);
   1291 };
   1292 
   1293 }  // namespace kernel_factory
   1294 
   1295 // -----------------------------------------------------------------------------
   1296 // Template and inline method implementations, please ignore
   1297 
   1298 template <class T>
   1299 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
   1300   return GetNodeAttr(def(), attr_name, value);
   1301 }
   1302 
   1303 inline DataType OpKernelContext::input_dtype(int index) const {
   1304   DCHECK_GE(index, 0);
   1305   DCHECK_LT(index, num_inputs());
   1306   const TensorValue& value((*params_->inputs)[index]);
   1307   if (value.is_ref()) {
   1308     return MakeRefType(value->dtype());
   1309   } else {
   1310     return value->dtype();
   1311   }
   1312 }
   1313 
   1314 inline MemoryType OpKernelContext::input_memory_type(int index) const {
   1315   DCHECK_GE(index, 0);
   1316   DCHECK_LT(index, num_inputs());
   1317   return op_kernel().input_memory_types()[index];
   1318 }
   1319 
   1320 inline DataType OpKernelContext::expected_output_dtype(int index) const {
   1321   DCHECK_GE(index, 0);
   1322   DCHECK_LT(index, num_outputs());
   1323   return params_->op_kernel->output_type(index);
   1324 }
   1325 
   1326 inline MemoryType OpKernelContext::output_memory_type(int index) const {
   1327   DCHECK_GE(index, 0);
   1328   DCHECK_LT(index, num_outputs());
   1329   return op_kernel().output_memory_types()[index];
   1330 }
   1331 
   1332 inline bool OpKernelContext::input_is_ref(int index) const {
   1333   const TensorValue& value((*params_->inputs)[index]);
   1334   return value.is_ref();
   1335 }
   1336 
   1337 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
   1338   DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
   1339             params_->record_tensor_accesses);
   1340   if (params_->record_tensor_accesses) {
   1341     really_record_tensor_reference(tensor);
   1342   }
   1343 }
   1344 
   1345 inline void OpKernelContext::retrieve_accessed_tensors(
   1346     TensorReferenceVector* out_vector) {
   1347   if (params_->record_tensor_accesses) {
   1348     mutex_lock l(mu_);
   1349     referenced_tensors_->FreezeAndReturnReferences(out_vector);
   1350   }
   1351 }
   1352 
   1353 // no input if tensor == nullptr.
   1354 inline bool OpKernelContext::has_input(int index) const {
   1355   DCHECK_GE(index, 0);
   1356   DCHECK_LT(index, num_inputs());
   1357   return (*params_->inputs)[index].tensor != nullptr;
   1358 }
   1359 
   1360 inline mutex* OpKernelContext::input_ref_mutex(int index) {
   1361   DCHECK_GE(index, 0);
   1362   DCHECK_LT(index, num_inputs());
   1363   DCHECK(input_is_ref(index));
   1364   return (*params_->inputs)[index].mutex_if_ref;
   1365 }
   1366 
   1367 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
   1368   if (t.IsInitialized()) {
   1369     record_tensor_reference(t);
   1370   }
   1371 }
   1372 
   1373 inline Tensor* OpKernelContext::mutable_output(int index) {
   1374   DCHECK_GE(index, 0);
   1375   DCHECK_LT(index, num_outputs());
   1376   // No need to record_tensor_reference since the output must already
   1377   // have been set by a call that did so.
   1378   return outputs_[index].tensor;
   1379 }
   1380 
   1381 inline TensorValue OpKernelContext::release_output(int index) {
   1382   DCHECK_GE(index, 0);
   1383   DCHECK_LT(index, num_outputs());
   1384   TensorValue value = outputs_[index];
   1385   outputs_[index] = TensorValue();
   1386   return value;
   1387 }
   1388 
   1389 inline Status OpKernelContext::forward_input_or_allocate_output(
   1390     gtl::ArraySlice<int> candidate_input_indices, int output_index,
   1391     const TensorShape& output_shape, Tensor** output) {
   1392   for (int input_index : candidate_input_indices) {
   1393     if (forward_input_to_output_with_shape(input_index, output_index,
   1394                                            output_shape, output)) {
   1395       return Status::OK();
   1396     }
   1397   }
   1398   return allocate_output(output_index, output_shape, output);
   1399 }
   1400 
   1401 inline Status OpKernelContext::forward_input_or_allocate_output(
   1402     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
   1403     const TensorShape& output_shape, Tensor** output) {
   1404   for (const StringPiece& input_name : candidate_input_names) {
   1405     if (forward_input_to_output_with_shape(input_name, output_name,
   1406                                            output_shape, output)
   1407             .ok()) {
   1408       return Status::OK();
   1409     }
   1410   }
   1411   return allocate_output(output_name, output_shape, output);
   1412 }
   1413 
   1414 template <typename T>
   1415 T* OpKernelContext::op_device_context() {
   1416   static_assert(std::is_base_of<DeviceContext, T>::value,
   1417                 "T is not a subclass of DeviceContext");
   1418   return static_cast<T*>(op_device_context());
   1419 }
   1420 
   1421 template <typename T>
   1422 T* OpKernelContext::input_device_context(int index) {
   1423   DCHECK_GE(index, 0);
   1424   DCHECK_LT(index, params_->input_device_contexts->size());
   1425   static_assert(std::is_base_of<DeviceContext, T>::value,
   1426                 "T is not a subclass of DeviceContext");
   1427   return static_cast<T*>((*params_->input_device_contexts)[index]);
   1428 }
   1429 
   1430 inline DeviceContext* OpKernelContext::input_device_context(int index) {
   1431   DCHECK_GE(index, 0);
   1432   DCHECK_LT(index, params_->input_device_contexts->size());
   1433   return (*params_->input_device_contexts)[index];
   1434 }
   1435 
   1436 inline const Tensor& OpInputList::operator[](int i) const {
   1437   DCHECK_GE(i, 0);
   1438   DCHECK_LT(i, stop_ - start_);
   1439   return ctx_->input(start_ + i);
   1440 }
   1441 
   1442 inline mutex* OpMutableInputList::ref_mutex(int i) {
   1443   DCHECK_GE(i, 0);
   1444   DCHECK_LT(i, stop_ - start_);
   1445   return ctx_->input_ref_mutex(start_ + i);
   1446 }
   1447 
   1448 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
   1449   DCHECK_GE(i, 0);
   1450   DCHECK_LT(i, stop_ - start_);
   1451   return ctx_->mutable_input(start_ + i, lock_held);
   1452 }
   1453 
   1454 inline Tensor* OpOutputList::operator[](int i) {
   1455   DCHECK_GE(i, 0);
   1456   DCHECK_LT(i, stop_ - start_);
   1457   return ctx_->mutable_output(start_ + i);
   1458 }
   1459 
   1460 inline bool OpOutputList::required(int i) const {
   1461   DCHECK_GE(i, 0);
   1462   DCHECK_LT(i, stop_ - start_);
   1463   return ctx_->output_required(start_ + i);
   1464 }
   1465 
   1466 inline DataType OpOutputList::expected_output_dtype(int i) const {
   1467   DCHECK_GE(i, 0);
   1468   DCHECK_LT(i, stop_ - start_);
   1469   return ctx_->expected_output_dtype(start_ + i);
   1470 }
   1471 
   1472 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
   1473                                      Tensor** output) {
   1474   DCHECK_GE(i, 0);
   1475   DCHECK_LT(i, stop_ - start_);
   1476   return ctx_->allocate_output(start_ + i, shape, output);
   1477 }
   1478 
   1479 inline void OpOutputList::set(int i, const Tensor& tensor) {
   1480   DCHECK_GE(i, 0);
   1481   DCHECK_LT(i, stop_ - start_);
   1482   ctx_->set_output(start_ + i, tensor);
   1483 }
   1484 
   1485 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
   1486   DCHECK_GE(i, 0);
   1487   DCHECK_LT(i, stop_ - start_);
   1488   ctx_->set_output_ref(i, mu, tensor_for_ref);
   1489 }
   1490 
   1491 // Convenience macros for asserting and handling exceptional conditions.
   1492 // Analogous to the CHECK* macros provided by logging.h.
   1493 //
   1494 // Example use:
   1495 // void Compute(OperationContext* context) {
   1496 //   OP_REQUIRES(context, context->num_inputs() == 2,
   1497 //               errors::InvalidArgument("FooOp requires 2 arguments"));
   1498 //   ...
   1499 //   Status status = SomeUncertainMethod();
   1500 //   OP_REQUIRES_OK(context, status);
   1501 //   ...
   1502 // }
   1503 
   1504 #define OP_REQUIRES(CTX, EXP, STATUS)                  \
   1505   do {                                                 \
   1506     if (!TF_PREDICT_TRUE(EXP)) {                       \
   1507       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
   1508       return;                                          \
   1509     }                                                  \
   1510   } while (0)
   1511 
   1512 #define OP_REQUIRES_OK(CTX, ...)                            \
   1513   do {                                                      \
   1514     ::tensorflow::Status _s(__VA_ARGS__);                   \
   1515     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
   1516       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
   1517       return;                                               \
   1518     }                                                       \
   1519   } while (0)
   1520 
   1521 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK)  \
   1522   do {                                                 \
   1523     if (!TF_PREDICT_TRUE(EXP)) {                       \
   1524       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
   1525       (CALLBACK)();                                    \
   1526       return;                                          \
   1527     }                                                  \
   1528   } while (0)
   1529 
   1530 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK)         \
   1531   do {                                                      \
   1532     ::tensorflow::Status _s(STATUS);                        \
   1533     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
   1534       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
   1535       (CALLBACK)();                                         \
   1536       return;                                               \
   1537     }                                                       \
   1538   } while (0)
   1539 
   1540 }  // namespace tensorflow
   1541 
   1542 #endif  // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
   1543