1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The compiler API is used by the XLA service to generate executables that 17 // run on a given platform. This is a registry and abstract interface, for 18 // pluggability by the various platforms. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 22 23 #include <functional> 24 #include <map> 25 #include <memory> 26 #include <string> 27 28 #include "tensorflow/compiler/xla/service/executable.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 31 #include "tensorflow/compiler/xla/service/logical_buffer.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 37 #include "tensorflow/core/platform/thread_annotations.h" 38 39 namespace xla { 40 41 // The following types are used for ahead of time compilation. 42 43 // Contains the object file data created as a result of ahead-of-time 44 // compuation. 45 using ObjectFileData = std::vector<char>; 46 47 // Contains the buffer sizes information needed to allocate buffers to execute 48 // an ahead-of-time computation. Entries which contain -1 designate a parameter 49 // which should be skipped over during allocation. 50 using BufferSizes = std::vector<int64>; 51 52 // Abstract superclass describing the result of an ahead-of-time compilation. 53 class AotCompilationResult { 54 public: 55 AotCompilationResult(const AotCompilationResult&) = delete; 56 AotCompilationResult& operator=(AotCompilationResult const&) = delete; 57 58 virtual ~AotCompilationResult() = default; 59 60 protected: 61 AotCompilationResult() = default; 62 }; 63 64 // Abstract superclass describing options to an ahead-of-time compilation. 65 class AotCompilationOptions { 66 public: 67 AotCompilationOptions(const AotCompilationOptions&) = delete; 68 AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; 69 70 virtual ~AotCompilationOptions() = default; 71 72 // Returns the ID of the platform to which these options apply. 73 virtual perftools::gputools::Platform::Id PlatformId() const = 0; 74 75 // Optional allocator that may be used for allocating temp space on the device 76 // during compilation. 77 DeviceMemoryAllocator* device_allocator() const { return device_allocator_; } 78 void set_device_allocator(DeviceMemoryAllocator* device_allocator) { 79 device_allocator_ = device_allocator; 80 } 81 82 protected: 83 AotCompilationOptions() = default; 84 85 private: 86 DeviceMemoryAllocator* device_allocator_ = nullptr; 87 }; 88 89 // Abstract compiler interface that is subclassed for compilation on a 90 // particular platform. 91 // 92 // The compiler ties together high level optimization (HLO) and low level 93 // optimization (LLO) / codegen (CG) to generate efficient executables for the 94 // target platform. 95 // 96 // The platform-based compiler singletons are registered via module initializers 97 // in their corresponding XLA compiler libraries, and are registered via the 98 // RegisterCompilerFactory API below. 99 // 100 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple 101 // XLA clients may be requesting compilation concurrently for a given 102 // platform. 103 class Compiler { 104 public: 105 virtual ~Compiler() {} 106 107 // Returns the ID of the platform that this compiler targets. 108 virtual perftools::gputools::Platform::Id PlatformId() const = 0; 109 110 // Runs Hlo passes to optimize the given Hlo module, returns the optimized 111 // module. 112 // 113 // If device_allocator is not null, the compiler may use it to allocate temp 114 // space on the device for use during compilation. For example, the compiler 115 // may allocate buffers on the device and then run variants of a given 116 // algorithm over those buffers, to see which variant is fastest. Any space 117 // allocated should be deallocated before this function returns. 118 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 119 std::unique_ptr<HloModule> module, 120 perftools::gputools::StreamExecutor* executor, 121 DeviceMemoryAllocator* device_allocator) = 0; 122 123 // Compiles the HLO module for execution on a device given by the executor, 124 // and returns an executable object or an error status. No HLO passes are 125 // applied to module. Generally a module should be passed through RunHloPasses 126 // prior to calling this method because the some HLO passes are required for 127 // correctness. Takes ownership of the HLO module and is free to transform it. 128 // 129 // The compiler may optionally specialize to the individual device 130 // (not just type of device) indicated by the executor. 131 // 132 // device_allocator is optional; see RunHloPasses. 133 // 134 // Use the overload below to compile computations that run in parallel. 135 virtual StatusOr<std::unique_ptr<Executable>> RunBackend( 136 std::unique_ptr<HloModule> module, 137 perftools::gputools::StreamExecutor* executor, 138 DeviceMemoryAllocator* device_allocator) = 0; 139 140 // Compiles a set of HLO modules that can run in parallel, potentially 141 // communicating data between the modules, and returns a corresponding 142 // sequence of executable objects. 143 // 144 // device_allocator is optional; see RunHloPasses. 145 // 146 // TODO(b/68666782): Remove this method after adding support for multiple 147 // modules to RunHloPasses and RunBackends. 148 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 149 std::vector<std::unique_ptr<HloModule>> modules, 150 std::vector<std::vector<perftools::gputools::StreamExecutor*>> 151 stream_exec, 152 DeviceMemoryAllocator* device_allocator) = 0; 153 154 // Compiles the HLO module for ahead-of-time execution. This is intended for 155 // use in static compilation. 156 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 157 CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, 158 const AotCompilationOptions& options) = 0; 159 160 ///// 161 // The Compiler class also serves as a point to register compiler objects 162 // for the various platforms. 163 164 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>; 165 166 // Registers the compiler singleton for the platform. This is assumed to 167 // be a singleton, so no ownership is transferred. 168 // 169 // Precondition: a platform kind must not be registered more than once. 170 static void RegisterCompilerFactory( 171 perftools::gputools::Platform::Id platform_id, 172 CompilerFactory compiler_factory); 173 174 // Returns the compiler singleton pointer if it is available for the given 175 // platform, or an error status if it is not. 176 static StatusOr<Compiler*> GetForPlatform( 177 const perftools::gputools::Platform* platform); 178 179 // Returns a function that computes the size in bytes of the logical 180 // buffer that contains a shape. 181 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; 182 183 // Returns a function that computes the size in bytes of a given 184 // logical buffer. 185 std::function<int64(const LogicalBuffer&)> BufferSizeBytesFunction() { 186 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); 187 return [shape_size](const LogicalBuffer& buffer) { 188 return shape_size(buffer.shape()); 189 }; 190 } 191 192 private: 193 // Mutex that guards the platform-compiler map. 194 static tensorflow::mutex platform_compiler_mutex_; 195 196 // Map from platform kind to compiler factory. 197 static std::map<perftools::gputools::Platform::Id, CompilerFactory>* 198 GetPlatformCompilerFactories(); 199 200 // Map from platform kind to compiler instance, if we made one already (based 201 // on the factories above). 202 static std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>* 203 GetPlatformCompilers(); 204 }; 205 206 } // namespace xla 207 208 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 209