Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
     17 
     18 #include <numeric>
     19 
     20 #include "absl/strings/str_cat.h"
     21 #include "absl/strings/str_join.h"
     22 #include "tensorflow/compiler/tf2xla/shape_util.h"
     23 #include "tensorflow/compiler/tf2xla/type_util.h"
     24 #include "tensorflow/compiler/tf2xla/xla_context.h"
     25 #include "tensorflow/compiler/xla/client/client_library.h"
     26 #include "tensorflow/core/common_runtime/device.h"
     27 #include "tensorflow/core/common_runtime/function.h"
     28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     29 #include "tensorflow/core/framework/attr_value_util.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/graph/graph_constructor.h"
     32 #include "tensorflow/core/graph/node_builder.h"
     33 #include "tensorflow/core/lib/hash/hash.h"
     34 #include "tensorflow/core/platform/env.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/public/version.h"
     37 #include "tensorflow/core/util/dump_graph.h"
     38 
     39 namespace tensorflow {
     40 
     41 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
     42 
     43 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
     44                                          DeviceType device_type)
     45     : client_(client), device_type_(std::move(device_type)) {}
     46 
     47 XlaCompilationCache::~XlaCompilationCache() {
     48   // Ensure any use of our programs have completed by waiting for all stream
     49   // executors to complete.
     50   for (auto* executor : client_->backend().stream_executors()) {
     51     bool ok = executor->SynchronizeAllActivity();
     52     if (!ok) {
     53       LOG(ERROR) << "Error synchronizing activity while waiting for all "
     54                     "programs to complete";
     55     }
     56   }
     57   // TODO(b/110813685): Think about the program ownership model. Programs are
     58   // currently owned by the compilation cache which means we must wait for
     59   // program completion in the destructor. There are multiple compilation caches
     60   // around, which complicates things a little. Perhaps having programs be
     61   // shared_ptrs (an invasive change) would make the model easier to reason
     62   // about?
     63 }
     64 
     65 string XlaCompilationCache::DebugString() const {
     66   return "XLA JIT compilation cache";
     67 }
     68 
     69 // Compute a string signature which encodes the shapes of the
     70 // arguments in the supplied list.
     71 string XlaCompilationCache::Signature::HumanString() const {
     72   string result = name;
     73   for (const auto& a : arg_shapes) {
     74     absl::StrAppend(&result, ",", DataTypeString(a.first));
     75     absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
     76   }
     77 
     78   for (const auto& v : arg_values) {
     79     absl::StrAppend(&result, "; ", v.DebugString());
     80   }
     81   return result;
     82 }
     83 
     84 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
     85   if (name != other.name) return false;
     86   if (arg_shapes != other.arg_shapes) return false;
     87 
     88   if (arg_values.size() != other.arg_values.size()) return false;
     89   for (int i = 0; i < arg_values.size(); ++i) {
     90     if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
     91         arg_values[i].shape() != other.arg_values[i].shape() ||
     92         arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
     93       return false;
     94     }
     95   }
     96   return true;
     97 }
     98 
     99 uint64 XlaCompilationCache::Signature::Hash::operator()(
    100     const XlaCompilationCache::Signature& signature) const {
    101   uint64 h = std::hash<string>()(signature.name);
    102   for (const auto& arg : signature.arg_shapes) {
    103     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
    104     h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
    105     for (int dim : arg.second) {
    106       h = Hash64Combine(h, std::hash<int>()(dim));
    107     }
    108   }
    109   for (const auto& arg : signature.arg_values) {
    110     h = Hash64Combine(
    111         h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
    112   }
    113   return h;
    114 }
    115 
    116 xla::StatusOr<XlaCompilationCache::Signature>
    117 XlaCompilationCache::BuildSignature(
    118     const NameAttrList& function,
    119     absl::Span<const XlaCompiler::Argument> args) {
    120   Signature signature;
    121   signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
    122   for (const XlaCompiler::Argument& arg : args) {
    123     switch (arg.kind) {
    124       case XlaCompiler::Argument::kConstant:
    125         signature.arg_values.push_back(arg.constant_value);
    126         break;
    127       case XlaCompiler::Argument::kParameter:
    128       case XlaCompiler::Argument::kResource:
    129         signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes());
    130         break;
    131       default:
    132         return errors::InvalidArgument(
    133             "Unhandled argument kind in XlaCompilationCache: ",
    134             arg.HumanString());
    135     }
    136   }
    137   return std::move(signature);
    138 }
    139 
    140 Status XlaCompilationCache::BuildExecutable(
    141     const XlaCompiler::Options& options,
    142     const XlaCompiler::CompilationResult& result,
    143     std::unique_ptr<xla::LocalExecutable>* executable) {
    144   VLOG(2) << "Compiling to local executable";
    145 
    146   std::vector<const xla::Shape*> argument_layouts(
    147       result.xla_input_shapes.size());
    148   for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
    149     argument_layouts[i] = &result.xla_input_shapes[i];
    150   }
    151   xla::ExecutableBuildOptions build_options;
    152   build_options.set_device_ordinal(options.device_ordinal != -1
    153                                        ? options.device_ordinal
    154                                        : client_->default_device_ordinal());
    155   build_options.set_result_layout(result.xla_output_shape);
    156   build_options.set_device_allocator(options.device_allocator);
    157 
    158   auto compile_result =
    159       client_->Compile(*result.computation, argument_layouts, build_options);
    160   if (!compile_result.ok()) {
    161     return compile_result.status();
    162   }
    163   *executable = std::move(compile_result.ValueOrDie());
    164   return Status::OK();
    165 }
    166 
    167 Status XlaCompilationCache::Compile(
    168     const XlaCompiler::Options& options, const NameAttrList& function,
    169     absl::Span<const XlaCompiler::Argument> args,
    170     const XlaCompiler::CompileOptions& compile_options,
    171     CompileMode compile_mode,
    172     const XlaCompiler::CompilationResult** out_compilation_result,
    173     xla::LocalExecutable** out_executable) {
    174   absl::optional<int64> compile_threshold;
    175   if (compile_mode == CompileMode::kLazy) {
    176     compile_threshold = kDefaultCompilationThreshold;
    177   }
    178   auto compile_fn = [&](XlaCompiler* compiler,
    179                         XlaCompiler::CompilationResult* result) {
    180     return compiler->CompileFunction(compile_options, function, args, result);
    181   };
    182   return CompileImpl(options, function, args, compile_fn,
    183                      /*compile_threshold=*/compile_threshold,
    184                      out_compilation_result, out_executable);
    185 }
    186 
    187 static bool IsMegamorphic(int64 compile_count, int64 execution_count) {
    188   const int64 kCompileThreshold = 10;
    189   const int64 kMinExecutionsPerCompile = 50;
    190 
    191   // This heuristic is trying to capture the following property: have we sunk a
    192   // certain minimum amount of compile time into the cluster that didn't quite
    193   // "pay off"?
    194   return compile_count > kCompileThreshold &&
    195          execution_count < kMinExecutionsPerCompile * compile_count;
    196 }
    197 
    198 Status XlaCompilationCache::CompileSingleOp(
    199     const XlaCompiler::Options& options,
    200     absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
    201     const XlaCompiler::CompileOptions& compile_options,
    202     const XlaCompiler::CompilationResult** out_compilation_result,
    203     xla::LocalExecutable** out_executable) {
    204   const NodeDef& def = ctx->op_kernel().def();
    205   NameAttrList name;
    206   name.set_name(def.op());
    207   *name.mutable_attr() = def.attr();
    208   // Remove the "_class" attribute from the attribute set used to create the
    209   // compilation cache key. This attribute is information for the colocator
    210   // and causes false uniqueness between nodes.
    211   name.mutable_attr()->erase("_class");
    212   auto compile_op = [&](XlaCompiler* compiler,
    213                         XlaCompiler::CompilationResult* result) {
    214     std::vector<DataType> result_dtypes(ctx->num_outputs());
    215     for (int i = 0; i < result_dtypes.size(); ++i) {
    216       result_dtypes[i] = ctx->expected_output_dtype(i);
    217     }
    218     return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
    219                                      args, result_dtypes, result);
    220   };
    221   return CompileImpl(options, name, args, compile_op,
    222                      /*compile_threshold=*/absl::nullopt,
    223                      out_compilation_result, out_executable);
    224 }
    225 
    226 Status XlaCompilationCache::CompileImpl(
    227     const XlaCompiler::Options& options, const NameAttrList& function,
    228     absl::Span<const XlaCompiler::Argument> args,
    229     const std::function<Status(XlaCompiler* compiler,
    230                                XlaCompiler::CompilationResult*)>& compile_fn,
    231     absl::optional<int64> compile_threshold,
    232     const XlaCompiler::CompilationResult** out_compilation_result,
    233     xla::LocalExecutable** out_executable) {
    234   DCHECK_NE(out_executable, nullptr);
    235   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
    236 
    237   if (VLOG_IS_ON(2)) {
    238     VLOG(2) << "num_inputs=" << args.size();
    239     for (int i = 0; i < args.size(); i++) {
    240       VLOG(2) << i << ": " << args[i].HumanString();
    241     }
    242   }
    243 
    244   TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
    245   VLOG(2) << "Signature: " << signature.HumanString();
    246 
    247   // The outer lock protects the existence of the cache entry. It does not
    248   // protect the contents of the cache entry.
    249   Entry* entry;
    250   {
    251     mutex_lock lock(compile_cache_mu_);
    252     // Find or create a cache entry.
    253     std::unique_ptr<Entry>& e = cache_[signature];
    254     if (!e) {
    255       e.reset(new Entry);
    256     }
    257     entry = e.get();
    258   }
    259 
    260   // We always compile a cluster the very first time it is executed.  This is an
    261   // optimistic guess that pays off for statically shaped TensorFlow graphs
    262   // (since they get the benefit of XLA right away without waiting for warmup)
    263   // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
    264   // most one cluster-compilation's worth of compile time).
    265   bool is_first_execution;
    266 
    267   // We avoid compiling clusters that have "gone megamorphic" i.e. have an
    268   // excessive amount of shape dynamism.
    269   bool is_megamorphic;
    270 
    271   {
    272     mutex_lock lock(cluster_compile_stats_mu_);
    273     auto it =
    274         cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
    275             .first;
    276     is_first_execution = it->second.execution_count++ == 0;
    277 
    278     // The is_megamorphic bit is "sticky".  We assume clusters that have been
    279     // observed to be megamorphic once stay megamorphic forever.
    280     it->second.is_megamorphic |=
    281         IsMegamorphic(/*compile_count=*/it->second.compile_count,
    282                       /*execution_count=*/it->second.execution_count);
    283     is_megamorphic = it->second.is_megamorphic;
    284   }
    285 
    286   // Acquire the cache entry lock and compile, if necessary.
    287   // TODO(phawkins): this locking will need to be restructured when we implement
    288   // cache eviction.
    289   mutex_lock entry_lock(entry->mu);
    290   int64 current_request_count = ++entry->request_count;
    291   VLOG(2) << "Compilation cache entry hit: " << entry->compiled
    292           << " signature: " << signature.HumanString() << " with request count "
    293           << current_request_count << " and compile threshold "
    294           << compile_threshold.value_or(0);
    295   if (!entry->compiled) {
    296     const bool should_compile = [&] {
    297       if (!compile_threshold.has_value()) {
    298         // Lazy compilation is disabled.
    299         return true;
    300       }
    301 
    302       if (is_megamorphic) {
    303         VLOG(3) << "Not compiling cluster " << function.name()
    304                 << " because it is megamorphic.";
    305         return false;
    306       }
    307 
    308       if (is_first_execution) {
    309         return true;
    310       }
    311 
    312       bool reached_compile_threshold =
    313           current_request_count >= *compile_threshold;
    314       if (!reached_compile_threshold) {
    315         VLOG(3)
    316             << "Not compiling cluster " << function.name()
    317             << " because it has not reached compile threshold; threshold is "
    318             << *compile_threshold << " execution count "
    319             << current_request_count << ".";
    320       }
    321       return reached_compile_threshold;
    322     }();
    323 
    324     if (!should_compile) {
    325       VLOG(2) << "Not compiling for signature: " << signature.HumanString();
    326       *out_compilation_result = nullptr;
    327       *out_executable = nullptr;
    328       return Status::OK();
    329     }
    330 
    331     tensorflow::Env* env = tensorflow::Env::Default();
    332     const uint64 compile_start_us = env->NowMicros();
    333     // Do the actual JIT compilation without holding the lock (it can take
    334     // a long time.)
    335 
    336     XlaCompiler compiler(options);
    337     entry->compiled = true;
    338 
    339     entry->compilation_status =
    340         compile_fn(&compiler, &entry->compilation_result);
    341     TF_RETURN_IF_ERROR(entry->compilation_status);
    342     CHECK_EQ(entry->executable.get(), nullptr);
    343     entry->compilation_status =
    344         BuildExecutable(options, entry->compilation_result, &entry->executable);
    345 
    346     const uint64 compile_end_us = env->NowMicros();
    347     const uint64 compile_time_us = compile_end_us - compile_start_us;
    348     {
    349       mutex_lock lock(cluster_compile_stats_mu_);
    350       auto it = cluster_compile_stats_.find(function.name());
    351       it->second.compile_count++;
    352       it->second.cumulative_compile_time_us += compile_time_us;
    353       VLOG(1) << "compiled " << function.name() << " "
    354               << it->second.compile_count
    355               << " times, compile time: " << compile_time_us
    356               << " us, cumulative: " << it->second.cumulative_compile_time_us
    357               << " us ("
    358               << tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
    359                                                                1.0e6)
    360               << " / "
    361               << tensorflow::strings::HumanReadableElapsedTime(
    362                      it->second.cumulative_compile_time_us / 1.0e6)
    363               << ")";
    364     }
    365   }
    366   TF_RETURN_IF_ERROR(entry->compilation_status);
    367   *out_compilation_result = &entry->compilation_result;
    368   *out_executable = entry->executable.get();
    369   return Status::OK();
    370 }
    371 
    372 }  // namespace tensorflow
    373