Home | History | Annotate | Download | only in client
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
     18 
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/client/client.h"
     22 #include "tensorflow/compiler/xla/client/computation.h"
     23 #include "tensorflow/compiler/xla/client/executable_build_options.h"
     24 #include "tensorflow/compiler/xla/executable_run_options.h"
     25 #include "tensorflow/compiler/xla/service/compiler.h"
     26 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     27 #include "tensorflow/compiler/xla/service/executable.h"
     28 #include "tensorflow/compiler/xla/service/local_service.h"
     29 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/compiler/xla/xla_data.pb.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     34 
     35 namespace xla {
     36 
     37 class LocalExecutable {
     38  public:
     39   // Run the compiled computation with the given arguments and options and
     40   // return the result.
     41   StatusOr<std::unique_ptr<ScopedShapedBuffer>> Run(
     42       const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
     43       ExecutableRunOptions run_options);
     44 
     45   // Return the layout (contained in a shape) of the result produced by the
     46   // computation.
     47   const Shape& result_layout() const {
     48     return executable_->module_config()
     49         .entry_computation_layout()
     50         .result_layout()
     51         .shape();
     52   }
     53 
     54   // Return the options used to build the executable.
     55   const ExecutableBuildOptions& build_options() const { return build_options_; }
     56 
     57   // Return the built executable.
     58   Executable* executable() const { return executable_.get(); }
     59 
     60  private:
     61   // Only a local client can construct these objects.
     62   friend class LocalClient;
     63 
     64   // Constructor invoked by LocalClient.
     65   LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
     66                   ExecutableBuildOptions build_options);
     67 
     68   // Validates that the given arguments and options satisfy various constraints
     69   // of the computation.
     70   tensorflow::Status ValidateExecutionOptions(
     71       const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
     72       const ExecutableRunOptions& options, const Backend& backend);
     73 
     74   // Records the computation in a SessionModule proto with the arguments used to
     75   // invoke it, and the result. Enabled by flag: --tla_dump_executions_to.
     76   StatusOr<std::unique_ptr<ScopedShapedBuffer>> ExecuteAndDump(
     77       const ServiceExecutableRunOptions* run_options,
     78       const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
     79 
     80   // Records the arguments used to invoke the computation in a SessionModule
     81   // proto.
     82   tensorflow::Status RecordArguments(
     83       const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
     84       SessionModule* session_module);
     85 
     86   // Records the result of the computation in a SessionModule proto.
     87   tensorflow::Status RecordResult(const ShapedBuffer* result,
     88                                   SessionModule* session_module);
     89 
     90   // Returns a literal containing the contents of the given ShapedBuffer.
     91   StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
     92       const ShapedBuffer& shaped_buffer);
     93 
     94   // The ordinal of the device which this executable was compiled for. The
     95   // executable can run on all equivalent devices (as determined by
     96   // Backend::devices_equivalent).
     97   int build_device_ordinal() const { return build_options_.device_ordinal(); }
     98 
     99   // Compiled computation.
    100   std::unique_ptr<Executable> executable_;
    101 
    102   // Execution backend.
    103   Backend* backend_ = nullptr;
    104 
    105   // Options used to build the executable.
    106   const ExecutableBuildOptions build_options_;
    107 };
    108 
    109 // An XLA Client specialization for use when the client and service run in
    110 // the same process.
    111 class LocalClient : public Client {
    112  public:
    113   explicit LocalClient(LocalService* service)
    114       : Client(service), local_service_(service) {}
    115 
    116   LocalClient(const LocalClient&) = delete;
    117   void operator=(const LocalClient&) = delete;
    118 
    119   // Build and return a LocalExecutable object. The executable is compiled using
    120   // the given argument layouts and options.
    121   StatusOr<std::unique_ptr<LocalExecutable>> Compile(
    122       const Computation& computation,
    123       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
    124       const ExecutableBuildOptions& options);
    125 
    126   // Copy the literal data to the device with the given ordinal and return as a
    127   // ScopedShapedBuffer. If non-null the given memory allocator is used for
    128   // device memory allocation. If null, the default memory allocator for the
    129   // device is used.
    130   StatusOr<std::unique_ptr<ScopedShapedBuffer>> LiteralToShapedBuffer(
    131       const Literal& literal, int device_ordinal,
    132       DeviceMemoryAllocator* allocator = nullptr);
    133 
    134   // Copy the data from the device contained in the given ShapedBuffer and
    135   // return as a Literal.
    136   StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
    137       const ShapedBuffer& shaped_buffer);
    138 
    139   // Transfer the given literal to the infeed queue of the given device.
    140   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
    141   // not inherit from Client and there is no possibility of confusion with
    142   // Client::TransferToInfeed.
    143   Status TransferToInfeedLocal(const Literal& literal, int device_ordinal);
    144 
    145   // Transfer and return a value of the given shape from the outfeed of the
    146   // given device.
    147   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
    148   // not inherit from Client and there is no possibility of confusion with
    149   // Client::TransferFromOutfeed.
    150   StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
    151       const Shape& shape, int device_ordinal);
    152 
    153   // Returns the device ordinal that corresponds to the given replica number.
    154   //
    155   // This returns an error if there is not a one-to-one correspondence of
    156   // replicas to device ordinals, but is useful as a short term mechanism for
    157   // the "easy" case where a single replica is a single device.
    158   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
    159 
    160   // Returns the platform that the underlying service targets.
    161   perftools::gputools::Platform* platform() const;
    162 
    163   // Returns the number of devices on the system of the service platform
    164   // type. Not all devices may be supported by the service (see
    165   // device_ordinal_supported method).
    166   int device_count() const;
    167 
    168   // Returns the default device ordinal that the service will run computations
    169   // on if no device ordinal is specified in execute options.
    170   int default_device_ordinal() const;
    171 
    172   // Returns whether the device with the given ordinal can be used by the
    173   // service to execute computations. Not all devices of a particular platform
    174   // may be usable by the service (eg, a GPU with insufficient CUDA compute
    175   // capability).
    176   bool device_ordinal_supported(int device_ordinal) const;
    177 
    178   // Returns the backend used to execute computations.
    179   const Backend& backend() const;
    180   Backend* mutable_backend();
    181 
    182  private:
    183   LocalService* local_service_;
    184 };
    185 
    186 }  // namespace xla
    187 
    188 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
    189