Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include <set>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "absl/types/span.h"
     26 #include "tensorflow/compiler/xla/debug_options_flags.h"
     27 #include "tensorflow/compiler/xla/executable_run_options.h"
     28 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
     29 #include "tensorflow/compiler/xla/service/backend.h"
     30 #include "tensorflow/compiler/xla/service/channel_tracker.h"
     31 #include "tensorflow/compiler/xla/service/compilation_cache.h"
     32 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     33 #include "tensorflow/compiler/xla/service/executable.h"
     34 #include "tensorflow/compiler/xla/service/execution_tracker.h"
     35 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
     36 #include "tensorflow/compiler/xla/service/hlo_module.h"
     37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     38 #include "tensorflow/compiler/xla/service_interface.h"
     39 #include "tensorflow/compiler/xla/statusor.h"
     40 #include "tensorflow/compiler/xla/types.h"
     41 #include "tensorflow/compiler/xla/xla.pb.h"
     42 #include "tensorflow/compiler/xla/xla_data.pb.h"
     43 #include "tensorflow/core/platform/logging.h"
     44 #include "tensorflow/core/platform/macros.h"
     45 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     46 
     47 namespace xla {
     48 
     49 // Options to configure the service when it is created.
     50 class ServiceOptions {
     51  public:
     52   // Set the platform backing the service, or nullptr for the default platform.
     53   ServiceOptions& set_platform(se::Platform* platform);
     54   se::Platform* platform() const;
     55 
     56   // Set the default number of replicas to use when compiling replicated
     57   // programs.
     58   ServiceOptions& set_number_of_replicas(int number_of_replicas);
     59   int number_of_replicas() const;
     60 
     61   // Sets the thread pool size for parallel execution of an individual operator.
     62   ServiceOptions& set_intra_op_parallelism_threads(int num_threads);
     63   int intra_op_parallelism_threads() const;
     64 
     65   // Sets the allowed_devices set for selectively constructing stream executors
     66   // on the platform.
     67   ServiceOptions& set_allowed_devices(
     68       const absl::optional<std::set<int>>& allowed_devices);
     69   const absl::optional<std::set<int>>& allowed_devices() const;
     70 
     71  private:
     72   se::Platform* platform_ = nullptr;
     73   int number_of_replicas_ = 1;
     74   int intra_op_parallelism_threads_ = -1;
     75   absl::optional<std::set<int>> allowed_devices_;
     76 };
     77 
     78 // The XLA service object, which is the same across all platforms. It maintains
     79 // the service state of computations and allocations, and delegates
     80 // target-specific requests to the target-specific infrastructure
     81 // (target-specific compiler, StreamExecutor).
     82 class Service : public ServiceInterface {
     83  public:
     84   // Factory method for creating a new Service.
     85   static StatusOr<std::unique_ptr<Service>> NewService(
     86       se::Platform* platform = nullptr);
     87   static StatusOr<std::unique_ptr<Service>> NewService(
     88       const ServiceOptions& options);
     89 
     90   // Unregisters a previously-allocated global handle.
     91   //
     92   // If the handle given is not currently allocated, a NOT_FOUND status is
     93   // returned.
     94   Status Unregister(const UnregisterRequest* arg,
     95                     UnregisterResponse* result) override;
     96 
     97   // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
     98   // element in the tuple.
     99   Status DeconstructTuple(const DeconstructTupleRequest* arg,
    100                           DeconstructTupleResponse* result) override;
    101 
    102   // Compiles a computation into an executable. The request contains the whole
    103   // computation graph. Returns the handle to the executable.
    104   Status Compile(const CompileRequest* arg, CompileResponse* result) override;
    105 
    106   // Executes an executable with the provided global data passes as immutable
    107   // arguments. The request contains the handle to the executable. Returns
    108   // global data output and execution timing.
    109   Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
    110 
    111   // Executes one or more computations in parallel with the provided global data
    112   // passed as immutable arguments. Returns global data output for each
    113   // computation.
    114   Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
    115                               ExecuteParallelResponse* result) override;
    116 
    117   // Requests one or more device handles from the target.
    118   //
    119   // When N device handles are requested and the number of replicas is R, at
    120   // least N * R devices must be available. The devices are assigned based on
    121   // the device ordinals such that the first R available devices are assigned to
    122   // the first set of replicas, and the next R devices to the second set of
    123   // replicas, etc. Each returned device handle represents the device with the
    124   // replica id 0.
    125   Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
    126                           GetDeviceHandlesResponse* result) override;
    127 
    128   // Waits until the specified execution is complete and returns the result.
    129   // Calling this API multiple times with the same execution handle returns the
    130   // method with an error since the execution handle is destroyed after the
    131   // first call.
    132   Status WaitForExecution(const WaitForExecutionRequest* arg,
    133                           WaitForExecutionResponse* result) override;
    134 
    135   // Requests that global data be transferred to the client in literal form.
    136   Status TransferToClient(const TransferToClientRequest* arg,
    137                           TransferToClientResponse* result) override;
    138 
    139   // Transfers data from a literal provided by the client, into device memory.
    140   Status TransferToServer(const TransferToServerRequest* arg,
    141                           TransferToServerResponse* result) override;
    142 
    143   // Transfers data from a literal provided by the client, into the Infeed
    144   // buffer of the device.
    145   Status TransferToInfeed(const TransferToInfeedRequest* arg,
    146                           TransferToInfeedResponse* result) override;
    147 
    148   // Transfers data from the Outfeed othe device to the literal provided by the
    149   // client.
    150   Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
    151                              TransferFromOutfeedResponse* result) override;
    152 
    153   // Resets devices, clearing all existing state on all the devices associated
    154   // with this service (including memory allocated on the devices).
    155   //
    156   // ResetDevice may only be called where no previous Execution state on the
    157   // device is used by the next Execution.
    158   //
    159   // ResetDevice should be called before an Execution that expect the device to
    160   // be in the reset state. For example, if the prior Execution modifies device
    161   // state (e.g., architectural state) that the next Execution depends on.
    162   Status ResetDevice(const ResetDeviceRequest* arg,
    163                      ResetDeviceResponse* result) override;
    164 
    165   Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
    166                               ComputeConstantResponse* result) override;
    167 
    168   // Returns the shape (with layout) of an array associated with a given data
    169   // handle.
    170   Status GetShape(const GetShapeRequest* arg,
    171                   GetShapeResponse* result) override;
    172 
    173   // Retrieves the statistics of a computation.
    174   Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg,
    175                                   ComputationStatsResponse* result) override;
    176 
    177   // Creates a unique channel handle that can be used for Send/Recv
    178   // instructions.
    179   Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
    180                              CreateChannelHandleResponse* result) override;
    181 
    182   // Returns the backend used to execute computations.
    183   const Backend& backend() const { return *execute_backend_; }
    184   Backend* mutable_backend() { return execute_backend_.get(); }
    185 
    186  private:
    187   // A private overload for Service itself, used by other methods within this
    188   // class.
    189   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
    190       const ProgramShape& program_shape,
    191       absl::Span<const ShapedBuffer* const> arguments,
    192       const ExecutionOptions& execution_options);
    193 
    194   // Prepare the executors for executing parallel.
    195   StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
    196       const ExecutionOptions& execution_options, int64 requests_size,
    197       int64 request_index) const;
    198 
    199   // Prepare the arguments for executing parallel.
    200   StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
    201       const ExecutionOptions& execution_options,
    202       absl::Span<const GlobalDataHandle* const> arguments) const;
    203 
    204  protected:
    205   friend class LocalExecutable;
    206 
    207   // The constructor is private. Use the NewService factory to create new
    208   // service objects.
    209   Service(const ServiceOptions& options,
    210           std::unique_ptr<Backend> execute_backend);
    211 
    212   // Resolves the given argument handles in the allocation tracker and returns
    213   // the corresponding allocations for every replica. The function also verifies
    214   // that each allocation matches the execution platform and device ordinal of
    215   // the corresponding replica.
    216   StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
    217   ResolveAndValidateArguments(
    218       absl::Span<const GlobalDataHandle* const> arguments,
    219       absl::Span<se::StreamExecutor* const> stream_executors) const;
    220 
    221   // Create a Hlo module config for the given program shape and arguments.
    222   // execution_options is optional; if not given a default is used.
    223   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
    224       const ProgramShape& program_shape,
    225       absl::Span<const Shape* const> argument_shapes,
    226       const ExecutionOptions* execution_options);
    227 
    228   // Builds an Executable for the given parameters.
    229   //
    230   // If device_allocator is not null, the compiler may use it to allocate temp
    231   // buffers, which the compiler is responsible for freeing.  The allocator
    232   // given here need not match the allocator used when running the executable.
    233   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
    234       const HloModuleProto& module_proto,
    235       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
    236       se::StreamExecutor* executor,
    237       DeviceMemoryAllocator* device_allocator = nullptr);
    238 
    239   // Same as BuildExecutable() above, but builds a list of Executables for the
    240   // given computations that may interact with each other.
    241   StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
    242       const std::vector<const HloModuleProto*>& module_protos,
    243       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
    244       Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
    245       DeviceMemoryAllocator* device_allocator);
    246 
    247   // Runs the given executable with the given arguments and register the result
    248   // in the allocation tracker. The handle of the result from the tracker is
    249   // returned. If the parameter "profile" is not null, it points to an
    250   // ExecutionProfile object which will be filled in with profile data.
    251   StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
    252       Executable* executable,
    253       absl::Span<const std::vector<const ShapedBuffer*>> arguments,
    254       Backend* backend, const DeviceHandle& device_handle,
    255       const string& result_tag, ExecutionProfile* profile);
    256 
    257   // Runs the given executables with the given arguments and register the result
    258   // from each executable in the allocation tracker. The handles of the result
    259   // from the tracker are returned.
    260   StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
    261       absl::Span<Executable* const> executables,
    262       absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
    263       Backend* backend, absl::Span<const DeviceHandle> device_handles,
    264       absl::Span<const string> result_tags, ExecutionProfile* profile);
    265 
    266   // Convenience function which checks whether the given client_shape
    267   // (presumably passed by the client to set the result layout) is valid for the
    268   // given computation result shape.
    269   Status ValidateResultShape(const Shape& client_shape,
    270                              const Shape& result_shape) const;
    271 
    272   // Returns the stream executors assigned to the replicas represented by the
    273   // given device handle. Each device_handle is a virtual replicated device that
    274   // represents a set of physical devices for the replicas.
    275   StatusOr<std::vector<se::StreamExecutor*>> Replicas(
    276       const Backend& backend, const DeviceHandle& device_handle) const;
    277 
    278   // Returns the device handle that represents the replicated device for a
    279   // single computation that is not model-parallelized.
    280   DeviceHandle SingleComputationDeviceHandle() const;
    281 
    282   ServiceOptions options_;
    283 
    284   // Cache containing previously built Executables.
    285   CompilationCache compilation_cache_;
    286 
    287   // Tracks channels created via the API.
    288   ChannelTracker channel_tracker_;
    289 
    290   // Tracks allocations made via the API and computation execution.
    291   AllocationTracker allocation_tracker_;
    292 
    293   // Tracks asynchronously launched executions via the API.
    294   ExecutionTracker execution_tracker_;
    295 
    296   // Backend to compile and execute computations on.
    297   std::unique_ptr<Backend> execute_backend_;
    298 
    299   TF_DISALLOW_COPY_AND_ASSIGN(Service);
    300 };
    301 
    302 }  // namespace xla
    303 
    304 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
    305