Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
     17 #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
     18 
     19 #include <atomic>
     20 
     21 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
     22 #include "tensorflow/compiler/jit/xla_device.h"
     23 #include "tensorflow/compiler/jit/xla_launch_util.h"
     24 #include "tensorflow/core/framework/allocator.h"
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 #include "tensorflow/core/util/stream_executor_util.h"
     30 
     31 namespace tensorflow {
     32 
     33 // Holds some information about the platform on which an
     34 // XlaLaunch/_XlaCompile/_XlaRun op must run on.
     35 class XlaPlatformInfo {
     36  public:
     37   XlaPlatformInfo() : device_type_("") {}
     38   XlaPlatformInfo(XlaPlatformInfo&&) = default;
     39   explicit XlaPlatformInfo(const DeviceType device_type,
     40                            se::Platform::Id platform_id,
     41                            const XlaDevice::Metadata* xla_device_metadata,
     42                            std::unique_ptr<XlaAllocator> xla_allocator,
     43                            xla::DeviceMemoryAllocator* device_allocator)
     44       : device_type_(device_type),
     45         platform_id_(platform_id),
     46         xla_device_metadata_(xla_device_metadata),
     47         xla_allocator_(std::move(xla_allocator)),
     48         device_allocator_(device_allocator) {
     49     CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
     50   }
     51 
     52   XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
     53 
     54   bool UseMultipleStreams() const {
     55     return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
     56   }
     57 
     58   xla::DeviceMemoryAllocator* allocator() const {
     59     return device_allocator_ ? device_allocator_ : xla_allocator_.get();
     60   }
     61   DeviceType device_type() const { return device_type_; }
     62 
     63   // This is equal to xla_device_metadata()->platform()->id() if
     64   // xla_device_metadata() is not nullptr.
     65   se::Platform::Id platform_id() const { return platform_id_; }
     66 
     67   // This may be null if the op this XlaPlatformInfo is for was not placed on an
     68   // XLA device.
     69   const XlaDevice::Metadata* xla_device_metadata() const {
     70     return xla_device_metadata_;
     71   }
     72   bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
     73 
     74  private:
     75   DeviceType device_type_;
     76   se::Platform::Id platform_id_;
     77 
     78   // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
     79   // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
     80   // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
     81   const XlaDevice::Metadata* xla_device_metadata_;
     82 
     83   // If the op associated with this XlaPlatformInfo is placed on an XLA device
     84   // then device_allocator_ is the xla::Backend's memory allocator and
     85   // xla_allocator_ is null.  If the op is placed on a regular CPU or GPU device
     86   // then device_allocator_ is null and xla_allocator_ points to an appropriate
     87   // XlaAllocator instance.
     88   std::unique_ptr<XlaAllocator> xla_allocator_;
     89   xla::DeviceMemoryAllocator* device_allocator_;
     90 
     91   TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
     92 };
     93 
     94 // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
     95 // The only difference is that it does not require arguments to follow
     96 // the "constants, then regular args, then resources" order.
     97 // It takes vectors of constant and resource arguments explicitly.
     98 // It does not have corresponding OpDef because it is never present
     99 // in the GraphDef.
    100 // Currently, it is used by eager runtime. FunctionLibraryRuntime creates
    101 // this kernel when asked to create a kernel for an XLA-compiled function.
    102 class XlaLocalLaunchBase : public OpKernel {
    103  public:
    104   XlaLocalLaunchBase(OpKernelConstruction* ctx,
    105                      const std::vector<int>& constants,
    106                      const std::vector<int>& resources,
    107                      const NameAttrList& function);
    108   XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
    109   XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
    110   ~XlaLocalLaunchBase() override = default;
    111 
    112   void Compute(OpKernelContext* ctx) override;
    113 
    114  protected:
    115   // Indexes of compile-time constant inputs
    116   const std::vector<int> constants_;
    117   // Indexes of resource inputs
    118   const std::vector<int> resources_;
    119 
    120   const NameAttrList function_;
    121   const XlaPlatformInfo platform_info_;
    122 };
    123 
    124 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
    125 // which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
    126 // responsible for handling interactions with the TensorFlow executor.
    127 // Once all inputs are present, and their shapes are known, the op can
    128 // use a 'XlaCompilationCache' to compile and execute code which is specific
    129 // to the shapes of input Tensors.
    130 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and
    131 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
    132 // memory.
    133 class XlaLocalLaunchOp : public XlaLocalLaunchBase {
    134  public:
    135   explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
    136   ~XlaLocalLaunchOp() override;
    137 
    138  private:
    139   TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
    140 };
    141 
    142 class XlaCompileOp : public OpKernel {
    143  public:
    144   explicit XlaCompileOp(OpKernelConstruction* ctx);
    145 
    146   void Compute(OpKernelContext* ctx) override;
    147 
    148  private:
    149   // Indexes of compile-time constant inputs
    150   const std::vector<int> constants_;
    151   // Indexes of resource inputs
    152   const std::vector<int> resources_;
    153 
    154   const NameAttrList function_;
    155 
    156   XlaPlatformInfo platform_info_;
    157 
    158   const bool must_compile_;
    159 
    160   // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
    161   // error when compiling the cluster this _XlaCompile is supposed to compile.
    162   // If `cannot_compile_cluster_` is true then we avoid compiling this cluster
    163   // on any future calls to _XlaCompile.
    164   bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
    165 
    166   mutex cannot_compile_cluster_mu_;
    167 };
    168 
    169 class XlaRunOp : public OpKernel {
    170  public:
    171   explicit XlaRunOp(OpKernelConstruction* ctx);
    172 
    173   void Compute(OpKernelContext* ctx) override;
    174 
    175  private:
    176   const XlaPlatformInfo platform_info_;
    177 };
    178 
    179 }  // namespace tensorflow
    180 
    181 #endif  // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
    182