1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/xla_compilation_cache.h" 17 18 #include <numeric> 19 20 #include "absl/strings/str_cat.h" 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/compiler/tf2xla/shape_util.h" 23 #include "tensorflow/compiler/tf2xla/type_util.h" 24 #include "tensorflow/compiler/tf2xla/xla_context.h" 25 #include "tensorflow/compiler/xla/client/client_library.h" 26 #include "tensorflow/core/common_runtime/device.h" 27 #include "tensorflow/core/common_runtime/function.h" 28 #include "tensorflow/core/common_runtime/graph_optimizer.h" 29 #include "tensorflow/core/framework/attr_value_util.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/graph/graph_constructor.h" 32 #include "tensorflow/core/graph/node_builder.h" 33 #include "tensorflow/core/lib/hash/hash.h" 34 #include "tensorflow/core/platform/env.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/public/version.h" 37 #include "tensorflow/core/util/dump_graph.h" 38 39 namespace tensorflow { 40 41 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; 42 43 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, 44 DeviceType device_type) 45 : client_(client), device_type_(std::move(device_type)) {} 46 47 XlaCompilationCache::~XlaCompilationCache() { 48 // Ensure any use of our programs have completed by waiting for all stream 49 // executors to complete. 50 for (auto* executor : client_->backend().stream_executors()) { 51 bool ok = executor->SynchronizeAllActivity(); 52 if (!ok) { 53 LOG(ERROR) << "Error synchronizing activity while waiting for all " 54 "programs to complete"; 55 } 56 } 57 // TODO(b/110813685): Think about the program ownership model. Programs are 58 // currently owned by the compilation cache which means we must wait for 59 // program completion in the destructor. There are multiple compilation caches 60 // around, which complicates things a little. Perhaps having programs be 61 // shared_ptrs (an invasive change) would make the model easier to reason 62 // about? 63 } 64 65 string XlaCompilationCache::DebugString() const { 66 return "XLA JIT compilation cache"; 67 } 68 69 // Compute a string signature which encodes the shapes of the 70 // arguments in the supplied list. 71 string XlaCompilationCache::Signature::HumanString() const { 72 string result = name; 73 for (const auto& a : arg_shapes) { 74 absl::StrAppend(&result, ",", DataTypeString(a.first)); 75 absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]"); 76 } 77 78 for (const auto& v : arg_values) { 79 absl::StrAppend(&result, "; ", v.DebugString()); 80 } 81 return result; 82 } 83 84 bool XlaCompilationCache::Signature::operator==(const Signature& other) const { 85 if (name != other.name) return false; 86 if (arg_shapes != other.arg_shapes) return false; 87 88 if (arg_values.size() != other.arg_values.size()) return false; 89 for (int i = 0; i < arg_values.size(); ++i) { 90 if (arg_values[i].dtype() != other.arg_values[i].dtype() || 91 arg_values[i].shape() != other.arg_values[i].shape() || 92 arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { 93 return false; 94 } 95 } 96 return true; 97 } 98 99 uint64 XlaCompilationCache::Signature::Hash::operator()( 100 const XlaCompilationCache::Signature& signature) const { 101 uint64 h = std::hash<string>()(signature.name); 102 for (const auto& arg : signature.arg_shapes) { 103 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first))); 104 h = Hash64Combine(h, std::hash<int>()(arg.second.size())); 105 for (int dim : arg.second) { 106 h = Hash64Combine(h, std::hash<int>()(dim)); 107 } 108 } 109 for (const auto& arg : signature.arg_values) { 110 h = Hash64Combine( 111 h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); 112 } 113 return h; 114 } 115 116 xla::StatusOr<XlaCompilationCache::Signature> 117 XlaCompilationCache::BuildSignature( 118 const NameAttrList& function, 119 absl::Span<const XlaCompiler::Argument> args) { 120 Signature signature; 121 signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); 122 for (const XlaCompiler::Argument& arg : args) { 123 switch (arg.kind) { 124 case XlaCompiler::Argument::kConstant: 125 signature.arg_values.push_back(arg.constant_value); 126 break; 127 case XlaCompiler::Argument::kParameter: 128 case XlaCompiler::Argument::kResource: 129 signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes()); 130 break; 131 default: 132 return errors::InvalidArgument( 133 "Unhandled argument kind in XlaCompilationCache: ", 134 arg.HumanString()); 135 } 136 } 137 return std::move(signature); 138 } 139 140 Status XlaCompilationCache::BuildExecutable( 141 const XlaCompiler::Options& options, 142 const XlaCompiler::CompilationResult& result, 143 std::unique_ptr<xla::LocalExecutable>* executable) { 144 VLOG(2) << "Compiling to local executable"; 145 146 std::vector<const xla::Shape*> argument_layouts( 147 result.xla_input_shapes.size()); 148 for (int i = 0; i < result.xla_input_shapes.size(); ++i) { 149 argument_layouts[i] = &result.xla_input_shapes[i]; 150 } 151 xla::ExecutableBuildOptions build_options; 152 build_options.set_device_ordinal(options.device_ordinal != -1 153 ? options.device_ordinal 154 : client_->default_device_ordinal()); 155 build_options.set_result_layout(result.xla_output_shape); 156 build_options.set_device_allocator(options.device_allocator); 157 158 auto compile_result = 159 client_->Compile(*result.computation, argument_layouts, build_options); 160 if (!compile_result.ok()) { 161 return compile_result.status(); 162 } 163 *executable = std::move(compile_result.ValueOrDie()); 164 return Status::OK(); 165 } 166 167 Status XlaCompilationCache::Compile( 168 const XlaCompiler::Options& options, const NameAttrList& function, 169 absl::Span<const XlaCompiler::Argument> args, 170 const XlaCompiler::CompileOptions& compile_options, 171 CompileMode compile_mode, 172 const XlaCompiler::CompilationResult** out_compilation_result, 173 xla::LocalExecutable** out_executable) { 174 absl::optional<int64> compile_threshold; 175 if (compile_mode == CompileMode::kLazy) { 176 compile_threshold = kDefaultCompilationThreshold; 177 } 178 auto compile_fn = [&](XlaCompiler* compiler, 179 XlaCompiler::CompilationResult* result) { 180 return compiler->CompileFunction(compile_options, function, args, result); 181 }; 182 return CompileImpl(options, function, args, compile_fn, 183 /*compile_threshold=*/compile_threshold, 184 out_compilation_result, out_executable); 185 } 186 187 static bool IsMegamorphic(int64 compile_count, int64 execution_count) { 188 const int64 kCompileThreshold = 10; 189 const int64 kMinExecutionsPerCompile = 50; 190 191 // This heuristic is trying to capture the following property: have we sunk a 192 // certain minimum amount of compile time into the cluster that didn't quite 193 // "pay off"? 194 return compile_count > kCompileThreshold && 195 execution_count < kMinExecutionsPerCompile * compile_count; 196 } 197 198 Status XlaCompilationCache::CompileSingleOp( 199 const XlaCompiler::Options& options, 200 absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, 201 const XlaCompiler::CompileOptions& compile_options, 202 const XlaCompiler::CompilationResult** out_compilation_result, 203 xla::LocalExecutable** out_executable) { 204 const NodeDef& def = ctx->op_kernel().def(); 205 NameAttrList name; 206 name.set_name(def.op()); 207 *name.mutable_attr() = def.attr(); 208 // Remove the "_class" attribute from the attribute set used to create the 209 // compilation cache key. This attribute is information for the colocator 210 // and causes false uniqueness between nodes. 211 name.mutable_attr()->erase("_class"); 212 auto compile_op = [&](XlaCompiler* compiler, 213 XlaCompiler::CompilationResult* result) { 214 std::vector<DataType> result_dtypes(ctx->num_outputs()); 215 for (int i = 0; i < result_dtypes.size(); ++i) { 216 result_dtypes[i] = ctx->expected_output_dtype(i); 217 } 218 return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), 219 args, result_dtypes, result); 220 }; 221 return CompileImpl(options, name, args, compile_op, 222 /*compile_threshold=*/absl::nullopt, 223 out_compilation_result, out_executable); 224 } 225 226 Status XlaCompilationCache::CompileImpl( 227 const XlaCompiler::Options& options, const NameAttrList& function, 228 absl::Span<const XlaCompiler::Argument> args, 229 const std::function<Status(XlaCompiler* compiler, 230 XlaCompiler::CompilationResult*)>& compile_fn, 231 absl::optional<int64> compile_threshold, 232 const XlaCompiler::CompilationResult** out_compilation_result, 233 xla::LocalExecutable** out_executable) { 234 DCHECK_NE(out_executable, nullptr); 235 VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); 236 237 if (VLOG_IS_ON(2)) { 238 VLOG(2) << "num_inputs=" << args.size(); 239 for (int i = 0; i < args.size(); i++) { 240 VLOG(2) << i << ": " << args[i].HumanString(); 241 } 242 } 243 244 TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); 245 VLOG(2) << "Signature: " << signature.HumanString(); 246 247 // The outer lock protects the existence of the cache entry. It does not 248 // protect the contents of the cache entry. 249 Entry* entry; 250 { 251 mutex_lock lock(compile_cache_mu_); 252 // Find or create a cache entry. 253 std::unique_ptr<Entry>& e = cache_[signature]; 254 if (!e) { 255 e.reset(new Entry); 256 } 257 entry = e.get(); 258 } 259 260 // We always compile a cluster the very first time it is executed. This is an 261 // optimistic guess that pays off for statically shaped TensorFlow graphs 262 // (since they get the benefit of XLA right away without waiting for warmup) 263 // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at 264 // most one cluster-compilation's worth of compile time). 265 bool is_first_execution; 266 267 // We avoid compiling clusters that have "gone megamorphic" i.e. have an 268 // excessive amount of shape dynamism. 269 bool is_megamorphic; 270 271 { 272 mutex_lock lock(cluster_compile_stats_mu_); 273 auto it = 274 cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) 275 .first; 276 is_first_execution = it->second.execution_count++ == 0; 277 278 // The is_megamorphic bit is "sticky". We assume clusters that have been 279 // observed to be megamorphic once stay megamorphic forever. 280 it->second.is_megamorphic |= 281 IsMegamorphic(/*compile_count=*/it->second.compile_count, 282 /*execution_count=*/it->second.execution_count); 283 is_megamorphic = it->second.is_megamorphic; 284 } 285 286 // Acquire the cache entry lock and compile, if necessary. 287 // TODO(phawkins): this locking will need to be restructured when we implement 288 // cache eviction. 289 mutex_lock entry_lock(entry->mu); 290 int64 current_request_count = ++entry->request_count; 291 VLOG(2) << "Compilation cache entry hit: " << entry->compiled 292 << " signature: " << signature.HumanString() << " with request count " 293 << current_request_count << " and compile threshold " 294 << compile_threshold.value_or(0); 295 if (!entry->compiled) { 296 const bool should_compile = [&] { 297 if (!compile_threshold.has_value()) { 298 // Lazy compilation is disabled. 299 return true; 300 } 301 302 if (is_megamorphic) { 303 VLOG(3) << "Not compiling cluster " << function.name() 304 << " because it is megamorphic."; 305 return false; 306 } 307 308 if (is_first_execution) { 309 return true; 310 } 311 312 bool reached_compile_threshold = 313 current_request_count >= *compile_threshold; 314 if (!reached_compile_threshold) { 315 VLOG(3) 316 << "Not compiling cluster " << function.name() 317 << " because it has not reached compile threshold; threshold is " 318 << *compile_threshold << " execution count " 319 << current_request_count << "."; 320 } 321 return reached_compile_threshold; 322 }(); 323 324 if (!should_compile) { 325 VLOG(2) << "Not compiling for signature: " << signature.HumanString(); 326 *out_compilation_result = nullptr; 327 *out_executable = nullptr; 328 return Status::OK(); 329 } 330 331 tensorflow::Env* env = tensorflow::Env::Default(); 332 const uint64 compile_start_us = env->NowMicros(); 333 // Do the actual JIT compilation without holding the lock (it can take 334 // a long time.) 335 336 XlaCompiler compiler(options); 337 entry->compiled = true; 338 339 entry->compilation_status = 340 compile_fn(&compiler, &entry->compilation_result); 341 TF_RETURN_IF_ERROR(entry->compilation_status); 342 CHECK_EQ(entry->executable.get(), nullptr); 343 entry->compilation_status = 344 BuildExecutable(options, entry->compilation_result, &entry->executable); 345 346 const uint64 compile_end_us = env->NowMicros(); 347 const uint64 compile_time_us = compile_end_us - compile_start_us; 348 { 349 mutex_lock lock(cluster_compile_stats_mu_); 350 auto it = cluster_compile_stats_.find(function.name()); 351 it->second.compile_count++; 352 it->second.cumulative_compile_time_us += compile_time_us; 353 VLOG(1) << "compiled " << function.name() << " " 354 << it->second.compile_count 355 << " times, compile time: " << compile_time_us 356 << " us, cumulative: " << it->second.cumulative_compile_time_us 357 << " us (" 358 << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / 359 1.0e6) 360 << " / " 361 << tensorflow::strings::HumanReadableElapsedTime( 362 it->second.cumulative_compile_time_us / 1.0e6) 363 << ")"; 364 } 365 } 366 TF_RETURN_IF_ERROR(entry->compilation_status); 367 *out_compilation_result = &entry->compilation_result; 368 *out_executable = entry->executable.get(); 369 return Status::OK(); 370 } 371 372 } // namespace tensorflow 373