1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/service/executable.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/llvm_compiler.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/lib/gtl/optional.h" 30 #include "tensorflow/core/lib/hash/hash.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 36 namespace xla { 37 namespace gpu { 38 39 // The GPU compiler generates efficient GPU executables. 40 class GpuCompiler : public LLVMCompiler { 41 public: 42 GpuCompiler(); 43 ~GpuCompiler() override {} 44 45 // Bring in 46 // StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 47 // std::vector<std::unique_ptr<HloModule>> modules, 48 // std::vector<std::vector<perftools::gputools::StreamExecutor*>> 49 // stream_execs) 50 using LLVMCompiler::Compile; 51 52 StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 53 std::unique_ptr<HloModule> module, 54 perftools::gputools::StreamExecutor* stream_exec, 55 DeviceMemoryAllocator* device_allocator) override; 56 57 StatusOr<std::unique_ptr<Executable>> RunBackend( 58 std::unique_ptr<HloModule> module, 59 perftools::gputools::StreamExecutor* stream_exec, 60 DeviceMemoryAllocator* device_allocator) override; 61 62 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 63 CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module, 64 AotCompilationOptions const& options) override; 65 66 perftools::gputools::Platform::Id PlatformId() const override; 67 68 HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { 69 // Capture just the pointer size, not the entire GpuCompiler object. 70 int64 pointer_size = pointer_size_; 71 return [pointer_size](const Shape& shape) { 72 return ShapeUtil::ByteSizeOf(shape, pointer_size); 73 }; 74 } 75 76 // The triple that represents our target. 77 static const char* kTargetTriple; 78 79 // The data layout of the emitted module. Copied from computeDataLayout in 80 // NVPTXTargetMachine.cpp. 81 static const char* kDataLayout; 82 83 private: 84 // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. 85 const int64 pointer_size_; 86 87 tensorflow::mutex mutex_; 88 89 // When compiling an HLO module, we need to find a path to the nvvm libdevice 90 // files. We search in the module's config.debug_options().cuda_data_dir() 91 // and in tensorflow::LibdeviceRoot(), the latter of which is a constant. 92 // 93 // We cache the cuda_data_dir() and the result of our search, so that if the 94 // next module we have to compile has the same cuda_data_dir(), we can skip 95 // the search. 96 string cached_cuda_data_dir_ GUARDED_BY(mutex_); 97 string cached_libdevice_dir_ GUARDED_BY(mutex_); 98 99 // Tries to compile the given ptx string to cubin. Returns a vector with the 100 // compiled cubin. If compilation was unsuccessful, returns an empty vector. 101 std::vector<uint8> CompilePtxOrGetCachedResult(const string& ptx, 102 int cc_major, int cc_minor); 103 104 // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} 105 // -> cubin so we don't recompile the same ptx twice. This is important for 106 // some interactive workflows. (We also cache at the HLO level, but sometimes 107 // we can't realize that two modules are the same until we lower to ptx.) 108 // 109 // Compilation of distinct PTX happens in parallel. If more than one thread 110 // attempts to compile the same PTX, the fist thread to obtain 111 // cache_value_->mutex_ performs the compilation. The rest wait() on 112 // cache_value_->compilation_done_cv_ until the compilation is done. 113 // 114 // If compiling the ptx fails, we return an empty cubin, cross our fingers, 115 // and leave compilation up to the driver. 116 struct CompilationCacheKey { 117 CompilationCacheKey(std::string ptx, int cc_major, int cc_minor) 118 : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {} 119 string ptx; 120 int cc_major; 121 int cc_minor; 122 }; 123 struct CompilationCacheHash { 124 size_t operator()(const CompilationCacheKey& key) const { 125 return tensorflow::Hash64Combine( 126 tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major), 127 key.cc_minor); 128 } 129 }; 130 struct CompilationCacheEq { 131 size_t operator()(const CompilationCacheKey& a, 132 const CompilationCacheKey& b) const { 133 return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && 134 a.ptx == b.ptx; 135 } 136 }; 137 struct CompilationCacheValue { 138 bool compilation_done = false; 139 std::vector<uint8> cubin_data; 140 // mutex and condition variable to serialize compilation completing. 141 tensorflow::mutex mutex_; 142 tensorflow::condition_variable compilation_done_cv_; 143 }; 144 145 // Don't even think about switching this to FlatMap; iterator stability is 146 // critical here. 147 std::unordered_map<CompilationCacheKey, CompilationCacheValue, 148 CompilationCacheHash, CompilationCacheEq> 149 compilation_cache_ GUARDED_BY(mutex_); 150 151 TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); 152 }; 153 154 } // namespace gpu 155 } // namespace xla 156 157 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 158