Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The compiler API is used by the XLA service to generate executables that
     17 // run on a given platform. This is a registry and abstract interface, for
     18 // pluggability by the various platforms.
     19 
     20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
     21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
     22 
     23 #include <functional>
     24 #include <map>
     25 #include <memory>
     26 #include <string>
     27 
     28 #include "tensorflow/compiler/xla/service/executable.h"
     29 #include "tensorflow/compiler/xla/service/hlo_module.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     32 #include "tensorflow/compiler/xla/statusor.h"
     33 #include "tensorflow/compiler/xla/types.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/platform/mutex.h"
     36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     37 #include "tensorflow/core/platform/thread_annotations.h"
     38 
     39 namespace xla {
     40 
     41 // The following types are used for ahead of time compilation.
     42 
     43 // Contains the object file data created as a result of ahead-of-time
     44 // compuation.
     45 using ObjectFileData = std::vector<char>;
     46 
     47 // Contains the buffer sizes information needed to allocate buffers to execute
     48 // an ahead-of-time computation.  Entries which contain -1 designate a parameter
     49 // which should be skipped over during allocation.
     50 using BufferSizes = std::vector<int64>;
     51 
     52 // Abstract superclass describing the result of an ahead-of-time compilation.
     53 class AotCompilationResult {
     54  public:
     55   AotCompilationResult(const AotCompilationResult&) = delete;
     56   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
     57 
     58   virtual ~AotCompilationResult() = default;
     59 
     60  protected:
     61   AotCompilationResult() = default;
     62 };
     63 
     64 // Abstract superclass describing options to an ahead-of-time compilation.
     65 class AotCompilationOptions {
     66  public:
     67   AotCompilationOptions(const AotCompilationOptions&) = delete;
     68   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
     69 
     70   virtual ~AotCompilationOptions() = default;
     71 
     72   // Returns the ID of the platform to which these options apply.
     73   virtual perftools::gputools::Platform::Id PlatformId() const = 0;
     74 
     75   // Optional allocator that may be used for allocating temp space on the device
     76   // during compilation.
     77   DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
     78   void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
     79     device_allocator_ = device_allocator;
     80   }
     81 
     82  protected:
     83   AotCompilationOptions() = default;
     84 
     85  private:
     86   DeviceMemoryAllocator* device_allocator_ = nullptr;
     87 };
     88 
     89 // Abstract compiler interface that is subclassed for compilation on a
     90 // particular platform.
     91 //
     92 // The compiler ties together high level optimization (HLO) and low level
     93 // optimization (LLO) / codegen (CG) to generate efficient executables for the
     94 // target platform.
     95 //
     96 // The platform-based compiler singletons are registered via module initializers
     97 // in their corresponding XLA compiler libraries, and are registered via the
     98 // RegisterCompilerFactory API below.
     99 //
    100 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
    101 // XLA clients may be requesting compilation concurrently for a given
    102 // platform.
    103 class Compiler {
    104  public:
    105   virtual ~Compiler() {}
    106 
    107   // Returns the ID of the platform that this compiler targets.
    108   virtual perftools::gputools::Platform::Id PlatformId() const = 0;
    109 
    110   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
    111   // module.
    112   //
    113   // If device_allocator is not null, the compiler may use it to allocate temp
    114   // space on the device for use during compilation.  For example, the compiler
    115   // may allocate buffers on the device and then run variants of a given
    116   // algorithm over those buffers, to see which variant is fastest.  Any space
    117   // allocated should be deallocated before this function returns.
    118   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
    119       std::unique_ptr<HloModule> module,
    120       perftools::gputools::StreamExecutor* executor,
    121       DeviceMemoryAllocator* device_allocator) = 0;
    122 
    123   // Compiles the HLO module for execution on a device given by the executor,
    124   // and returns an executable object or an error status. No HLO passes are
    125   // applied to module. Generally a module should be passed through RunHloPasses
    126   // prior to calling this method because the some HLO passes are required for
    127   // correctness. Takes ownership of the HLO module and is free to transform it.
    128   //
    129   // The compiler may optionally specialize to the individual device
    130   // (not just type of device) indicated by the executor.
    131   //
    132   // device_allocator is optional; see RunHloPasses.
    133   //
    134   // Use the overload below to compile computations that run in parallel.
    135   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
    136       std::unique_ptr<HloModule> module,
    137       perftools::gputools::StreamExecutor* executor,
    138       DeviceMemoryAllocator* device_allocator) = 0;
    139 
    140   // Compiles a set of HLO modules that can run in parallel, potentially
    141   // communicating data between the modules, and returns a corresponding
    142   // sequence of executable objects.
    143   //
    144   // device_allocator is optional; see RunHloPasses.
    145   //
    146   // TODO(b/68666782): Remove this method after adding support for multiple
    147   // modules to RunHloPasses and RunBackends.
    148   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
    149       std::vector<std::unique_ptr<HloModule>> modules,
    150       std::vector<std::vector<perftools::gputools::StreamExecutor*>>
    151           stream_exec,
    152       DeviceMemoryAllocator* device_allocator) = 0;
    153 
    154   // Compiles the HLO module for ahead-of-time execution.  This is intended for
    155   // use in static compilation.
    156   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
    157   CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
    158                      const AotCompilationOptions& options) = 0;
    159 
    160   /////
    161   // The Compiler class also serves as a point to register compiler objects
    162   // for the various platforms.
    163 
    164   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
    165 
    166   // Registers the compiler singleton for the platform. This is assumed to
    167   // be a singleton, so no ownership is transferred.
    168   //
    169   // Precondition: a platform kind must not be registered more than once.
    170   static void RegisterCompilerFactory(
    171       perftools::gputools::Platform::Id platform_id,
    172       CompilerFactory compiler_factory);
    173 
    174   // Returns the compiler singleton pointer if it is available for the given
    175   // platform, or an error status if it is not.
    176   static StatusOr<Compiler*> GetForPlatform(
    177       const perftools::gputools::Platform* platform);
    178 
    179   // Returns a function that computes the size in bytes of the logical
    180   // buffer that contains a shape.
    181   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
    182 
    183   // Returns a function that computes the size in bytes of a given
    184   // logical buffer.
    185   std::function<int64(const LogicalBuffer&)> BufferSizeBytesFunction() {
    186     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
    187     return [shape_size](const LogicalBuffer& buffer) {
    188       return shape_size(buffer.shape());
    189     };
    190   }
    191 
    192  private:
    193   // Mutex that guards the platform-compiler map.
    194   static tensorflow::mutex platform_compiler_mutex_;
    195 
    196   // Map from platform kind to compiler factory.
    197   static std::map<perftools::gputools::Platform::Id, CompilerFactory>*
    198   GetPlatformCompilerFactories();
    199 
    200   // Map from platform kind to compiler instance, if we made one already (based
    201   // on the factories above).
    202   static std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>*
    203   GetPlatformCompilers();
    204 };
    205 
    206 }  // namespace xla
    207 
    208 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
    209