Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The compiler API is used by the XLA service to generate executables that
     17 // run on a given platform. This is a registry and abstract interface, for
     18 // pluggability by the various platforms.
     19 
     20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
     21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
     22 
     23 #include <functional>
     24 #include <map>
     25 #include <memory>
     26 #include <string>
     27 #include <vector>
     28 
     29 #include "absl/types/span.h"
     30 #include "tensorflow/compiler/xla/service/buffer_value.h"
     31 #include "tensorflow/compiler/xla/service/computation_placer.h"
     32 #include "tensorflow/compiler/xla/service/executable.h"
     33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     34 #include "tensorflow/compiler/xla/service/hlo_module.h"
     35 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     36 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
     37 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     38 #include "tensorflow/compiler/xla/statusor.h"
     39 #include "tensorflow/compiler/xla/types.h"
     40 #include "tensorflow/core/platform/mutex.h"
     41 #include "tensorflow/core/platform/protobuf.h"
     42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     43 #include "tensorflow/core/platform/thread_annotations.h"
     44 
     45 namespace xla {
     46 
     47 // The following types are used for ahead of time compilation.
     48 
     49 // Contains the object file data created as a result of ahead-of-time
     50 // compuation.
     51 using ObjectFileData = std::vector<char>;
     52 
     53 // Abstract superclass describing the result of an ahead-of-time compilation.
     54 class AotCompilationResult {
     55  public:
     56   AotCompilationResult(const AotCompilationResult&) = delete;
     57   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
     58 
     59   virtual ~AotCompilationResult() = default;
     60 
     61  protected:
     62   AotCompilationResult() = default;
     63 };
     64 
     65 // Abstract superclass describing options to an ahead-of-time compilation.
     66 class AotCompilationOptions {
     67  public:
     68   AotCompilationOptions(const AotCompilationOptions&) = delete;
     69   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
     70 
     71   virtual ~AotCompilationOptions() = default;
     72 
     73   // Returns the ID of the platform to which these options apply.
     74   virtual se::Platform::Id PlatformId() const = 0;
     75 
     76   // Optional allocator that may be used for allocating temp space on the device
     77   // during compilation.
     78   DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
     79   void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
     80     device_allocator_ = device_allocator;
     81   }
     82 
     83   const DebugOptions& debug_options() const { return debug_options_; }
     84   DebugOptions* mutable_debug_options() { return &debug_options_; }
     85 
     86   bool has_static_device_assignment() const {
     87     return static_device_assignment_.has_value();
     88   }
     89   const DeviceAssignment& static_device_assignment() const {
     90     CHECK(static_device_assignment_.has_value());
     91     return *static_device_assignment_;
     92   }
     93   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
     94     static_device_assignment_ = device_assignment;
     95   }
     96 
     97  protected:
     98   AotCompilationOptions();
     99 
    100  private:
    101   DeviceMemoryAllocator* device_allocator_ = nullptr;
    102   DebugOptions debug_options_;
    103   absl::optional<DeviceAssignment> static_device_assignment_;
    104 };
    105 
    106 // Abstract superclass describing metadata produced during ahead-of-time
    107 // compilation.
    108 class AotCompilationMetadata {
    109  public:
    110   AotCompilationMetadata(const AotCompilationMetadata&) = delete;
    111   AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
    112 
    113   virtual ~AotCompilationMetadata() = default;
    114 
    115  protected:
    116   AotCompilationMetadata() = default;
    117 };
    118 
    119 // Abstract compiler interface that is subclassed for compilation on a
    120 // particular platform.
    121 //
    122 // The compiler ties together high level optimization (HLO) and low level
    123 // optimization (LLO) / codegen (CG) to generate efficient executables for the
    124 // target platform.
    125 //
    126 // The platform-based compiler singletons are registered via module initializers
    127 // in their corresponding XLA compiler libraries, and are registered via the
    128 // RegisterCompilerFactory API below.
    129 //
    130 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
    131 // XLA clients may be requesting compilation concurrently for a given
    132 // platform.
    133 class Compiler {
    134  public:
    135   virtual ~Compiler() {}
    136 
    137   // Returns the ID of the platform that this compiler targets.
    138   virtual se::Platform::Id PlatformId() const = 0;
    139 
    140   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
    141   // module.
    142   //
    143   // If device_allocator is not null, the compiler may use it to allocate temp
    144   // space on the device for use during compilation.  For example, the compiler
    145   // may allocate buffers on the device and then run variants of a given
    146   // algorithm over those buffers, to see which variant is fastest.  Any space
    147   // allocated should be deallocated before this function returns.
    148   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
    149       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
    150       DeviceMemoryAllocator* device_allocator) = 0;
    151 
    152   // Optimizes a HLO module group, a set of module which runs concurrently on
    153   // multiple devices potentially communicating data between the modules.
    154   virtual Status RunHloPassesOnModuleGroup(
    155       HloModuleGroup* module_group,
    156       absl::Span<se::StreamExecutor* const> executors,
    157       DeviceMemoryAllocator* device_allocator) = 0;
    158 
    159   // Compiles the HLO module for execution on a device given by the executor,
    160   // and returns an executable object or an error status. No HLO passes are
    161   // applied to module. Generally a module should be passed through RunHloPasses
    162   // prior to calling this method because some HLO passes are required for
    163   // correctness. Takes ownership of the HLO module.
    164   //
    165   // The compiler may optionally specialize to the individual device
    166   // (not just type of device) indicated by the executor.
    167   //
    168   // device_allocator is optional; see RunHloPasses.
    169   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
    170       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
    171       DeviceMemoryAllocator* device_allocator) = 0;
    172 
    173   // Compiles a set of HLO modules that can run in parallel, potentially
    174   // communicating data between the modules.
    175   virtual StatusOr<std::vector<std::unique_ptr<Executable>>>
    176   RunBackendOnModuleGroup(
    177       std::unique_ptr<HloModuleGroup> module_group,
    178       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
    179       DeviceMemoryAllocator* device_allocator) = 0;
    180 
    181   // Compiles a set of HLO modules that can run in parallel, potentially
    182   // communicating data between the modules, and returns a corresponding
    183   // sequence of executable objects.
    184   //
    185   // device_allocator is optional; see RunHloPasses.
    186   //
    187   // TODO(b/68666782): Remove this method after adding support for multiple
    188   // modules to RunHloPasses and RunBackends.
    189   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
    190       std::unique_ptr<HloModuleGroup> module_group,
    191       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
    192       DeviceMemoryAllocator* device_allocator) = 0;
    193 
    194   // Returns the backend configurations that the backend will consider for the
    195   // given HLO. Returns no configurations if the backend does not support
    196   // configurations for the given HLO.
    197   //
    198   // The stream executor is passed in to provide information about the hardware
    199   // that the backend configurations would be targeting.
    200   virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
    201   ComputeBackendConfigs(const HloInstruction& hlo,
    202                         se::StreamExecutor* executor) const;
    203 
    204   // Returns the backend configuration that the backend chooses by default for
    205   // the given HLO. Returns no configuration if the backend does not support
    206   // configurations for the given HLO.
    207   //
    208   // The stream executor is passed in to provide information about the hardware
    209   // that the backend configurations would be targeting.
    210   virtual std::unique_ptr<tensorflow::protobuf::Message>
    211   ComputeDefaultBackendConfig(const HloInstruction& hlo,
    212                               se::StreamExecutor* executor) const;
    213 
    214   // Compiles the HLO module group for ahead-of-time execution.  This is
    215   // intended for use in static compilation.
    216   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
    217   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
    218                      const AotCompilationOptions& options) = 0;
    219 
    220   // Similar to CompileAheadOfTime above but AotCompilationMetadata
    221   // has an argument that can be populated during compilation.
    222   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
    223   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
    224                      const AotCompilationOptions& options,
    225                      std::unique_ptr<AotCompilationMetadata>* metadata);
    226 
    227   /////
    228   // The Compiler class also serves as a point to register compiler objects
    229   // for the various platforms.
    230 
    231   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
    232 
    233   // Registers the compiler singleton for the platform. This is assumed to
    234   // be a singleton, so no ownership is transferred.
    235   //
    236   // Precondition: a platform kind must not be registered more than once.
    237   static void RegisterCompilerFactory(se::Platform::Id platform_id,
    238                                       CompilerFactory compiler_factory);
    239 
    240   // Returns the compiler singleton pointer if it is available for the given
    241   // platform, or an error status if it is not.
    242   static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
    243 
    244   // Returns a function that computes the size in bytes of the logical
    245   // buffer that contains a shape.
    246   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
    247 
    248   // Returns a function that computes the size in bytes of a given
    249   // logical buffer.
    250   std::function<int64(const BufferValue&)> BufferSizeBytesFunction() {
    251     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
    252     return [shape_size](const BufferValue& buffer) {
    253       return shape_size(buffer.shape());
    254     };
    255   }
    256 
    257  private:
    258   // Mutex that guards the platform-compiler map.
    259   static tensorflow::mutex platform_compiler_mutex_;
    260 
    261   // Map from platform kind to compiler factory.
    262   static std::map<se::Platform::Id, CompilerFactory>*
    263   GetPlatformCompilerFactories();
    264 
    265   // Map from platform kind to compiler instance, if we made one already (based
    266   // on the factories above).
    267   static std::map<se::Platform::Id, std::unique_ptr<Compiler>>*
    268   GetPlatformCompilers();
    269 };
    270 
    271 }  // namespace xla
    272 
    273 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
    274