Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/executable_run_options.h"
     25 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
     26 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
     27 #include "tensorflow/compiler/xla/service/backend.h"
     28 #include "tensorflow/compiler/xla/service/channel_tracker.h"
     29 #include "tensorflow/compiler/xla/service/compilation_cache.h"
     30 #include "tensorflow/compiler/xla/service/computation_tracker.h"
     31 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     32 #include "tensorflow/compiler/xla/service/executable.h"
     33 #include "tensorflow/compiler/xla/service/execution_tracker.h"
     34 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
     35 #include "tensorflow/compiler/xla/service/hlo_module.h"
     36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     37 #include "tensorflow/compiler/xla/service/session.pb.h"
     38 #include "tensorflow/compiler/xla/service/user_computation.h"
     39 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
     40 #include "tensorflow/compiler/xla/service_interface.h"
     41 #include "tensorflow/compiler/xla/statusor.h"
     42 #include "tensorflow/compiler/xla/types.h"
     43 #include "tensorflow/compiler/xla/xla.pb.h"
     44 #include "tensorflow/compiler/xla/xla_data.pb.h"
     45 #include "tensorflow/core/lib/gtl/array_slice.h"
     46 #include "tensorflow/core/platform/logging.h"
     47 #include "tensorflow/core/platform/macros.h"
     48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     49 
     50 namespace xla {
     51 
     52 // Options to configure the service when it is created.
     53 class ServiceOptions {
     54  public:
     55   // Set the platform backing the service, or nullptr for the default platform.
     56   ServiceOptions& set_platform(perftools::gputools::Platform* platform);
     57   perftools::gputools::Platform* platform() const;
     58 
     59   // Set the number of replicas to use when compiling replicated
     60   // programs.
     61   ServiceOptions& set_number_of_replicas(int number_of_replicas);
     62   int number_of_replicas() const;
     63 
     64   // Sets the thread pool size for parallel execution of an individual operator.
     65   ServiceOptions& set_intra_op_parallelism_threads(int num_threads);
     66   int intra_op_parallelism_threads() const;
     67 
     68  private:
     69   perftools::gputools::Platform* platform_ = nullptr;
     70   int number_of_replicas_ = 1;
     71   int intra_op_parallelism_threads_ = -1;
     72 };
     73 
     74 // The XLA service object, which is the same across all platforms. It maintains
     75 // the service state of computations and allocations, and delegates
     76 // target-specific requests to the target-specific infrastructure
     77 // (target-specific compiler, StreamExecutor).
     78 class Service : public ServiceInterface {
     79  public:
     80   // Factory method for creating a new Service.
     81   static StatusOr<std::unique_ptr<Service>> NewService(
     82       perftools::gputools::Platform* platform = nullptr);
     83   static StatusOr<std::unique_ptr<Service>> NewService(
     84       const ServiceOptions& options);
     85 
     86   // Creates a new computation with the given name.
     87   // A unique ComputationHandle is returned.
     88   tensorflow::Status Computation(const ComputationRequest* arg,
     89                                  ComputationResponse* result) override;
     90 
     91   // Unregisters a previously-allocated global handle.
     92   //
     93   // If the handle given is not currently allocated, a NOT_FOUND status is
     94   // returned.
     95   tensorflow::Status Unregister(const UnregisterRequest* arg,
     96                                 UnregisterResponse* result) override;
     97 
     98   // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
     99   // element in the tuple.
    100   tensorflow::Status DeconstructTuple(
    101       const DeconstructTupleRequest* arg,
    102       DeconstructTupleResponse* result) override;
    103 
    104   // Modifies the provided computation so that subsequent executions
    105   // will compute the provided ComputationDataHandle, rather than the
    106   // last expression enqueued on that Computation.
    107   tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg,
    108                                     SetReturnValueResponse* results) override;
    109 
    110   // Executes a computation with the provided global data passed as
    111   // immutable arguments. Returns global data output and execution timing.
    112   tensorflow::Status Execute(const ExecuteRequest* arg,
    113                              ExecuteResponse* result) override;
    114 
    115   // Executes one or more computations in parallel with the provided global data
    116   // passed as immutable arguments. Returns global data output for each
    117   // computation.
    118   tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
    119                                      ExecuteParallelResponse* result) override;
    120 
    121   // Requests one or more device handles from the target.
    122   //
    123   // When N device handles are requested and the number of replicas is R, at
    124   // least N * R devices must be available. The devices are assigned based on
    125   // the device ordinals such that the first R available devices are assigned to
    126   // the first set of replicas, and the next R devices to the second set of
    127   // replicas, etc. Each returned device handle represents the device with the
    128   // replica id 0.
    129   tensorflow::Status GetDeviceHandles(
    130       const GetDeviceHandlesRequest* arg,
    131       GetDeviceHandlesResponse* result) override;
    132 
    133   // Asynchronously executes a computation with provided arguments. Invokes
    134   // the provided computation with the provided global data passed as
    135   // immutable arguments. Returns a handle to the execution.
    136   //
    137   // (Note: The corresponding function in xla::Client was removed as part of
    138   // b/64116060, in an attempt to simplify our API.  We're keeping this around
    139   // for now in case we want to expose this to clients in a different way.)
    140   tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
    141                                   ExecuteAsyncResponse* result) override;
    142 
    143   // Waits until the specified execution is complete and returns the result.
    144   // Calling this API multiple times with the same execution handle returns the
    145   // method with an error since the execution handle is destroyed after the
    146   // first call.
    147   tensorflow::Status WaitForExecution(
    148       const WaitForExecutionRequest* arg,
    149       WaitForExecutionResponse* result) override;
    150 
    151   // Requests that global data be transferred to the client in literal form.
    152   tensorflow::Status TransferToClient(
    153       const TransferToClientRequest* arg,
    154       TransferToClientResponse* result) override;
    155 
    156   // Transfers data from a literal provided by the client, into device memory.
    157   tensorflow::Status TransferToServer(
    158       const TransferToServerRequest* arg,
    159       TransferToServerResponse* result) override;
    160 
    161   // Transfers data from a literal provided by the client, into the Infeed
    162   // buffer of the device.
    163   tensorflow::Status TransferToInfeed(
    164       const TransferToInfeedRequest* arg,
    165       TransferToInfeedResponse* result) override;
    166 
    167   // Transfers data from the Outfeed othe device to the literal provided by the
    168   // client.
    169   tensorflow::Status TransferFromOutfeed(
    170       const TransferFromOutfeedRequest* arg,
    171       TransferFromOutfeedResponse* result) override;
    172 
    173   // Resets devices, clearing all existing state on all the devices associated
    174   // with this service (including memory allocated on the devices).
    175   //
    176   // ResetDevice may only be called where no previous Execution state on the
    177   // device is used by the next Execution.
    178   //
    179   // ResetDevice should be called before an Execution that expect the device to
    180   // be in the reset state. For example, if the prior Execution modifies device
    181   // state (e.g., architectural state) that the next Execution depends on.
    182   tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
    183                                  ResetDeviceResponse* result) override;
    184 
    185   // Tests if an expression is a compile-time constant.
    186   tensorflow::Status IsConstant(const IsConstantRequest* arg,
    187                                 IsConstantResponse* result) override;
    188 
    189   // Computes the value of a constant expression.
    190   tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
    191                                      ComputeConstantResponse* result) override;
    192 
    193   // Returns the shape (with layout) of an array associated with a given data
    194   // handle.
    195   tensorflow::Status GetShape(const GetShapeRequest* arg,
    196                               GetShapeResponse* result) override;
    197 
    198   // Returns the program shape of the computation associated with the given
    199   // handle.
    200   tensorflow::Status GetComputationShape(
    201       const GetComputationShapeRequest* arg,
    202       GetComputationShapeResponse* result) override;
    203 
    204   /////
    205   // Computation-oriented methods.
    206 
    207   // Enqueues an Op on the computation.
    208   tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override;
    209 
    210   // Retrieves the inferred shape for a value within a computation.
    211   tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
    212                                    GetLocalShapeResponse* result) override;
    213 
    214   // Retrieves the statistics of a computation.
    215   tensorflow::Status GetComputationStats(
    216       const ComputationStatsRequest* arg,
    217       ComputationStatsResponse* result) override;
    218 
    219   // Snapshots the current state of a computation handle into a serializable
    220   // protocol buffer form, so it can be loaded via
    221   // LoadComputationSnapshot.
    222   tensorflow::Status SnapshotComputation(
    223       const SnapshotComputationRequest* arg,
    224       SnapshotComputationResponse* result) override;
    225 
    226   // Loads a computation from a serialized protocol buffer created via
    227   // SnapshotComputation.
    228   tensorflow::Status LoadComputationSnapshot(
    229       const LoadComputationSnapshotRequest* arg,
    230       LoadComputationSnapshotResponse* result) override;
    231 
    232   // Creates a unique channel handle that can be used for Send/Recv
    233   // instructions.
    234   tensorflow::Status CreateChannelHandle(
    235       const CreateChannelHandleRequest* arg,
    236       CreateChannelHandleResponse* result) override;
    237 
    238   // Returns the ComputationTracker of the current service instance.
    239   // Only used in unit tests to access user computations from client.
    240   const ComputationTracker& computation_tracker() {
    241     return computation_tracker_;
    242   }
    243 
    244   // Returns the backend used to execute computations.
    245   const Backend& backend() const { return *execute_backend_; }
    246   Backend* mutable_backend() { return execute_backend_.get(); }
    247 
    248  private:
    249   // A private overload for Service itself, used by other methods within this
    250   // class.
    251   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
    252       const ProgramShape& program_shape,
    253       tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
    254       const ExecutionOptions& execution_options,
    255       const UserComputation& user_computation);
    256 
    257  protected:
    258   friend class LocalExecutable;
    259 
    260   // The constructor is private. Use the NewService factory to create new
    261   // service objects.
    262   Service(const ServiceOptions& options,
    263           std::unique_ptr<Backend> execute_backend);
    264 
    265   static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
    266 
    267   // Resolves the given argument handles in the allocation tracker and returns
    268   // the corresponding allocations. The function also verifies that each
    269   // allocation matches the execution platform and device ordinal.
    270   StatusOr<std::vector<const ShapedBuffer*>> ResolveAndValidateArguments(
    271       tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
    272       int device_ordinal);
    273 
    274   // Create a Hlo module config for the given program shape and arguments.
    275   // execution_options is optional; if not given a default is used.
    276   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
    277       const ProgramShape& program_shape,
    278       tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
    279       const ExecutionOptions* execution_options,
    280       const UserComputation& user_computation);
    281 
    282   // Builds an Executable for the given parameters.
    283   //
    284   // If device_allocator is not null, the compiler may use it to allocate temp
    285   // buffers, which the compiler is responsible for freeing.  The allocator
    286   // given here need not match the allocator used when running the executable.
    287   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
    288       const VersionedComputationHandle& versioned_handle,
    289       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
    290       perftools::gputools::StreamExecutor* executor,
    291       DeviceMemoryAllocator* device_allocator = nullptr);
    292 
    293   // Same as BuildExecutable() above, but builds a list of Executables for the
    294   // given computations that may interact with each other.
    295   StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
    296       std::vector<VersionedComputationHandle> versioned_handles,
    297       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
    298       Backend* backend,
    299       std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
    300       DeviceMemoryAllocator* device_allocator);
    301 
    302   // Similar to BuildExecutable, but look in the compilation cache for the
    303   // executable first. If the executable is not in the cache, it is built and
    304   // inserted into the cache.
    305   StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
    306       const VersionedComputationHandle& versioned_handle,
    307       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
    308       perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
    309       DeviceMemoryAllocator* device_allocator = nullptr);
    310 
    311   // Runs the given executable with the given arguments and register the result
    312   // in the allocation tracker. The handle of the result from the tracker is
    313   // returned. If the parameter "profile" is not null, it points to an
    314   // ExecutionProfile object which will be filled in with profile data.
    315   StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
    316       Executable* executable,
    317       const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
    318       Backend* backend, perftools::gputools::StreamExecutor* executor,
    319       const string& result_tag, ExecutionProfile* profile);
    320 
    321   // Runs the given executables with the given arguments and register the result
    322   // from each executable in the allocation tracker. The handles of the result
    323   // from the tracker are returned.
    324   StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
    325       tensorflow::gtl::ArraySlice<Executable*> executables,
    326       tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>> arguments,
    327       Backend* backend,
    328       tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
    329       tensorflow::gtl::ArraySlice<string> result_tags,
    330       ExecutionProfile* profile);
    331 
    332   // Convenience function for adding a function to a user computation.
    333   template <typename RequestT, typename ResponseT>
    334   tensorflow::Status AddInstruction(
    335       const RequestT* arg, ResponseT* result,
    336       const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
    337           adder);
    338 
    339   // Convenience function which checks whether the given shape_with_layout
    340   // (presumably passed by the client to set the result layout) is valid for the
    341   // given computation result shape.
    342   tensorflow::Status ValidateResultShapeWithLayout(
    343       const Shape& shape_with_layout, const Shape& result_shape) const;
    344 
    345   // Returns the stream executors assigned to the replicas represented by the
    346   // given device handle. Each device_handle is a virtual replicated device that
    347   // represents a set of physical devices for the replicas.
    348   StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
    349       const Backend& backend, const DeviceHandle& device_handle) const;
    350 
    351   Status MaybeDumpHloModule(const HloModule& module) const;
    352 
    353   // Returns the device handle that represents the replicated device for a
    354   // single computation that is not model-parallelized.
    355   DeviceHandle SingleComputationDeviceHandle() const;
    356 
    357   ServiceOptions options_;
    358 
    359   // Tracks computations built via the API.
    360   ComputationTracker computation_tracker_;
    361 
    362   // Tracks channels created via the API.
    363   ChannelTracker channel_tracker_;
    364 
    365   // Tracks allocations made via the API and computation execution.
    366   AllocationTracker allocation_tracker_;
    367 
    368   // Tracks asynchronously launched executions via the API.
    369   ExecutionTracker execution_tracker_;
    370 
    371   // Cache containing previously built Executables.
    372   CompilationCache compilation_cache_;
    373 
    374   // Backend to compile and execute computations on.
    375   //
    376   // TODO(b/28616830): Support multiple backends for execution.
    377   std::unique_ptr<Backend> execute_backend_;
    378 
    379   TF_DISALLOW_COPY_AND_ASSIGN(Service);
    380 };
    381 
    382 }  // namespace xla
    383 
    384 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
    385