     19 #include <atomic>
     20 #include <functional>
     22 #include <utility>
     23 #include <vector>
     24 #include "tensorflow/core/framework/allocator.h"
     25 #include "tensorflow/core/framework/cancellation.h"
     26 #include "tensorflow/core/framework/control_flow.h"
     27 #include "tensorflow/core/framework/device_base.h"
     28 #include "tensorflow/core/framework/graph.pb.h"
     29 #include "tensorflow/core/framework/kernel_def.pb.h"
     30 #include "tensorflow/core/framework/kernel_def_builder.h"
     31 #include "tensorflow/core/framework/node_def_util.h"
     32 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
     33 #include "tensorflow/core/framework/rendezvous.h"
     34 #include "tensorflow/core/framework/selective_registration.h"
     35 #include "tensorflow/core/framework/session_state.h"
     36 #include "tensorflow/core/framework/tensor.h"
     37 #include "tensorflow/core/framework/tensor_shape.h"
     38 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
     39 #include "tensorflow/core/framework/tracking_allocator.h"
     40 #include "tensorflow/core/framework/types.h"
     41 #include "tensorflow/core/framework/types.pb.h"
     42 #include "tensorflow/core/framework/unique_tensor_references.h"
     43 #include "tensorflow/core/lib/core/errors.h"
     44 #include "tensorflow/core/lib/core/status.h"
     45 #include "tensorflow/core/lib/gtl/array_slice.h"
     46 #include "tensorflow/core/lib/gtl/manual_constructor.h"
     47 #include "tensorflow/core/platform/env.h"
     48 #include "tensorflow/core/platform/logging.h"
     49 #include "tensorflow/core/platform/macros.h"
     50 #include "tensorflow/core/platform/mutex.h"
     51 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
     52 #include "tensorflow/core/platform/thread_annotations.h"
     53 #include "tensorflow/core/platform/types.h"
     55 namespace Eigen {
     56 struct ThreadPoolDevice;
     57 struct GpuDevice;
     58 struct SyclDevice;
     59 }  // end namespace Eigen
     61 namespace tensorflow {
     63 namespace checkpoint {
     64 class TensorSliceReaderCacheWrapper;
     65 }  // namespace checkpoint
     67 class AsyncOpKernel;
     68 class CallFrameInterface;
     69 class FunctionLibraryRuntime;
     70 class OpKernelConstruction;  // declared below
     71 class OpKernelContext;       // declared below,
     72 class OpRegistryInterface;
     73 class ResourceMgr;
     74 class ScopedStepContainer;
     75 class CollectiveExecutor;
     76 class StepStatsCollectorInterface;
     78 class OpKernel {
     79  public:
     80   // OpKernel won't be instantiated by the scheduler, so you may perform
     81   // expensive initialization in the descendant's constructor.
     82   explicit OpKernel(OpKernelConstruction* context);
     84   // Specialized constructor that enables the descendant to provide a different
     85   // `NodeDef` value. For example, this constructor can be used to provide a
     86   // stripped-down `NodeDef` that does not contain the full set of attrs (such
     87   // as tensor values) if the descendant stores them in a different form.
     88   explicit OpKernel(OpKernelConstruction* context,
     89                     std::unique_ptr<const NodeDef> node_def);
     91   virtual ~OpKernel();
     93   // An OpKernel's computation can be either synchronous or
     94   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
     95   // may be called concurrently (e.g. by multiple executions of the same graph
     96   // concurrently).
     97   //
     98   // Most OpKernels should compute synchronously.  They should
     99   // subclass OpKernel and override the Compute() method and have it
    100   // return after completing the supplied work.
    101   //
    102   // A few special kernels might need to be asynchronous to bound the
    103   // number of threads (e.g., network receive operations). These
    104   // kernels must subclass AsyncOpKernel and override
    105   // AsyncOpKernel::ComputeAsync().
    106   //
    107   // In both cases, implementations of Compute() and ComputeAsync()
    108   // get inputs and write outputs through the given OpKernelContext
    109   // and returns a status via context->SetStatus(). They must be
    110   // thread-safe.
    112   // Synchronous compute.
    113   //
    114   // "context" is guaranteed to be alive until Compute() returns.
    115   virtual void Compute(OpKernelContext* context) = 0;
    117   // Returns nullptr iff this op kernel is synchronous.
    118   virtual AsyncOpKernel* AsAsync() { return nullptr; }
    119   virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
    121   // Initial time (in CPU cycles) we expect an operation to take.  Used to
    122   // determine whether an operation should be place in a threadpool.  Operations
    123   // start out "expensive".
    124   static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
    125   static const uint64 kOpIsExpensiveThresholdCycles = 5000;
    126   static const uint64 kCostDecay = 10;
    128   // Returns true iff this op kernel is considered "expensive". The
    129   // runtime may use this flag to optimize graph execution for example
    130   // to "inline" inexpensive kernels.
    131   virtual bool IsExpensive() {
    132     return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) >
    133                           kOpIsExpensiveThresholdCycles);
    134   }
    136   // Updates the dynamic cost estimate, which is used to determine whether this
    137   // op is expensive. The new cost estimate is a weighted average of the old
    138   // cost estimate and the latest cost.
    139   void UpdateCostEstimate(uint64 elapsed_cycles) {
    140     // N.B. Updates to `cost_estimate_` are atomic but unlocked.  Simulataneous
    141     // updates may result in one or more updates being ignored.  This does not
    142     // affect correctness but may slow down the update frequency.
    143     cost_estimate_.store(
    144         (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) /
    145                 kCostDecay +
    146             (elapsed_cycles / kCostDecay),
    147         std::memory_order_relaxed);
    148   }
    150   // Accessors.
    151   const NodeDef& def() const { return *def_; }
    152   const string& name() const;              // Same as def().name()
    153   const string& type_string() const;       // Same as def().op()
    154   const string& requested_device() const;  // Same as def().device()
    155   bool is_internal() const { return is_internal_; }
    157   int num_inputs() const { return input_types_.size(); }
    158   DataType input_type(int i) const { return input_types_[i]; }
    159   const DataTypeVector& input_types() const { return input_types_; }
    160   const MemoryTypeVector& input_memory_types() const {
    161     return input_memory_types_;
    162   }
    163   const string& requested_input(int i) const;  // Same as def().input(i)
    165   int num_outputs() const { return output_types_.size(); }
    166   DataType output_type(int o) const { return output_types_[o]; }
    167   const DataTypeVector& output_types() const { return output_types_; }
    168   const MemoryTypeVector& output_memory_types() const {
    169     return output_memory_types_;
    170   }
    172   Status InputRange(StringPiece input_name, int* start, int* stop) const;
    173   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
    175   // We allow legacy scalars within Google up until GraphDef version 6.
    176   // TODO(irving): Remove when we can drop support for GraphDef version 5.
    177   bool allow_legacy_scalars() const {
    178 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
    179     return graph_def_version_ < 6;
    180 #else
    181     return false;
    182 #endif
    183   }
    185   // Allow either scalars or (if allowing legacy scalars) shape (1,).
    186   bool IsLegacyScalar(const TensorShape& shape) const {
    187     return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
    188                                  shape.dim_size(0) == 1);
    189   }
    191   // Allow rank 1 or (if allowing legacy scalars) rank 0.
    192   bool IsLegacyVector(const TensorShape& shape) const {
    193     return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
    194   }
    196   // Turn a shape Tensor into a TensorShape
    197   // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
    198   Status MakeShape(const Tensor& shape, TensorShape* out) const;
    200   static int DeviceNumaNode(const DeviceBase* device);
    202  private:
    203   const std::unique_ptr<const NodeDef> def_;
    204   const DataTypeVector input_types_;
    205   const MemoryTypeVector input_memory_types_;
    206   const DataTypeVector output_types_;
    207   const MemoryTypeVector output_memory_types_;
    208   const int graph_def_version_;
    209   const bool is_internal_;  // True if this is an internal operation
    210   NameRangeMap input_name_map_;
    211   NameRangeMap output_name_map_;
    212   bool expensive_;
    213   std::atomic_uint_fast64_t cost_estimate_;
    216 };
    218 class AsyncOpKernel : public OpKernel {
    219  public:
    220   using OpKernel::OpKernel;  // Lift OpKernel constructors.
    222   // Asynchronous compute.
    223   //
    224   // Implementations of ComputeAsync() must run "done" to signal the
    225   // completion of the computation. "context" is guaranteed to be
    226   // alive until the "done" callback starts.
    227   typedef std::function<void()> DoneCallback;
    228   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
    230   AsyncOpKernel* AsAsync() final { return this; }
    231   const AsyncOpKernel* AsAsync() const final { return this; }
    233   void Compute(OpKernelContext* context) final;
    234 };
    236 // Wraps a tensor that is held by an Op across calls to Compute(). For
    237 // memory safety when using asynchronous devices like GPUs, the system
    238 // must be notified when a Tensor is used inside an Op execution. The
    239 // wrapper ensures that all uses of the Tensor are tracked, because in
    240 // order to retrieve the Tensor the caller must use AccessTensor which
    241 // notifies the context.
    242 class PersistentTensor {
    243  public:
    244   PersistentTensor() {}
    245   explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
    247   // Caller does not own the returned Tensor*.
    248   Tensor* AccessTensor(OpKernelConstruction* context);
    249   // Caller does not own the returned Tensor*.
    250   Tensor* AccessTensor(OpKernelContext* context);
    252   // The check for initialization does not need to access the
    253   // underlying tensor buffer.
    254   bool IsInitialized() const { return tensor_.IsInitialized(); }
    256   int64 NumElements() const { return tensor_.NumElements(); }
    258   int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
    260  private:
    261   Tensor tensor_;
    262 };
    264 class OpKernelConstruction {
    265  public:
    266   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
    267                        Allocator* allocator, const NodeDef* node_def,
    268                        const OpDef* op_def, FunctionLibraryRuntime* flib,
    269                        const DataTypeSlice& input_types,
    270                        const MemoryTypeSlice& input_memory_types,
    271                        const DataTypeSlice& output_types,
    272                        const MemoryTypeSlice& output_memory_types,
    273                        int graph_def_version, Status* status);
    275   Env* env() const { return device_->env(); }
    277   // Allocation of tensors during kernel construction:
    278   //
    279   // It is legal to temporarily allocate scratch tensor storage during
    280   // Op kernel construction. Scratch tensors should be allocated using
    281   // allocate_temp below. Some kernels need to keep tensors in between
    282   // invocations. If such a Tensor is allocated during kernel
    283   // construction this must be done using allocate_persistent, and the
    284   // Op may only store the returned PersistentTensor object. When the
    285   // Tensor is needed in a subsequent invocation, it can be retrieved
    286   // from the PersistentTensor using the AccessTensor method. This
    287   // ensures that the system is made aware of any use of the tensor's
    288   // allocated memory, which is needed for correctness on asynchronous
    289   // devices such as GPUs.
    291   // Allocates a temporary Tensor of the specified type and shape. The
    292   // Tensor must not be used after kernel construction is
    293   // complete. See comment above.
    294   Status allocate_temp(DataType type, const TensorShape& shape,
    295                        Tensor* out_temp);
    297   // Allocates a Tensor of the specified type and shape which the Op
    298   // plans to maintain as persistent state. out_persistent holds the
    299   // PersistentTensor which is the object the caller should store. For
    300   // convenience, if out_tensor is non-null then it will be filled in
    301   // with a Tensor* pointing to the newly-allocated tensor which the
    302   // caller can use instead of calling
    303   // out_persistent->AccessTensor. The caller does not own out_tensor
    304   // and should not keep a copy of it. See comment above.
    305   Status allocate_persistent(DataType type, const TensorShape& shape,
    306                              PersistentTensor* out_persistent,
    307                              Tensor** out_tensor);
    309   // User-supplied configuration of this operation.
    310   const NodeDef& def() const { return *def_; }
    312   // For inspecting the inputs to this operation.
    313   int num_inputs() const { return input_types_.size(); }
    314   DataType input_type(int i) const { return input_types_[i]; }
    315   const DataTypeSlice& input_types() const { return input_types_; }
    316   const MemoryTypeSlice& input_memory_types() const {
    317     return input_memory_types_;
    318   }
    320   // For inspecting the outputs expected from this operation.
    321   int num_outputs() const { return output_types_.size(); }
    322   DataType output_type(int i) const { return output_types_[i]; }
    323   const DataTypeSlice& output_types() const { return output_types_; }
    324   const MemoryTypeSlice& output_memory_types() const {
    325     return output_memory_types_;
    326   }
    328   // If expected_inputs == inputs() and expected_outputs == output_types(),
    329   // returns OK, else returns INVALID_ARGUMENT with an error message.
    330   // Recommended for Ops with dynamic signatures.
    331   Status MatchSignature(const DataTypeSlice expected_inputs,
    332                         const DataTypeSlice expected_outputs);
    334   // For recording configuration errors during construction.
    335   void SetStatus(const Status& status);
    336   const Status& status() const { return *status_; }
    338   // Look up the attr with name attr_name and set *value to its value.  If no
    339   // attr with attr_name is found in def(), or the attr does not have
    340   // a matching type, a non-ok status will be returned.
    341   template <class T>
    342   Status GetAttr(StringPiece attr_name, T* value) const;
    344   // Return true if the attr_name is defined in def().
    345   bool HasAttr(StringPiece attr_name) const;
    347   // Return the device type.
    348   const DeviceType& device_type() const { return device_type_; }
    350   // If not nullptr, the kernel can instantiate functions defined in
    351   // the library. E.g.,
    352   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
    353   FunctionLibraryRuntime* function_library() const { return flib_; }
    355   // The GraphDef version whose behavior we should follow.
    356   int graph_def_version() const { return graph_def_version_; }
    358   // Helper routines for the OP_REQUIRES macros
    359   void CtxFailure(const Status& s);
    360   void CtxFailureWithWarning(const Status& s);
    361   void CtxFailure(const char* file, int line, const Status& s);
    362   void CtxFailureWithWarning(const char* file, int line, const Status& s);
    364   // Unrecommended functions: these are functions that have some
    365   // current uses but are not recommended for use, and may go away at
    366   // some future major version release.
    368   // May be used, e.g., to get GPU handles, etc.
    369   //
    370   // Currently only used to call MakeTensorFromProto() for
    371   // implementing ConstantOp for every device.  See comments
    372   // on Device::MakeTensorFromProto for longer-term replacement
    373   // ideas.
    374   DeviceBase* device() const { return device_; }
    376  private:
    377   const DeviceType device_type_;
    378   DeviceBase* const device_;
    379   Allocator* allocator_;
    380   const NodeDef* def_;
    381   const OpDef* op_def_;
    382   FunctionLibraryRuntime* flib_;
    383   DataTypeSlice input_types_;
    384   MemoryTypeSlice input_memory_types_;
    385   DataTypeSlice output_types_;
    386   MemoryTypeSlice output_memory_types_;
    387   const int graph_def_version_;
    388   Status* status_;
    390   // Allow op_def_ across from OpKernel, but not from subclasses.
    391   // TODO(irving): Remove protos from this header entirely.
    392   friend class OpKernel;
    394   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
    395 };
    397 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
    398 // tensorflow::gtl::iterator_range to make the below container classes
    399 // unnecessary.
    400 template <typename ListType, typename ElementType>
    401 class OpArgIterator {
    402  public:
    403   using iterator_category = std::forward_iterator_tag;
    404   using value_type = ElementType;
    405   using pointer = ElementType*;
    406   using const_pointer = const ElementType*;
    407   using reference = ElementType&;
    408   using const_reference = const ElementType&;
    409   using difference_type = ptrdiff_t;
    411   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
    413   bool operator==(const OpArgIterator& rhs) {
    414     DCHECK(list_ == rhs.list_);
    415     return i_ == rhs.i_;
    416   }
    418   bool operator!=(const OpArgIterator& rhs) {
    419     DCHECK(list_ == rhs.list_);
    420     return i_ != rhs.i_;
    421   }
    423   OpArgIterator operator++() {  // prefix ++it
    424     ++i_;
    425     return *this;
    426   }
    428   OpArgIterator operator++(int) {  // postfix it++
    429     OpArgIterator old_value = *this;
    430     ++i_;
    431     return old_value;
    432   }
    434   reference operator*() { return (*list_)[i_]; }
    435   pointer operator->() { return &(*list_)[i_]; }
    437   const_reference operator*() const { return (*list_)[i_]; }
    438   const_pointer operator->() const { return &(*list_)[i_]; }
    440  private:
    441   const ListType* const list_;
    442   int i_;
    443 };
    445 // Utility class for representing a list of immutable input tensors
    446 // that are passed to the op as a single named argument.
    447 class OpInputList {
    448  public:
    449   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
    450   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
    451   OpInputList(OpKernelContext* ctx, int start, int stop)
    452       : ctx_(ctx), start_(start), stop_(stop) {}
    453   OpInputList& operator=(const OpInputList& other) = default;
    454   const Tensor& operator[](int i) const;
    455   int size() const { return stop_ - start_; }
    456   Iterator begin() const { return Iterator(this, 0); }
    457   Iterator end() const { return Iterator(this, size()); }
    459  private:
    460   OpKernelContext* ctx_;  // not owned
    461   int start_;
    462   int stop_;
    463 };
    465 // Utility class for representing a list of mutable ("ref") input tensors
    466 // that are passed to the op as a single named argument.
    467 class OpMutableInputList {
    468  public:
    469   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
    470   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
    471       : ctx_(ctx), start_(start), stop_(stop) {}
    472   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
    473   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
    474   Tensor at(int i, bool lock_held);
    475   mutex* ref_mutex(int i);
    476   int size() const { return stop_ - start_; }
    477   Iterator begin() const { return Iterator(this, 0); }
    478   Iterator end() const { return Iterator(this, size()); }
    480  private:
    481   OpKernelContext* ctx_;  // not owned
    482   int start_;
    483   int stop_;
    484 };
    486 // Utility class for representing a list of output tensors that are
    487 // grouped as a single named output.
    488 class OpOutputList {
    489  public:
    490   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
    491   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
    492   OpOutputList(OpKernelContext* ctx, int start, int stop)
    493       : ctx_(ctx), start_(start), stop_(stop) {}
    494   OpOutputList& operator=(const OpOutputList& other) = default;
    495   Tensor* operator[](int i);
    496   bool required(int i) const;
    497   DataType expected_output_dtype(int i) const;
    498   Status allocate(int i, const TensorShape& shape, Tensor** output);
    499   void set(int i, const Tensor& tensor);
    500   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
    501   int size() const { return stop_ - start_; }
    502   Iterator begin() const { return Iterator(this, 0); }
    503   Iterator end() const { return Iterator(this, size()); }
    505  private:
    506   OpKernelContext* ctx_;  // not owned
    507   int start_;
    508   int stop_;
    509 };
    511 // Holds a tensor or tensor reference. For tensor references, we need
    512 // a mutex to prevent concurrent access to the tensor.
    513 struct TensorValue {
    514   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
    515   TensorValue(Tensor* t)  // NOLINT(runtime/explicit)
    516       : mutex_if_ref(nullptr), tensor(t) {}
    517   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
    518   Tensor* operator->() const { return tensor; }
    519   bool is_ref() const { return mutex_if_ref != nullptr; }
    521   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
    522   Tensor* tensor;
    523 };
    525 // Used to store partitioned graphs from function-calling ops.
    526 struct GraphCollector {
    527   mutex mu;
    528   std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
    529   GraphDef raw_graph GUARDED_BY(mu);
    530   GraphDef optimized_graph GUARDED_BY(mu);
    532   bool dirty GUARDED_BY(mu);
    534   GraphCollector() : dirty(false) {}
    536   void CollectRawGraph(const GraphDef& graph) {
    537     mutex_lock ml(mu);
    538     raw_graph.MergeFrom(graph);
    539     dirty = true;
    540   }
    542   void CollectOptimizedGraph(const GraphDef& graph) {
    543     mutex_lock ml(mu);
    544     optimized_graph.MergeFrom(graph);
    545     dirty = true;
    546   }
    548   void CollectPartitionedGraph(const GraphDef& graph) {
    549     mutex_lock ml(mu);
    550     partitioned_graphs.push_back(graph);
    551     dirty = true;
    552   }
    554   void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
    555     raw_graph.Clear();
    556     optimized_graph.Clear();
    557     partitioned_graphs.clear();
    558     dirty = false;
    559   }
    561   bool HasUpdatedGraphs() {
    562     mutex_lock ml(mu);
    563     return dirty;
    564   }
    565 };
    567 class OpKernelContext {
    568  public:
    569   // The first element of a WrappedAllocator is a "base" Allocator and
    570   // the second element is that Allocator wrapped by a
    571   // TrackingAllocator
    572   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
    574   // TODO(zhifengc): Do some cleanup of Params.
    575   // The Params struct is passed in to initialize an OpKernelContext,
    576   // and must outlive the OpKernelContext.
    577   struct Params {
    578     ~Params() { delete eigen_gpu_device; }
    580     // The step being executed.
    581     int64 step_id = 0;
    583     // The op kernel being computed.
    584     OpKernel* op_kernel = nullptr;
    586     // The device on which the kernel is running.
    587     DeviceBase* device = nullptr;
    589     // The Eigen GPU device wrapper, which may include a per-op
    590     // wrapped allocator. The concrete type of this object depends on
    591     // the type of this->device, so eigen_gpu_device can't be an
    592     // inline member and must be heap allocated. However, we don't
    593     // want to allocate a new eigen_gpu_device for every Op that is
    594     // executed. Instead this member is allocated on first use using
    595     // ensure_eigen_gpu_device, and then if the Params structure is
    596     // re-used for subsequent Ops, the eigen_gpu_device is
    597     // ReInitialized in the OpKernelContext constructor. Unlike the
    598     // other pointers in Params, this one is owned by Params.
    599     PerOpGpuDevice* eigen_gpu_device = nullptr;
    601     inline void ensure_eigen_gpu_device() {
    602       DCHECK(device);
    603       if (nullptr == eigen_gpu_device) {
    604         // Surprisingly, MakeGpuDevice will return nullptr if the
    605         // device is not a GPU device. This is ok, since those devices
    606         // will never use eigen_gpu_device. It seems better to have
    607         // ensure_eigen_gpu_device fall through and regenerate the
    608         // nullptr every time an OpKernelContext is instantiated, than
    609         // to do an unnecessary allocation of a dummy eigen GPU
    610         // device for CPU device Ops.
    611         eigen_gpu_device = device->MakeGpuDevice();
    612       }
    613     }
    615     bool track_allocations = false;
    616     bool log_memory = false;
    617     bool record_tensor_accesses = false;
    619     // Array indexed by output number for this node
    620     const AllocatorAttributes* output_attr_array = nullptr;
    622     // Shared resources accessible by this op kernel invocation.
    623     ResourceMgr* resource_manager = nullptr;
    625     // Per-step resources accessible by this op kernel invocation should be
    626     // stored in this container..
    627     ScopedStepContainer* step_container = nullptr;
    629     // Mechanism used by this op kernel invocation to communicate with
    630     // computations running on other devices.
    631     Rendezvous* rendezvous = nullptr;
    633     // Mechanism for executing a collective op that needs to coordinate
    634     // with parallel instances running on other devices.
    635     CollectiveExecutor* collective_executor = nullptr;
    637     // The session state for this op.
    638     SessionState* session_state = nullptr;
    640     // Unique session identifier. Can be empty.
    641     string session_handle;
    643     // The tensor store for this op.
    644     TensorStore* tensor_store = nullptr;
    646     // Mechanism used by this op kernel invocation to register a callback
    647     // for its cancellation.
    648     CancellationManager* cancellation_manager = nullptr;
    650     // Inputs to this op kernel.
    651     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
    652     bool is_input_dead = false;
    654     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
    655         nullptr;
    657     // Device contexts.
    658     const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
    659         nullptr;
    660     DeviceContext* op_device_context = nullptr;
    662     // Control-flow op supports.
    663     FrameAndIter frame_iter;
    665     // Function call supports.
    666     CallFrameInterface* call_frame = nullptr;
    667     FunctionLibraryRuntime* function_library = nullptr;
    668     std::function<void(std::function<void()>)>* runner = nullptr;
    669     StepStatsCollectorInterface* stats_collector = nullptr;
    670     GraphCollector* graph_collector = nullptr;
    672     // TensorSliceReaderCache support.
    673     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
    675     // Support for forwarding reservations (used by ScopedAllocator).
    676     static const int kNeverForward = -2;
    677     static const int kNoReservation = -1;
    678     // Values in [0,...) represent reservations for the indexed output.
    679     const int* forward_from_array = nullptr;
    681     // For tracking actively running deferred ops.
    682     std::function<void()> inc_num_deferred_ops_function = []() {};
    683     std::function<void()> dec_num_deferred_ops_function = []() {};
    684   };
    686   // params must outlive the OpKernelContext.
    687   explicit OpKernelContext(Params* params);
    688   OpKernelContext(Params* params, int noutputs);
    689   ~OpKernelContext();
    691   Env* env() const { return params_->device->env(); }
    693   int64 step_id() const { return params_->step_id; }
    695   const OpKernel& op_kernel() const { return *params_->op_kernel; }
    697   // Input/output signature.
    699   int num_inputs() const { return params_->inputs->size(); }
    700   DataType input_dtype(int index) const;
    701   Status input_dtype(StringPiece name, DataType* dtype) const;
    702   MemoryType input_memory_type(int index) const;
    704   int num_outputs() const { return outputs_.size(); }
    705   DataType expected_output_dtype(int index) const;
    706   MemoryType output_memory_type(int index) const;
    708   // Input
    710   // Returns an immutable input tensor. May only be used for non-Ref
    711   // inputs. For Ref inputs use mutable_input below.
    712   // REQUIRES: !IsRefType(input_dtype(index))
    713   // TODO(mrry): Convert this to return Status.
    714   const Tensor& input(int index);
    716   // Returns the named immutable input tensor in "tensor", as defined
    717   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
    718   // use mutable_input below.
    719   // REQUIRES: !IsRefType(input_dtype(index))
    720   // REQUIRES: the named input must not be a list.
    721   Status input(StringPiece name, const Tensor** tensor);
    723   // Returns the named list-valued immutable input in "list", as
    724   // defined in the OpDef.  If the named output is not list-valued,
    725   // returns a one-element list. May only be used for non-Ref
    726   // inputs. For Ref inputs use mutable_input below.
    727   // REQUIRES: !IsRefType(input_dtype(index))
    728   Status input_list(StringPiece name, OpInputList* list);
    730   // For mutable inputs, use the following together to make sure there
    731   // is no concurrent access to mutable_input(), e.g.:
    732   // {
    733   //   Tensor& t = context->mutable_input(index);
    734   //   mutex_lock lock(*context->input_ref_mutex(index));
    735   //   // modify the values in t
    736   // }
    737   // REQUIRES: IsRefType(input_dtype(index))
    738   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
    740   // Returns a mutable input tensor. Must be used to access Ref
    741   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
    742   // modify the values stored in the Tensor buffer, and modifications
    743   // will be visible to other Ops reading the same ref tensor. If
    744   // !lock_held the input mutex will be acquired before returning the
    745   // Tensor.
    746   // TODO(mrry): Convert this to return Status.
    747   Tensor mutable_input(int index, bool lock_held);
    749   // Returns the named mutable input tensor in "tensor", as defined in
    750   // the OpDef. Must be used to access Ref inputs. The values stored
    751   // in the Tensor buffer may be modified, and modifications will be
    752   // visible to other Ops reading the same ref tensor. If !lock_held
    753   // the input mutex will be acquired before returning the Tensor.
    754   // REQUIRES: the named input must not be a list.
    755   // REQUIRES: the named input must be a ref tensor.
    756   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
    758   // Returns the named list-valued mutable input in "list", as defined
    759   // in the OpDef.  If the named input is not list-valued, returns a
    760   // one-element list. Must be used to access Ref inputs. The values
    761   // stored in the Tensor buffer may be modified, and modifications
    762   // will be visible to other Ops reading the same ref tensor.
    763   // REQUIRES: the named input must be a ref tensor.
    764   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
    766   // Replace the corresponding Ref Input to use the storage buffer
    767   // used by tensor. If !lock_held the input mutex will be acquired
    768   // before returning the Tensor.
    769   // REQUIRES: IsRefType(input_dtype(index)).
    770   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
    772   // Replace the corresponding named Ref Input to use the storage
    773   // buffer used by tensor. If !lock_held the input mutex will be
    774   // acquired before returning the Tensor.
    775   // REQUIRES: IsRefType(input_dtype(index)).
    776   Status replace_ref_input(StringPiece name, const Tensor& tensor,
    777                            bool lock_held);
    779   // Deletes the Tensor object used as the Ref Input at
    780   // input_index. This is not usually necessary and should be used
    781   // with caution. If !lock_held the input mutex will be acquired
    782   // before returning the Tensor.
    783   // REQUIRES: IsRefType(input_dtype(input_index)).
    784   void delete_ref_input(int input_index, bool lock_held);
    786   // Return true if there is input at the given index. An operator has no
    787   // input at index if its tensor is null. This is primarily used by the
    788   // merge operator.
    789   // TODO(mrry): Convert this to return Status.
    790   bool has_input(int index) const;
    792   // Returns true if all inputs are the same shape, otherwise sets the
    793   // status to a non-OK value and returns false.
    794   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
    795   bool ValidateInputsAreSameShape(OpKernel* op);
    797   // If non-null, kernels should populate with any partition subgraphs created.
    798   GraphCollector* graph_collector() { return params_->graph_collector; }
    800   // Input to output forwarding.
    802   // Set the output Ref Tensor at output_index to be an alias of the
    803   // input Ref Tensor at input_index.
    804   // REQUIRES: IsRefType(input_dtype(input_index)).
    805   // REQUIRES: IsRefType(output_dtype(output_index)).
    806   void forward_ref_input_to_ref_output(int input_index, int output_index);
    808   // Returns true when an alias to input[input_index], reshaped to output_shape,
    809   // which is safe to use for in-place computation was written to *output.
    810   // Returns false if input[input_index] has a refcount greater than one, or if
    811   // its type does not match the expected output type of output[output_index],
    812   // or the number of elements in input[input_index] does not equal the number
    813   // of elements in output_shape.
    814   bool forward_input_to_output_with_shape(int input_index, int output_index,
    815                                           const TensorShape& output_shape,
    816                                           Tensor** output) TF_MUST_USE_RESULT;
    817   Status forward_input_to_output_with_shape(StringPiece input_name,
    818                                             StringPiece output_name,
    819                                             const TensorShape& output_shape,
    820                                             Tensor** output) TF_MUST_USE_RESULT;
    822   // Returns a pointer to a Tensor aliasing the underlying buffer backing
    823   // input[input_index] iff
    824   //   * input[input_index] is not a ref,
    825   //   * the data type, shape, memory type, and allocator attributes of
    826   //     input[input_index] are compatible with those given in dtype, shape,
    827   //     memory_type, and attr,
    828   //   * refcount on the underlying buffer is one.
    829   //   * Either there is no forwarding reservation for either input_index
    830   //     or output_index or the specified input is reserved for the specified
    831   //     output. More precisely:
    832   //
    833   //     These cases mean neither input nor output has a reservation:
    834   //        forward_from_array = nullptr
    835   //     OR (input_index is not in forward_from_array AND
    836   //         (output_index == kNoReservation OR
    837   //          forward_from_array[output_index] == kNoReservation))
    838   //
    839   //     This case means that input_index is reserved for output_index:
    840   //        forward_from_array[output_index] == input_index
    841   //
    842   //     This case means the output is reserved to always be allocated,
    843   //     never assigned a forwarded input:
    844   //        forward_from_array[output_index] == kNeverForward
    845   //
    846   // Otherwise returns nullptr.
    847   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
    848   // forwarding is only safe if there are no reads via __ldg() after writes
    849   // to the same address.
    850   std::unique_ptr<Tensor> forward_input(
    851       int input_index, int output_index, DataType output_dtype,
    852       const TensorShape& output_shape, MemoryType output_memory_type,
    853       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
    855   // Tries to forward one of the inputs given in input_indices to
    856   // output[output_index]. If none of the given inputs can be forwarded, calls
    857   // allocate_output() to allocate a new output buffer.
    858   Status forward_input_or_allocate_output(
    859       gtl::ArraySlice<int> candidate_input_indices, int output_index,
    860       const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
    861   Status forward_input_or_allocate_output(
    862       gtl::ArraySlice<StringPiece> candidate_input_names,
    863       StringPiece output_name, const TensorShape& output_shape,
    864       Tensor** output) TF_MUST_USE_RESULT;
    866   // Tries to reuse one of the inputs given in input_indices as a temporary.
    867   // If none of the given inputs can be forwarded, calls
    868   // allocate_temp() to allocate a new temporary buffer.
    869   Status forward_input_or_allocate_temp(
    870       gtl::ArraySlice<int> candidate_input_indices, DataType type,
    871       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
    872       Tensor* out_temp) TF_MUST_USE_RESULT;
    874   Status forward_input_or_allocate_temp(
    875       gtl::ArraySlice<int> candidate_input_indices, DataType type,
    876       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
    877     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
    878                                           AllocatorAttributes(), out_temp);
    879   }
    881   // Output
    883   // Returns the named list-valued output in "list", as defined in the OpDef.
    884   // If the named output is not list-valued, returns a one-element list.
    885   Status output_list(StringPiece name, OpOutputList* list);
    887   // If output_required(index) returns true, the OpKernel's Compute() method
    888   // should call allocate_output(index, ...), set_output(index, ...),
    889   // set_output_ref(index, ...), or set the status to a non-ok value.
    890   // If it returns false, it may output, but is not required to do so.
    891   // TODO(mrry): Convert this to return Status, and implement a string
    892   // name version.
    893   bool output_required(int index) const {
    894     return true;  // TODO(josh11b): implement
    895   }
    897   // Allocation of tensors during kernel execution inside the Compute
    898   // method:
    899   //
    900   // There are three methods to allocate Tensors when an Op kernel
    901   // executes.
    902   //
    903   // 1) allocate_persistent. This is only needed for Tensors that will
    904   // be stored by the Op between invocations, and it *must* be used
    905   // for those Tensors. The call returns a PersistentTensor, and that
    906   // is the only object the Op is allowed to hold on to between
    907   // invocations. When the Tensor is needed in a subsequent
    908   // invocation, it can be retrieved from the PersistentTensor using
    909   // the AccessTensor method. This ensures that the system is made
    910   // aware of any use of the tensor's allocated memory, which is
    911   // needed for correctness on asynchronous devices such as GPUs.
    912   //
    913   // 2) allocate_output. This should be used to allocate any tensor
    914   // that is going to be used as an output from the Op at the end of
    915   // the current execution. The caller indicates which output the
    916   // Tensor will be assigned to, and the call returns the
    917   // newly-allocated Tensor. The Tensor can subsequently be assigned
    918   // to during kernel execution, and will be used as the designated
    919   // output when the kernel execution completes.
    920   //
    921   // 3) allocate_temp. This should be used to allocate any scratch
    922   // storage that is needed while the kernel is executing, and will
    923   // not be retained by the Op.
    924   //
    925   // In some cases a Tensor needs to be used as an output even though
    926   // it was previously allocated elsewhere. The Tensor may have been
    927   // passed as an input, or stored in a PersistentTensor during a
    928   // previous kernel execution, or allocated earlier in the kernel
    929   // execution at a time when it was not known which output it would
    930   // be assigned to. In this case the kernel can use set_output or
    931   // set_output_ref to indicate that the tensor should be used as the
    932   // designated output. It is legal to use any previously-allocated
    933   // Tensor as an argument to set_output or set_output_ref, including
    934   // Tensors allocated via allocate_temp. There may be a performance
    935   // penalty to using a Tensor that was not allocated using
    936   // allocate_output. This is because allocate_output uses the
    937   // AllocatorAttributes stored in output_attr_array for the
    938   // designated output. In some cases, using the wrong attributes may
    939   // cause an extra copy of the Tensor's buffer.
    941   // Allocates output for the specified output index with shape.
    942   // OpKernelContext retains ownership of the returned pointer. See
    943   // comment above.
    944   //
    945   // If memory allocation fails, returns an error status.
    946   //
    947   // REQUIRES: !IsRefType(expected_output_dtype(index))
    948   Status allocate_output(int index, const TensorShape& shape,
    949                          Tensor** tensor) TF_MUST_USE_RESULT;
    950   Status allocate_output(StringPiece name, const TensorShape& shape,
    951                          Tensor** tensor) TF_MUST_USE_RESULT;
    952   // The following methods use the supplied attributes instead of
    953   // those in output_attr_array. The caller is responsible for
    954   // ensuring that the attributes are "compatible" with the
    955   // output_attr_array, e.g. the tensor is allocated on the correct
    956   // device. See comment above.
    957   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
    958                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
    959   Status allocate_output(StringPiece name, const TensorShape& shape,
    960                          Tensor** tensor,
    961                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
    963   // Allocates a temporary Tensor of the specified type and
    964   // shape. Devices such as GPUs that enqueue Ops for lazy execution
    965   // may retain references to the temporary tensors after the Op's
    966   // Compute method has run. See comment above.
    967   Status allocate_temp(DataType type, const TensorShape& shape,
    968                        Tensor* out_temp, AllocatorAttributes allocator_attr,
    969                        const AllocationAttributes& allocation_attr);
    970   Status allocate_temp(DataType type, const TensorShape& shape,
    971                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
    972     return allocate_temp(type, shape, out_temp, allocator_attr,
    973                          AllocationAttributes());
    974   }
    975   Status allocate_temp(DataType type, const TensorShape& shape,
    976                        Tensor* out_temp) {
    977     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
    978   }
    980   // Allocates a Tensor of the specified type and shape which the Op
    981   // plans to maintain as persistent state. out_persistent holds the
    982   // PersistentTensor which is the object the caller should store. For
    983   // convenience, if out_tensor is non-null then it will be filled in
    984   // with a Tensor* pointing to the newly-allocated tensor which the
    985   // caller can use instead of calling
    986   // out_persistent->AccessTensor. The caller does not own out_tensor
    987   // and should not keep a copy of it. See comment above.
    988   Status allocate_persistent(DataType type, const TensorShape& shape,
    989                              PersistentTensor* out_persistent,
    990                              Tensor** out_tensor, AllocatorAttributes attr);
    991   Status allocate_persistent(DataType type, const TensorShape& shape,
    992                              PersistentTensor* out_persistent,
    993                              Tensor** out_tensor) {
    994     return allocate_persistent(type, shape, out_persistent, out_tensor,
    995                                AllocatorAttributes());
    996   }
    998   // Copies a tensor (allocated by the caller) to the specified output
    999   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
   1000   // REQUIRES: 'tensor' must have the same MemoryType as
   1001   // output_memory_types[index]. See comment above.
   1002   Status set_output(StringPiece name, const Tensor& tensor);
   1004   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
   1005   // and they must outlive all uses within the step. See comment above.
   1006   // REQUIRES: IsRefType(expected_output_dtype(index))
   1007   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
   1009   // Returns nullptr if allocate_output() or set_output() have not been called.
   1010   Status mutable_output(StringPiece name, Tensor** tensor);
   1012   // Records device specific state about how the input tensors were
   1013   // computed.
   1014   //
   1015   // If using the templated function, the type must be a subclass
   1016   // of DeviceContext.
   1017   //
   1018   // Get the DeviceContext used for the index input.  Returns nullptr
   1019   // if no DeviceContext was provided.
   1020   template <typename T>
   1021   T* input_device_context(int index);
   1022   DeviceContext* input_device_context(int index);
   1024   // Return the DeviceContext that should be used for this Op.
   1025   //
   1026   // If using the templated function, the type must be a subclass
   1027   // of DeviceContext.
   1028   //
   1029   // Returns nullptr if the device did not provide one.
   1030   template <typename T>
   1031   T* op_device_context();
   1032   DeviceContext* op_device_context() {
   1033     DeviceContext* ret = params_->op_device_context;
   1034     if (ret == nullptr) {
   1035       auto* dev_info = device()->tensorflow_gpu_device_info();
   1036       if (dev_info) ret = dev_info->default_context;
   1037     }
   1038     return ret;
   1039   }
   1041   AllocatorAttributes input_alloc_attr(int index) const {
   1042     if (params_->input_alloc_attrs == nullptr) {
   1043       return AllocatorAttributes();
   1044     } else {
   1045       DCHECK_GE(index, 0);
   1046       DCHECK_LT(index, params_->input_alloc_attrs->size());
   1047       return (*params_->input_alloc_attrs)[index];
   1048     }
   1049   }
   1051   AllocatorAttributes output_alloc_attr(int index) const {
   1052     return params_->output_attr_array[index];
   1053   }
   1055   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
   1056     mutex_lock lock(mu_);
   1057     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
   1058     retrieved.swap(wrapped_allocators_);
   1059     return retrieved;
   1060   }
   1062   // Communication.
   1063   //
   1064   // An op kernel communicates with outside environment through
   1065   // Rendezvous Send() and Recv().
   1066   Rendezvous* rendezvous() const { return params_->rendezvous; }
   1068   CollectiveExecutor* collective_executor() const {
   1069     return params_->collective_executor;
   1070   }
   1072   // An op kernel can access the session state it belongs to.
   1073   SessionState* session_state() const { return params_->session_state; }
   1075   // Unique identifier of the session it belongs to. Can be empty.
   1076   string session_handle() const { return params_->session_handle; }
   1078   // An op kernel can access the tensor store of the run it belongs to.
   1079   TensorStore* tensor_store() const { return params_->tensor_store; }
   1081   // Function call support.
   1082   //
   1083   // If this kernel invocation is within a function execution,
   1084   // call_frame() returns the call frame for the function call.
   1085   CallFrameInterface* call_frame() const { return params_->call_frame; }
   1087   // If not nullptr, the kernel invoke functions defined in the
   1088   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
   1089   FunctionLibraryRuntime* function_library() const {
   1090     return params_->function_library;
   1091   }
   1093   std::function<void(std::function<void()>)>* runner() const {
   1094     return params_->runner;
   1095   }
   1096   StepStatsCollectorInterface* stats_collector() const {
   1097     return params_->stats_collector;
   1098   }
   1100   // Shared resources accessible to this kernel.
   1101   ResourceMgr* resource_manager() const { return params_->resource_manager; }
   1103   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
   1104     return params_->slice_reader_cache;
   1105   }
   1107   // Execution.
   1108   //
   1109   // OpKernels can use these eigen devices to carry out their
   1110   // numerical computation.
   1111   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
   1112     return *device()->eigen_cpu_device();
   1113   }
   1114   const Eigen::GpuDevice& eigen_gpu_device() const {
   1115     return params_->eigen_gpu_device->device();
   1116   }
   1117 #ifdef TENSORFLOW_USE_SYCL
   1118   const Eigen::SyclDevice& eigen_sycl_device() const {
   1119     return *device()->eigen_sycl_device();
   1120   }
   1121 #endif
   1122   template <typename EigenDeviceType>
   1123   const EigenDeviceType& eigen_device() const;
   1125   // Error handling.
   1127   // If expected_inputs == inputs() and expected_outputs == output_types(),
   1128   // returns OK, else returns INVALID_ARGUMENT with an error message.
   1129   // Recommended for Ops with dynamic signatures, where validation can only
   1130   // be performed at runtime.
   1131   Status MatchSignature(const DataTypeSlice expected_inputs,
   1132                         const DataTypeSlice expected_outputs);
   1134   // An OpKernel should call SetStatus() if Compute() encounters an
   1135   // error.
   1136   void SetStatus(const Status& status);
   1137   const Status& status() const { return status_; }
   1139   // Cancellation.
   1140   //
   1141   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
   1142   // example of how to use this API.
   1143   CancellationManager* cancellation_manager() const {
   1144     return params_->cancellation_manager;
   1145   }
   1147   // Other accessors.
   1149   // For control flow.
   1150   FrameAndIter frame_iter() const { return params_->frame_iter; }
   1151   bool is_input_dead() const { return params_->is_input_dead; }
   1153   // May be used, e.g., to get GPU handles, etc.
   1154   // TODO(tucker): Add example usage.
   1155   DeviceBase* device() const { return params_->device; }
   1157   // Retrieve list of referenced tensors in out_vector. Once this is
   1158   // called, it is not legal to reference any more tensors.  Should
   1159   // not be called from Op kernels.
   1160   void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
   1162   // Per-step container for use by white-listed internal ops.
   1163   ScopedStepContainer* step_container() const {
   1164     return params_->step_container;
   1165   }
   1167   // Helper routines for the OP_REQUIRES macros
   1168   void CtxFailure(const Status& s);
   1169   void CtxFailureWithWarning(const Status& s);
   1170   void CtxFailure(const char* file, int line, const Status& s);
   1171   void CtxFailureWithWarning(const char* file, int line, const Status& s);
   1173   // Unrecommended functions: these are functions that have some
   1174   // current uses but are not recommended for use, and may go away at
   1175   // some future major version release.
   1176   //
   1177   // The following functions all have versions that return Status
   1178   // to capture error conditions, and are strongly preferred.
   1179   Tensor* mutable_output(int index);
   1180   void set_output(int index, const Tensor& tensor);
   1181   mutex* input_ref_mutex(int index);
   1182   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
   1183   TensorValue release_output(int index);
   1185   bool track_allocations() const { return params_->track_allocations; }
   1187   // Records temp memory allocation. Tensor object is recorded to identify the
   1188   // case where temp memory is used as output memory.
   1189   void record_temp_memory_allocation(int64 size, const Tensor& t)
   1190       LOCKS_EXCLUDED(stats_mu_);
   1192   // Returns recorded size of temporary memory;
   1193   int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
   1195   // Records persistent memory allocation, size can be negative indicating
   1196   // deallocation.
   1197   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
   1198       LOCKS_EXCLUDED(stats_mu_);
   1200   // Returns recorded size and ids of persistent memory.
   1201   int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
   1203   std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
   1205   // Resets counters for temp and persistent memory and recorded ids.
   1206   void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
   1208   bool input_is_ref(int index) const;
   1210   // Used by OpKernel implementations to track actively running deferred ops.
   1211   //
   1212   // A deferred op is one whose Compute method returns (or whose ComputeAsync
   1213   // method invokes the callback) when work is scheduled onto a device. At that
   1214   // point, we don't know when the work will actually complete (or if it has
   1215   // already completed) on the device. These functions allow the executor to
   1216   // track the status of deferred ops and act accordingly.
   1217   //
   1218   // Deferred OpKernel implementations must use these methods to get two
   1219   // functions. It then must call these two functions in pairs, before and after
   1220   // device execution, respectively.
   1221   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
   1222     return params_->inc_num_deferred_ops_function;
   1223   }
   1224   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
   1225     return params_->dec_num_deferred_ops_function;
   1226   }
   1228  private:
   1229   Allocator* get_allocator(AllocatorAttributes attr);
   1231   // Internal method to add a tensor's buffer to the list of buffers
   1232   // referenced during the execution of the Op, so that GPUs may
   1233   // accurately track the memory that may not be reused until the Op
   1234   // execution completes.
   1235   void record_tensor_reference(const Tensor& tensor);
   1236   void really_record_tensor_reference(const Tensor& tensor);
   1238   // Internal common method used when allocating tensor memory
   1239   Status allocate_tensor(DataType type, const TensorShape& shape,
   1240                          Tensor* out_tensor,
   1241                          AllocatorAttributes allocator_attr) {
   1242     return allocate_tensor(type, shape, out_tensor, allocator_attr,
   1243                            AllocationAttributes());
   1244   }
   1246   Status allocate_tensor(DataType type, const TensorShape& shape,
   1247                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
   1248                          const AllocationAttributes& allocation_attr);
   1250   // This is called by PersistentTensor::AccessTensor whenever the
   1251   // wrapped tensor is retrieved, to ensure the runtime knows that the
   1252   // Tensor is being accessed within an Op. This is necessary for
   1253   // memory safety of devices like GPUs that queue Ops for
   1254   // asynchronous execution after the Compute() method completes.
   1255   friend class PersistentTensor;
   1256   void NotifyUseOfPersistentTensor(const Tensor& tensor);
   1258   Status status_;
   1259   friend class CollectiveExecutor;  // for access to params_
   1260   Params* params_;                  // not owned
   1261   mutable mutex mu_;  // mutable so const accessors can acquire the lock
   1262   gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
   1263   gtl::InlinedVector<TensorValue, 4> outputs_;
   1265   // Constructed only if <params->record_tensor_accesses>.
   1266   ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
   1268   // The following data members are only used when allocation tracking is
   1269   // enabled.
   1270   mutable mutex stats_mu_;
   1271   int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
   1272   int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
   1273   std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
   1274       temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
   1275   std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
   1276       GUARDED_BY(stats_mu_);
   1278   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
   1279 };
   1281 // Register your OpKernel by specifying the Op's name, the device the
   1282 // kernel runs on, any type attr constraints for this kernel, any
   1283 // host-memory args, and the class to instantiate.  Examples:
   1284 //
   1285 //  // A kernel that supports all types.
   1286 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
   1287 //
   1288 //  // The following are equivalent ways of specifying that the kernel only
   1289 //  // works if the "T" type attr is set to DT_FLOAT.
   1291 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
   1292 //      SubOp<float>);
   1293 //  // (You would then repeat this for every type supported by "Sub".)
   1294 //
   1295 //  // This form allows you to specify a list of types as the constraint.
   1296 //  REGISTER_KERNEL_BUILDER(Name("Sub")
   1297 //                              .Device(DEVICE_CPU)
   1298 //                              .TypeConstraint("T", {DT_FLOAT}),
   1299 //                          SubOp<float>);
   1300 //
   1301 //  // A kernel that expects one of the input tensors in host memory.
   1303 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
   1304 //
   1305 // See kernel_def_builder for details.
   1307 // Instantiate an OpKernel that has been registered.  Returns nullptr
   1308 // if no operation for that type of device / input signature combination
   1309 // (and a NOT_FOUND *status), or there is an error in construction (and
   1310 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
   1311 // of the returned pointer.
   1312 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
   1313 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
   1314 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
   1315                                          DeviceBase* device,
   1316                                          Allocator* allocator,
   1317                                          const NodeDef& def,
   1318                                          int graph_def_version, Status* status);
   1319 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
   1320                       Allocator* allocator, FunctionLibraryRuntime* flib,
   1321                       const NodeDef& def, int graph_def_version,
   1322                       OpKernel** kernel);
   1324 // Returns into 'device_types' the subset of prioritized_types that this
   1325 // binary has registered for the given NodeDef.
   1326 //
   1327 // REQUIRES: * 'device_types' is not nullptr.
   1328 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
   1329 Status SupportedDeviceTypesForNode(
   1330     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
   1331     PrioritizedDeviceTypeVector* device_types);
   1333 // Returns a message with a description of the kernels registered for op
   1334 // `op_name`.
   1335 string KernelsRegisteredForOp(StringPiece op_name);
   1337 // Call once after Op registration has completed.
   1338 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
   1340 // -----------------------------------------------------------------------------
   1341 // OpKernel registration implementation follows, please ignore.
   1343 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
   1344 namespace register_kernel {
   1346 class Name : public KernelDefBuilder {
   1347  public:
   1348   // With selective registration, kernels whose implementation class is not used
   1349   // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
   1350   // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
   1351   // implementation class with a used kernel would get through that mechanism.
   1352   //
   1353   // This mechanism stops that registration by changing the name of the kernel
   1354   // for the unused op to one that is ignored by
   1355   // OpKernelRegistrar::InitInternal.  Note that this method alone is
   1356   // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
   1357   // compilation time, so this method doesn't actually reduce code size.
   1358   explicit Name(const char* op)
   1359       : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
   1360 };
   1362 namespace system {
   1364 class Name : public KernelDefBuilder {
   1365  public:
   1366   // For system kernels, we ignore selective registration and
   1367   // unconditionally register the kernel.
   1368   explicit Name(const char* op) : KernelDefBuilder(op) {}
   1369 };
   1371 }  // namespace system
   1373 }  // namespace register_kernel
   1375 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
   1378 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
   1379   REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
   1381 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
   1382   constexpr bool should_register_##ctr##__flag =                      \
   1383       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
   1384   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
   1385       registrar__body__##ctr##__object(                               \
   1386           should_register_##ctr##__flag                               \
   1387               ? ::tensorflow::register_kernel::kernel_builder.Build() \
   1388               : nullptr,                                              \
   1389           #__VA_ARGS__,                                               \
   1390           [](::tensorflow::OpKernelConstruction* context)             \
   1391               -> ::tensorflow::OpKernel* {                            \
   1392             return new __VA_ARGS__(context);                          \
   1393           });
   1395 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
   1396 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
   1397 // unconditionally even when selective registration is used.
   1398 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...)               \
   1400                                              __VA_ARGS__)
   1402 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
   1403   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
   1405 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)    \
   1406   static ::tensorflow::kernel_factory::OpKernelRegistrar                 \
   1407       registrar__body__##ctr##__object(                                  \
   1408           ::tensorflow::register_kernel::system::kernel_builder.Build(), \
   1409           #__VA_ARGS__,                                                  \
   1410           [](::tensorflow::OpKernelConstruction* context)                \
   1411               -> ::tensorflow::OpKernel* {                               \
   1412             return new __VA_ARGS__(context);                             \
   1413           });
   1415 // Checks whether a given kernel is registered on device_type.
   1416 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
   1418 // If node_def has a corresponding kernel registered on device_type,
   1419 // returns OK and fill in the kernel def and kernel_class_name. <def> and
   1420 // <kernel_class_name> may be null.
   1421 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
   1422                      const KernelDef** def, string* kernel_class_name);
   1424 // Writes a list of all registered kernels to LOG(INFO), to help users debug
   1425 // missing kernel errors.
   1426 void LogAllRegisteredKernels();
   1428 // Gets a list of all registered kernels.
   1429 KernelList GetAllRegisteredKernels();
   1431 // Gets a list of all registered kernels for which predicate returns true
   1432 KernelList GetFilteredRegisteredKernels(
   1433     const std::function<bool(const KernelDef&)>& predicate);
   1435 // Gets a list of all registered kernels for a given op
   1436 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
   1438 namespace kernel_factory {
   1440 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
   1441 // them. You register factories with the TensorFlow core by constructing an
   1442 // OpKernelRegistrar and passing the factory as a constructor parameter.
   1443 class OpKernelFactory {
   1444  public:
   1445   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
   1446   virtual ~OpKernelFactory() = default;
   1447 };
   1449 class OpKernelRegistrar {
   1450  public:
   1451   // Registers the given kernel factory with TensorFlow. TF will call the
   1452   // factory Create() method when it determines that a kernel matching the given
   1453   // KernelDef is required.
   1454   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
   1455                     std::unique_ptr<OpKernelFactory> factory) {
   1456     // Perform the check in the header to allow compile-time optimization
   1457     // to a no-op, allowing the linker to remove the kernel symbols.
   1458     if (kernel_def != nullptr) {
   1459       InitInternal(kernel_def, kernel_class_name, std::move(factory));
   1460     }
   1461   }
   1463   // Registers the given factory function with TensorFlow. This is equivalent
   1464   // to registering a factory whose Create function invokes `create_fn`.
   1465   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
   1466                     OpKernel* (*create_fn)(OpKernelConstruction*)) {
   1467     // Perform the check in the header to allow compile-time optimization
   1468     // to a no-op, allowing the linker to remove the kernel symbols.
   1469     if (kernel_def != nullptr) {
   1470       InitInternal(kernel_def, kernel_class_name,
   1471                    absl::make_unique<PtrOpKernelFactory>(create_fn));
   1472     }
   1473   }
   1475  private:
   1476   struct PtrOpKernelFactory : public OpKernelFactory {
   1477     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
   1478         : create_func_(create_func) {}
   1480     OpKernel* Create(OpKernelConstruction* context) override;
   1482     OpKernel* (*create_func_)(OpKernelConstruction*);
   1483   };
   1485   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
   1486                     std::unique_ptr<OpKernelFactory> factory);
   1487 };
   1489 }  // namespace kernel_factory
   1491 // -----------------------------------------------------------------------------
   1492 // Template and inline method implementations, please ignore
   1494 template <class T>
   1495 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
   1496   return GetNodeAttr(def(), attr_name, value);
   1497 }
   1499 inline DataType OpKernelContext::input_dtype(int index) const {
   1500   DCHECK_GE(index, 0);
   1501   DCHECK_LT(index, num_inputs());
   1502   const TensorValue& value((*params_->inputs)[index]);
   1503   if (value.is_ref()) {
   1504     return MakeRefType(value->dtype());
   1505   } else {
   1506     return value->dtype();
   1507   }
   1508 }
   1510 inline MemoryType OpKernelContext::input_memory_type(int index) const {
   1511   DCHECK_GE(index, 0);
   1512   DCHECK_LT(index, num_inputs());
   1513   return op_kernel().input_memory_types()[index];
   1514 }
   1516 inline DataType OpKernelContext::expected_output_dtype(int index) const {
   1517   DCHECK_GE(index, 0);
   1518   DCHECK_LT(index, num_outputs());
   1519   return params_->op_kernel->output_type(index);
   1520 }
   1522 inline MemoryType OpKernelContext::output_memory_type(int index) const {
   1523   DCHECK_GE(index, 0);
   1524   DCHECK_LT(index, num_outputs());
   1525   return op_kernel().output_memory_types()[index];
   1526 }
   1528 inline bool OpKernelContext::input_is_ref(int index) const {
   1529   const TensorValue& value((*params_->inputs)[index]);
   1530   return value.is_ref();
   1531 }
   1533 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
   1534   DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
   1535             params_->record_tensor_accesses);
   1536   if (params_->record_tensor_accesses) {
   1537     really_record_tensor_reference(tensor);
   1538   }
   1539 }
   1541 inline void OpKernelContext::retrieve_accessed_tensors(
   1542     TensorReferenceVector* out_vector) {
   1543   if (params_->record_tensor_accesses) {
   1544     mutex_lock l(mu_);
   1545     referenced_tensors_->FreezeAndReturnReferences(out_vector);
   1546   }
   1547 }
   1549 // no input if tensor == nullptr.
   1550 inline bool OpKernelContext::has_input(int index) const {
   1551   DCHECK_GE(index, 0);
   1552   DCHECK_LT(index, num_inputs());
   1553   return (*params_->inputs)[index].tensor != nullptr;
   1554 }
   1556 inline mutex* OpKernelContext::input_ref_mutex(int index) {
   1557   DCHECK_GE(index, 0);
   1558   DCHECK_LT(index, num_inputs());
   1559   DCHECK(input_is_ref(index));
   1560   return (*params_->inputs)[index].mutex_if_ref;
   1561 }
   1563 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
   1564   if (t.IsInitialized()) {
   1565     record_tensor_reference(t);
   1566   }
   1567 }
   1569 inline Tensor* OpKernelContext::mutable_output(int index) {
   1570   DCHECK_GE(index, 0);
   1571   DCHECK_LT(index, num_outputs());
   1572   // No need to record_tensor_reference since the output must already
   1573   // have been set by a call that did so.
   1574   return outputs_[index].tensor;
   1575 }
   1577 inline TensorValue OpKernelContext::release_output(int index) {
   1578   DCHECK_GE(index, 0);
   1579   DCHECK_LT(index, num_outputs());
   1580   TensorValue value = outputs_[index];
   1581   outputs_[index] = TensorValue();
   1582   return value;
   1583 }
   1585 inline Status OpKernelContext::forward_input_or_allocate_output(
   1586     gtl::ArraySlice<int> candidate_input_indices, int output_index,
   1587     const TensorShape& output_shape, Tensor** output) {
   1588   for (int input_index : candidate_input_indices) {
   1589     if (forward_input_to_output_with_shape(input_index, output_index,
   1590                                            output_shape, output)) {
   1591       return Status::OK();
   1592     }
   1593   }
   1594   return allocate_output(output_index, output_shape, output);
   1595 }
   1597 inline Status OpKernelContext::forward_input_or_allocate_output(
   1598     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
   1599     const TensorShape& output_shape, Tensor** output) {
   1600   for (const StringPiece& input_name : candidate_input_names) {
   1601     if (forward_input_to_output_with_shape(input_name, output_name,
   1602                                            output_shape, output)
   1603             .ok()) {
   1604       return Status::OK();
   1605     }
   1606   }
   1607   return allocate_output(output_name, output_shape, output);
   1608 }
   1610 template <typename T>
   1611 T* OpKernelContext::op_device_context() {
   1612   static_assert(std::is_base_of<DeviceContext, T>::value,
   1613                 "T is not a subclass of DeviceContext");
   1614   return static_cast<T*>(op_device_context());
   1615 }
   1617 template <typename T>
   1618 T* OpKernelContext::input_device_context(int index) {
   1619   DCHECK_NE(params_->input_device_contexts, nullptr);
   1620   DCHECK_GE(index, 0);
   1621   DCHECK_LT(index, params_->input_device_contexts->size());
   1622   static_assert(std::is_base_of<DeviceContext, T>::value,
   1623                 "T is not a subclass of DeviceContext");
   1624   return static_cast<T*>((*params_->input_device_contexts)[index]);
   1625 }
   1627 inline DeviceContext* OpKernelContext::input_device_context(int index) {
   1628   DCHECK_NE(params_->input_device_contexts, nullptr);
   1629   DCHECK_GE(index, 0);
   1630   DCHECK_LT(index, params_->input_device_contexts->size());
   1631   return (*params_->input_device_contexts)[index];
   1632 }
   1634 inline const Tensor& OpInputList::operator[](int i) const {
   1635   DCHECK_GE(i, 0);
   1636   DCHECK_LT(i, stop_ - start_);
   1637   return ctx_->input(start_ + i);
   1638 }
   1640 inline mutex* OpMutableInputList::ref_mutex(int i) {
   1641   DCHECK_GE(i, 0);
   1642   DCHECK_LT(i, stop_ - start_);
   1643   return ctx_->input_ref_mutex(start_ + i);
   1644 }
   1646 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
   1647   DCHECK_GE(i, 0);
   1648   DCHECK_LT(i, stop_ - start_);
   1649   return ctx_->mutable_input(start_ + i, lock_held);
   1650 }
   1652 inline Tensor* OpOutputList::operator[](int i) {
   1653   DCHECK_GE(i, 0);
   1654   DCHECK_LT(i, stop_ - start_);
   1655   return ctx_->mutable_output(start_ + i);
   1656 }
   1658 inline bool OpOutputList::required(int i) const {
   1659   DCHECK_GE(i, 0);
   1660   DCHECK_LT(i, stop_ - start_);
   1661   return ctx_->output_required(start_ + i);
   1662 }
   1664 inline DataType OpOutputList::expected_output_dtype(int i) const {
   1665   DCHECK_GE(i, 0);
   1666   DCHECK_LT(i, stop_ - start_);
   1667   return ctx_->expected_output_dtype(start_ + i);
   1668 }
   1670 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
   1671                                      Tensor** output) {
   1672   DCHECK_GE(i, 0);
   1673   DCHECK_LT(i, stop_ - start_);
   1674   return ctx_->allocate_output(start_ + i, shape, output);
   1675 }
   1677 inline void OpOutputList::set(int i, const Tensor& tensor) {
   1678   DCHECK_GE(i, 0);
   1679   DCHECK_LT(i, stop_ - start_);
   1680   ctx_->set_output(start_ + i, tensor);
   1681 }
   1683 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
   1684   DCHECK_GE(i, 0);
   1685   DCHECK_LT(i, stop_ - start_);
   1686   ctx_->set_output_ref(i, mu, tensor_for_ref);
   1687 }
   1689 // Convenience macros for asserting and handling exceptional conditions.
   1690 // Analogous to the CHECK* macros provided by logging.h.
   1691 //
   1692 // Example use:
   1693 // void Compute(OperationContext* context) {
   1694 //   OP_REQUIRES(context, context->num_inputs() == 2,
   1695 //               errors::InvalidArgument("FooOp requires 2 arguments"));
   1696 //   ...
   1697 //   Status status = SomeUncertainMethod();
   1698 //   OP_REQUIRES_OK(context, status);
   1699 //   ...
   1700 // }
   1702 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
   1703 // AsyncOpKernel implementations. If these macros are used and the condition
   1704 // does not hold, the `done` callback will never be called and the system will
   1705 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
   1706 // are legal to use in AsyncOpKernel constructors, we use overload resolution
   1707 // to distinguish between OpKernelConstruction* and OpKernelContext* context
   1708 // types.
   1709 class XlaOpKernelContext;
   1710 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
   1711 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
   1712 void CheckNotInComputeAsync(OpKernelContext* ctx,
   1713                             const char* correct_macro_name);
   1715 #define OP_REQUIRES(CTX, EXP, STATUS)                     \
   1716   do {                                                    \
   1717     if (!TF_PREDICT_TRUE(EXP)) {                          \
   1718       CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
   1719       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS));    \
   1720       return;                                             \
   1721     }                                                     \
   1722   } while (0)
   1724 #define OP_REQUIRES_OK(CTX, ...)                             \
   1725   do {                                                       \
   1726     ::tensorflow::Status _s(__VA_ARGS__);                    \
   1727     if (!TF_PREDICT_TRUE(_s.ok())) {                         \
   1728       CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
   1729       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s);  \
   1730       return;                                                \
   1731     }                                                        \
   1732   } while (0)
   1735   do {                                                 \
   1736     if (!TF_PREDICT_TRUE(EXP)) {                       \
   1737       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
   1738       (CALLBACK)();                                    \
   1739       return;                                          \
   1740     }                                                  \
   1741   } while (0)
   1743 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK)         \
   1744   do {                                                      \
   1745     ::tensorflow::Status _s(STATUS);                        \
   1746     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
   1747       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
   1748       (CALLBACK)();                                         \
   1749       return;                                               \
   1750     }                                                       \
   1751   } while (0)
   1753 }  // namespace tensorflow