Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
     17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
     18 
     19 #include "absl/container/flat_hash_map.h"
     20 #include "absl/types/optional.h"
     21 #include "absl/types/span.h"
     22 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     23 #include "tensorflow/compiler/tf2xla/xla_context.h"
     24 #include "tensorflow/compiler/xla/client/local_client.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/core/common_runtime/device.h"
     27 #include "tensorflow/core/common_runtime/device_mgr.h"
     28 #include "tensorflow/core/framework/graph.pb.h"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/lib/core/threadpool.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/thread_annotations.h"
     33 
     34 namespace tensorflow {
     35 
     36 // The XlaCompilationCache class caches the results of the XlaCompiler class,
     37 // which converts a Tensorflow graph into a compiled XLA compilation.
     38 //
     39 // Since XLA computations must have static shapes, the cache generates a new
     40 // XLA computation for each new set of input shapes.
     41 //
     42 // Currently no cache eviction policy is implemented and the cache grows without
     43 // bound.
     44 class XlaCompilationCache : public ResourceBase {
     45  public:
     46   XlaCompilationCache(xla::LocalClient* client, DeviceType device_type);
     47   ~XlaCompilationCache() override;
     48 
     49   enum class CompileMode {
     50     kLazy,
     51     kStrict,
     52   };
     53 
     54   // Compiles a function into a XlaCompiler::CompilationResult that can be used
     55   // to execute an XLA Computation. Compilation results are cached.
     56   // `function` is the name of a Tensorflow function to compile.
     57   // `args` is a description of the arguments to the computation.
     58   //
     59   // `compile_mode` controls the behavior of the compilation cache on a cache
     60   // miss.  If `compile_mode` is `kLazy` then, based on some profitability
     61   // heuristics, the compilation cache may decide not to compile the cluster at
     62   // this time.  In this case it returns null into both `out_compilation_result`
     63   // and `out_executable`.  If `compile_mode` is `kStrict` then the compilation
     64   // cache always attempts the compilation on a cache miss.
     65   //
     66   // The result of compilation is written to `*compilation_result`, which must
     67   // be non-null. If `executable` is non-null, also builds an
     68   // xla::LocalExecutable and sets `executable` to point to it. The resulting
     69   // executable pointer may be null if the computation has no non-constant
     70   // outputs.
     71   Status Compile(const XlaCompiler::Options& options,
     72                  const NameAttrList& function,
     73                  absl::Span<const XlaCompiler::Argument> args,
     74                  const XlaCompiler::CompileOptions& compile_options,
     75                  CompileMode compile_mode,
     76                  const XlaCompiler::CompilationResult** out_compilation_result,
     77                  xla::LocalExecutable** out_executable);
     78 
     79   // As above, but calls XlaCompiler::CompileSingleOp instead of
     80   // XlaCompiler::CompileFunction.
     81   Status CompileSingleOp(
     82       const XlaCompiler::Options& options,
     83       absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
     84       const XlaCompiler::CompileOptions& compile_options,
     85       const XlaCompiler::CompilationResult** out_compilation_result,
     86       xla::LocalExecutable** out_executable);
     87 
     88   xla::LocalClient* client() const { return client_; }
     89   const DeviceType& device_type() const { return device_type_; }
     90 
     91   string DebugString() const override;
     92 
     93   // Describes the types, shapes and any compile-time constant arguments
     94   // to a kernel. Key that uniquely identifies a compilation output.
     95   struct Signature {
     96     string name;
     97 
     98     // List of Tensor types & shapes for compile-time constant arguments to the
     99     // compilation, ordered by argument number.
    100     std::vector<std::pair<DataType, std::vector<int64>>> arg_shapes;
    101 
    102     // List of Tensor values for compile-time constant arguments to the
    103     // compilation, ordered by argument number. Tensors must be in host memory.
    104     std::vector<Tensor> arg_values;
    105 
    106     bool operator==(const Signature& other) const;
    107 
    108     struct Hash {
    109       uint64 operator()(const Signature& signature) const;
    110     };
    111 
    112     // Returns a human-readable description of the signature.
    113     string HumanString() const;
    114   };
    115 
    116   // Builds the signature for a compilation.
    117   static xla::StatusOr<Signature> BuildSignature(
    118       const NameAttrList& function,
    119       absl::Span<const XlaCompiler::Argument> args);
    120 
    121  private:
    122   // Common implementation of Compile and CompileSingleOp.
    123   Status CompileImpl(
    124       const XlaCompiler::Options& options, const NameAttrList& function,
    125       absl::Span<const XlaCompiler::Argument> args,
    126       const std::function<Status(XlaCompiler* compiler,
    127                                  XlaCompiler::CompilationResult*)>& compile_fn,
    128       absl::optional<int64> compile_threshold,
    129       const XlaCompiler::CompilationResult** out_compilation_result,
    130       xla::LocalExecutable** out_executable);
    131 
    132   // Takes `result` which has been compiled from a Tensorflow subgraph to a
    133   // XLA computation already, and generates an XLA LocalExecutable `executable`.
    134   Status BuildExecutable(const XlaCompiler::Options& options,
    135                          const XlaCompiler::CompilationResult& result,
    136                          std::unique_ptr<xla::LocalExecutable>* executable);
    137 
    138   xla::LocalClient* const client_;
    139   const DeviceType device_type_;
    140 
    141   // The value associated with a cache entry.
    142   struct Entry {
    143     mutex mu;
    144 
    145     // Have we tried compiling this entry?
    146     bool compiled = false;
    147 
    148     // The number of times a compilation with this signature has been requested.
    149     int64 request_count = 0;
    150 
    151     // Did compilation succeed?
    152     Status compilation_status GUARDED_BY(mu);
    153 
    154     // Output of the XlaCompiler.
    155     XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
    156 
    157     // The XLA executable compiled from <computation>. May be null if no
    158     // executable has been built.
    159     std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
    160   };
    161 
    162   mutex compile_cache_mu_;
    163   absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
    164       GUARDED_BY(compile_cache_mu_);
    165 
    166   struct ClusterCompileStats {
    167     // Number of times the cluster has been (re-)compiled.
    168     int64 compile_count = 0;
    169 
    170     // The number of times this cluster has been executed.
    171     int64 execution_count = 0;
    172 
    173     // Cumulative time spent compiling the cluster.
    174     int64 cumulative_compile_time_us = 0;
    175 
    176     // True if we have decided that this cluster is too dynamic (i.e. its shapes
    177     // change too frequently) to profitably JIT compile.  Once a cluster is
    178     // tagged megamorphic, it stays megamorphic forever.
    179     bool is_megamorphic = false;
    180   };
    181 
    182   mutex cluster_compile_stats_mu_;
    183 
    184   // Maps cluster names to compilation statistics for said cluster.
    185   absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
    186       GUARDED_BY(cluster_compile_stats_mu_);
    187 
    188   // The number of times a lazy compilation must be requested for a specific
    189   // signature before  we attempt to compile it.
    190   static constexpr int64 kDefaultCompilationThreshold = 2;
    191 
    192   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
    193 };
    194 
    195 }  // namespace tensorflow
    196 
    197 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
    198